Deconstructing LLaMA 2: A PyTorch Deep Dive

Deconstructing LLaMA 2: A PyTorch Deep Dive

Tags
LLM
Transformers
AI
Published
May 21, 2025
Author
Philip Redford
Large Language Models (LLMs) have taken the world by storm, and Meta's LLaMA 2 stands out for its impressive capabilities and open-source nature. While many interact with LLaMA 2 through high-level libraries, truly understanding its power comes from delving into its underlying architecture, particularly how it's built using PyTorch.
This post will peel back the layers of LLaMA 2, exploring its fundamental components and how they're implemented in PyTorch, giving you a clearer picture of what makes this powerful model tick.

The Foundation: Transformer Architecture

At its heart, LLaMA 2 is an autoregressive decoder-only Transformer model. This means it's designed to predict the next token in a sequence, relying solely on previously generated tokens. The architecture is built upon a stack of identical "Encoder Blocks," each responsible for processing the input and refining the token representations.

Key Components of a LLaMA 2 Encoder Block (in PyTorch)

Each Encoder Block in LLaMA 2 is a carefully engineered sequence of operations designed for efficiency and performance. Let's break down the core modules you'd find in a PyTorch implementation:

1. Input Normalisation: RMSNorm

Before processing the input, LLaMA 2 employs RMSNorm (Root Mean Square Normalisation). This is a computationally lighter alternative to traditional Layer Normalisation. Instead of both centring and scaling the values, RMSNorm focuses purely on scaling, using the root mean square of the input.
In PyTorch, an RMSNorm module would typically look like this:
class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): """RMSNorm layer. Args: dim (int): The dimension of the input tensor. eps (float, optional): The epsilon value. Defaults to 1e-6. """ super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x: torch.Tensor) -> torch.Tensor: # (B, seq_len, dim) * (B, seq_len, 1) -> (B, seq_len, dim) return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) def forward(self, x: torch.Tensor) -> torch.Tensor: # (Dim) * (B, seq_len, Dim) -> (B, seq_len, Dim) return self._norm(x.float()).astype(x) * self.weight

2. Self-Attention: Understanding Context with Efficiency

The attention mechanism is crucial for LLMs to understand the relationships between different tokens in a sequence. LLaMA 2's SelfAttention module features several optimisations:
  • Query, Key, Value Projections: Each token is projected into three different representations: Query (xq), Key (xk), and Value (xv). These are linear transformations of the input embedding.Python
    • self.wq = nn.Linear(args.dim, args.n_heads_q * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads_q * self.head_dim, args.dim, bias=False)
  • Rotary Positional Encoding (RoPE): Rather than adding fixed positional embeddings, RoPE applies a rotation to the Query and Key vectors. This elegant method efficiently encodes both absolute position and relative distances between tokens, ensuring better understanding of sequential information.
    • def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str | None = None, theta: float = 10000.0): """Precompute the theta and position frequencies. Args: head_dim (int): The dimension of the head. seq_len (int): The length of the sequence. device (str | None, optional): The device to run the model on. Defaults to None. theta (float, optional): The theta value. Defaults to 10000.0. Returns: torch.Tensor: The theta and position frequencies. """ assert head_dim % 2 == 0, "head_dim must be even" # Build the theta parameters to formula: theta_i = 10000 ^ (-2i / head_dim) # Shape: (Head_dim // 2) theta_numerator = torch.arange(0, head_dim, 2) # Shape: (Head_dim // 2) theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device) # Build the position (m parameter) # Shape: (seq_len) m = torch.arange(seq_len, device=device) # Shape: (seq_len) . (Head_dim // 2) -> (seq_len, Head_dim // 2) freqs = torch.outer(m, theta).float() # We can compute complex numbers in the polar form (seq_len, Head_dim // 2) -> (seq_len, Head_dim) freqs_complex = torch.polar(torch.ones_like(freqs), freqs) return freqs_complex def apply_rotary_embedding(x: torch.Tensor, freqs_complex: torch.Tensor, device: str | None = None) -> torch.Tensor: """Apply the rotary embedding to the input tensor. Args: x (torch.Tensor): The input tensor. freqs_complex (torch.Tensor): The precomputed frequencies. device (str | None, optional): The device to run the model on. Defaults to None. Returns: torch.Tensor: The input tensor with the rotary embedding applied. """ # This operation takes two consecutive dimensions and fuses them into a single dimension # (B, seq_len, n_heads, head_dim) -> (B, seq_len, n_heads, head_dim // 2, 2) x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) # (seq_len, head_dim) -> (1, seq_len, 1, head_dim) freqs_complex = torch.unsqueeze(freqs_complex, 0).unsqueeze(2) # (B, seq_len, n_heads, head_dim // 2) * (1, seq_len, 1, head_dim) -> (B, seq_len, n_heads, head_dim // 2) x_rotated = x_complex * freqs_complex # (B, seq_len, n_heads, head_dim // 2) -> (B, seq_len, n_heads, head_dim // 2, 2) x_rotated = torch.view_as_real(x_rotated) # (B, seq_len, n_heads, head_dim // 2, 2) -> (B, seq_len, n_heads, head_dim) x_out = x_rotated.reshape(*x.shape) return x_out.astype(x).to(device)
      This is calculated in the init of the model:
      class Transformer(nn.Module): def __init__(self, args: ModelArgs): super().__init__() ... # Precompute frequencies self.freqs_complex = precompute_theta_pos_frequencies( self.args.dim // args.n_heads, args.max_seq_len * 2, device=args.device, ) ...
      Then applied in the forward method before each layer:
      def forward(self, tokens: torch.Tensor, start_pos: int) -> torch.Tensor: .... # Retrieve hte pairs (m, theta), coresponding to the positions [start_pos, start_pos + seq_len] # This is precomputed for all positions in the sequence length freqs_complex = self.freqs_complex[start_pos : start_pos + seq_len] # Consecutively apply the encoder blocks for layer in self.layers: h = layer(h, start_pos, freqs_complex) ...
  • KV-Cache for Efficient Inference: During text generation (inference), LLaMA 2 leverages a Key-Value Cache. This means that for each new token generated, the previously computed Key and Value vectors are stored. Only the new Query vector needs to be calculated, significantly reducing redundant computations and speeding up inference.
    • # Within SelfAttention's __init__ self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)) self.cache_v = torch.zeros((args.max_batch_size, args.max_batch_size, self.n_kv_heads, self.head_dim)) # In SelfAttention's forward pass self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv # Use cached keys and values for attention computation keys = self.cache_k[:batch_size, : start_pos + seq_len] values = self.cache_v[:batch_size, : start_pos + seq_len]
      notion image
 
notion image
  • Grouped Query Attention (GQA): To further optimise efficiency, particularly for larger models, LLaMA 2 employs GQA. This is a middle-ground between traditional Multi-Head Attention (MHA) and Multi-Query Attention (MQA). Instead of each Query head having its own Key and Value head (MHA), or all Query heads sharing a single Key and Value head (MQA), GQA allows a group of Query heads to share Key and Value heads. This reduces memory footprint and speeds up inference with minimal performance impact.
    • def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: # Helper function to duplicate KV heads for GQA ... return ( x[:, :, :, None, :] .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim) .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim) ) # Applied to keys and values in SelfAttention's forward keys = repeat_kv(keys, self.n_rep) values = repeat_kv(values, self.n_rep)
      notion image
       
  • Feed-Forward Network: SwiGLU Activation: Following the self-attention mechanism, the FeedForward network processes the attention output. LLaMA 2 uses the SwiGLU activation function, an adaptation of the SiLU (Sigmoid Linear Unit). While the exact theoretical benefits are still under investigation, SwiGLU has consistently shown strong empirical performance in LLMs.

The Overall Encoder Block Structure

Bringing it all together, a PyTorch EncoderBlock combines these elements:
Python
class EncoderBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() # Initialise SelfAttention and FeedForward modules self.attention = SelfAttention(args) self.feed_forward = FeedForward(args) # Apply RMSNorm before attention and feed-forward self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor) -> torch.Tensor: # Residual connection and normalisation before attention h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_complex) # Residual connection and normalisation before feed-forward h = h + self.feed_forward.forward(self.ffn_norm(h)) return h
This sequential application of normalisation, attention, and feed-forward operations, coupled with residual connections, forms the backbone of LLaMA 2's processing capability.

Building the Full LLaMA 2 Model

A complete LLaMA 2 model in PyTorch would consist of:
  1. Token Embeddings: A nn.Embedding layer to convert input token IDs into dense vectors.
  1. Stacked Encoder Blocks: Multiple instances of the EncoderBlock defined above, typically 32 or 40 layers for smaller LLaMA 2 models, and 80 for the 70B variant.
  1. Final RMSNorm Layer: Another RMSNorm applied after the last encoder block.
  1. Language Model Head: A linear layer that projects the output of the final RMSNorm to the vocabulary size, producing logits (raw prediction scores for each possible next token).

The Inference Loop

Generating text with LLaMA 2 involves an iterative process, usually one token at a time. The start_pos parameter in the forward method is crucial here. In each step, the model processes the newly generated token along with the cached Key and Value vectors from previous steps. The output logits are then used to sample the next token, often employing strategies like Top-P sampling or adjusting the temperature for controlled randomness.
We can use a few strategies for sampling from this distribution:
  • Greedy: Always choose the highest probability next token. This generally performs poorly in practice
  • Beam Search: Pick the top-k words, then create k next prompt using each option, calculating the cumulative score for each path, keeping the path with the top score. Increases inference time but performs well.
  • Random Sampling: Sample over all values according to the probability. This can result in some very bad next tokens being selected.
  • Top-K: Select the k most likely outputs, then calculate the distribution for these k and sample. This can still result in poor results. e.g. if we set k=10, but the first two values are 80% and 19% likelihood, then the other 8 tokens are clearly a bad fit for the next token, but may still be selected.
  • Top-P: Sum the top probability up to a certain threshold, keeping only those values that are below that threshold. Then update the distribution and sample.
Temperature can control the confidence of the predictions. A low temperature will make the model more confident in its high prediction and less confident in the low predictions. And vice versa for high temperature, will make the output more random.
The code for inference, and all other model code, can be found in the GitHub repo. Not including here to save space.

Conclusion

Building LLaMA 2 in pure PyTorch offers invaluable insights into the intricacies of modern LLM design. From the efficient RMSNorm to the clever RoPE, and the performance-boosting KV-Cache and Grouped Query Attention, each component plays a vital role in enabling LLaMA 2's impressive capabilities. Understanding these PyTorch implementations provides a solid foundation for anyone looking to build upon or customise these state-of-the-art models.