Implementing LLaMA3 in 100 Lines of Pure Jax

In this post, we'll implement llama3 from scratch using pure jax in just 100 lines of code. Why jax? Because I think it has good aesthetics. Also jax looks like a NumPy wrapper but it has some cool features like xla; a linear algebra accelerator, jit, vmap, pmap etc., which makes your training go brr brr.

Jax is one of the first libraries which strongly focuses on the soul of pure functional programming which makes it more cool.1

Note :

  • This post assumes familiarity with Python and basic understanding of transformer architectures.
  • This implementation is for educational purposes, which means it is not for any production stuff but it covers all components of the model.2
  • If you don't wanna read this amazing blog post then you can check out all the code at.3
Llama architecture Llama architecture

LLaMA3

At its core, LLaMA 3 is a decoder only transformer language model that generates text one token at a time, building on previous tokens to predict what comes next ; like completing a sentence word by word.

So lets fucking go !! we're doing it, get your diet coke !! First, we will begin with setting up device5 and configuring the model.

# Configure JAX to use GPU and prevent memory preallocation
os.environ['JAX_PLATFORM_NAME'] = 'gpu'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
print("JAX devices:", jax.devices())

So these are the hyperparameter we need to train approximately 2 million parameters model.

# Define model hyperparameters
args = ModelArgs(
    vocab_size=enc.n_vocab,    # Size of vocabulary
    dim=256,                # Embedding dimension
    n_layers=6,            # Number of transformer layers
    n_heads=8,             # Number of attention heads
    n_kv_heads=4,          # Number of key-value heads for GQA
    max_seq_len=512,       # Maximum sequence length
    norm_eps=1e-5          # Normalization epsilon
)

Model Weights Initialization

In pure JAX, we don't use classes like in PyTorch. We use only pure fucntions why ? cause it makes our code more predictable and easier to parallelize. A pure function always returns the same output for the same input and doesn’t cause any side effects.6 For example, if you call F(x), you'll always get the same y.

Since we aren’t using a framework like PyTorch’s nn.Module to automatically track parameters, we must initialize and update our weights manually.

Handling randomness is also different. Instead of relying on a single global seed as in NumPy or PyTorch, in jax we need to manage randomness with explicit pseudo-random number generator (PRNG) keys. Each random operation gets its own unique key, which is derived by splitting a parent key. This will help in reproducibility and parallelism.

For example, below you can see we are creating a key and splitting it into sub keys and then providing that key to the function which involves the randomness.

# Generate and split random keys for reproducibility
key = jax.random.PRNGKey(42)

# Create a new subkey for random operations
key, subkey = jax.random.split(key)

# Initialize random weights using the subkey
weights = jax.random.normal(subkey, (784, 512))

Now lets start with our Model Weights Initialization, first we create the random values for our parameters with normal ditribuition.

# Initialize weights with optional scaling
def init_weight(key, shape, scale=None):
    # Calculate default scale if none provided
    scale = 1.0 / math.sqrt(shape[0]) if scale is None else scale
    # Return scaled normal distribution
    return jax.random.normal(key, shape) * scale

Next, we'll identify all the learnable parameters of our model(llama3), assign each a unique key to ensure reproducibility, and apply the initialization process to them.

Since weights are essentially numbers stored in arrays, we can use dictionaries to manage them as key-value pairs.

First we will start with attention module which has four trainable parameters.

# Initialize attention weights for multi-head attention
def init_attention_weights(key, dim, n_heads, n_kv_heads):
    # Split key for each weight matrix
    keys = jax.random.split(key, 4)
    head_dim = dim // n_heads
    # Return dictionary of weight matrices
    return {
    'wq': init_weight(keys[0], (dim, n_heads head_dim)),  # Query weights
    'wk': init_weight(keys[1], (dim, n_kv_heads head_dim)),  # Key weights
    'wv': init_weight(keys[2], (dim, n_kv_heads head_dim)),  # Value weights
    'wo': init_weight(keys[3], (n_heads head_dim, dim))    # Output projection
    }

Next we have our Feed-forward network which has 3 trainable parameters.

# Initialize feed-forward network weights
def init_ffn_weights(key, dim):
    # Split key into three for each weight matrix
    keys = jax.random.split(key, 3)
    return {
        'w1': init_weight(keys[0], (dim, 4 * dim)),  # First projection
        'w2': init_weight(keys[1], (4 * dim, dim)),  # Output projection
        'w3': init_weight(keys[2], (dim, 4 * dim))   # Gate projection
    }

Then we combine our weights into transformer block, adding two additional parameters for two layers of RMSNorm.

# Initialize a complete transformer block
def init_transformer_block(key, dim, n_heads, n_kv_heads):
    # Split key for each component
    keys = jax.random.split(key, 4)
    return {
    'attention': init_attention_weights(keys[0], dim, n_heads, n_kv_heads),  # Self-attention
    'ffn': init_ffn_weights(keys[1], dim),  # Feed-forward network
    'attention_norm': init_weight(keys[2], (dim,), scale=1.0),  # Pre-attention norm
    'ffn_norm': init_weight(keys[3], (dim,), scale=1.0)  # Pre-ffn norm
    }

Finally we assemble Model's Weights Initialization in one place.

# Initialize complete model parameters
def init_model_params(key, vocab_size, dim, n_layers, n_heads, n_kv_heads):
    # Split keys for different components
    keys = jax.random.split(key, 4)
    params = {
        'token_embedding': init_weight(keys[0], (vocab_size, dim)),  # Token embeddings
        'norm_f': init_weight(keys[1], (dim,), scale=1.0),  # Final normalization
        'output': init_weight(keys[2], (dim, vocab_size))  # Output projection
    }
    # Initialize transformer blocks
    block_keys = jax.random.split(keys[3], n_layers)
    params['blocks'] = [
        init_transformer_block(k, dim, n_heads, n_kv_heads)
        for k in block_keys
    ]
    return params

Tokenization

Tokenization means dividing the text into words and subwords (tokens). We will be using Byte Pair Encoding (BPE) for training our model (BPE was used in training Llama 3).7 I will not build bpe from scratch we will use tiktoken library by openai for bpe.

import jax.numpy as jnp
import tiktoken

# Load GPT-2 BPE encoding
enc = tiktoken.get_encoding("gpt2")


# reading a line from 
with open('../shakespeare.txt', 'r') as f:
    text = f.readlines()[0]  # Take the first line

# Encode the text into token IDs
tokens = enc.encode(text)
data = jnp.array(tokens, dtype=jnp.int32)  # Store as JAX array

# Decode back to text
decoded_text = enc.decode(tokens)

print("original Text:", text.strip())
print("encoded Tokens:", tokens)
print("decoded Text:", decoded_text)

## Ouput ##

# Original Text: From fairest creatures we desire increase,
# Encoded Tokens: [220, 3574, 37063, 301, 8109, 356, 6227, 2620, 11, 198]
# Decoded Text:   From fairest creatures we desire increase,

Embeddings

We cannot provide tokens directly to a model because tokens are discrete, while neural networks operate on continuous numerical data this is important for performing mathematical operations. Therefore, we use an embedding layer to convert the discrete tokens into a continuous vector space. These embeddings also help encode the semantic and syntactic relationships between tokens.

Llama architecture Llama architecture

There are two types of embeddings: static and dynamic. We use dynamic embeddings to train LLMs. Why? Because static embeddings work well for finding similarities between words and representing them in a similar vector space, as seen in the first image.

However, they suffer from semantic ambiguity, as shown in the second image. This is where Self-Attention comes in, it refines these embeddings to incorporate context. So, we start with random embeddings and update them according to the context.

# Converting the input tokens into embeddings

h = params["token_embedding"][inputs]

# token_embedding is a matrix of shape (vocab_size, dim).
# inputs are token IDs (integers).
# token_embedding is a matrix of shape (vocab_size, dim).

Root Mean Square Layer Normalization

RMS normalization is an important layer in llama3 models. It helps keep the training stable by making sure that the numbers in the network don’t become too high or too low. This balance is very important, especially in deep networks.

Llama architecture Llama architecture
# RMS Norm function for stabilizing training
def rms_norm(x, weight, eps=1e-5):
    # Calculate variance across last dimension
    variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True)                    
    # Normalize and scale
    return x * weight * jnp.reciprocal(jnp.sqrt(variance + eps))

Rotary Positional Encoding

Transformers don't naturally know the order of tokens, so we need to add some position info. In llama3 to solve this we have ROPE. It does this by “rotating” the query and key vectors based on their position.8

Llama architecture Llama architecture

How It Works:

Precompute Rotation Factors: First we create a table of rotation factors using a range of frequencies. This means each token gets its own unique rotation angle.

# Compute rotary position embeddings
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    # Generate frequency bands
    freqs = 1.0 / (theta ** (jnp.arange(0, dim // 2, dtype=jnp.float32) / dim))
    # Generate position indices
    t = jnp.arange(end, dtype=jnp.float32)
    # Compute outer product
    freqs = jnp.outer(t, freqs)
    # Convert to complex exponential
    return jnp.complex64(jnp.exp(1j * freqs))

Apply the Rotation:

Pair Up Features: we reshape the vectors so that every two numbers form a pair; imagine them as the real and imaginary parts of a complex number.

Rotate: We multiply these complex numbers by our precomputed rotation factors. This rotates each pair in the complex plane.

Convert Back: Finally, we split the rotated complex numbers back into their real and imaginary parts to restore the original shape.

Math Behind It: For each pair (x2i,x2i+1), the rotation is given by:
(x2ix2i+1)=(cos(θi)sin(θi)sin(θi)cos(θi))(x2ix2i+1) where θi is the rotation angle for that token. In short, ROPE embeds positional information directly into the token features by rotating them. This way attention module gets the info about token order without extra position vectors.

# Apply rotary embeddings to queries and keys
def apply_rotary_emb(xq, xk, freqs_cis):
    # Reshape inputs for complex multiplication
    xq_r, xk_r = jnp.reshape(xq, (*xq.shape[:-1], -1, 2)),    
    jnp.reshape(xk, (*xk.shape[:-1], -1, 2))
    
    # Convert to complex numbers
    xq_complex = jnp.complex64(xq_r[..., 0] + 1j * xq_r[..., 1])
    xk_complex = jnp.complex64(xk_r[..., 0] + 1j * xk_r[..., 1])
    
    # Reshape frequencies for broadcasting
    freqs_cis = jnp.reshape(freqs_cis, (1, freqs_cis.shape[0], 1, freqs_cis.shape[1]))
    
    # Apply rotation through complex multiplication
    xq_out = xq_complex * freqs_cis
    xk_out = xk_complex * freqs_cis
    
    # Convert back to real numbers and reshape
    xq = jnp.stack([jnp.real(xq_out), jnp.imag(xq_out)], axis=-1).reshape(xq.shape)
    xk = jnp.stack([jnp.real(xk_out), jnp.imag(xk_out)], axis=-1).reshape(xk.shape)
    
    return xq, xk

Group-Query Attention

Now it's time for attention. Grouped Query Attention (GQA) is an optimized version of Multi-Head Attention that improves efficiency by sharing key and value representations among multiple query heads. This reduces computational overhead and memory usage, enabling faster inference and better scaling for transformer models. At it's core, it's just self-attention but with some modification.

Scaled Dot-Product Attention:

A=softmax(QKTdh)V

# Attention mechanism with grouped-query attention
def attention(params, x, mask, freqs_cis, n_heads, n_kv_heads, cache=None, position=0):
    # Get input dimensions
    B, T, C = x.shape
    head_dim = C // n_heads
    
    # Project inputs to queries, keys, and values
    q = jnp.dot(x, params['wq']).reshape(B, T, n_heads, head_dim)
    k = jnp.dot(x, params['wk']).reshape(B, T, n_kv_heads, head_dim)
    v = jnp.dot(x, params['wv']).reshape(B, T, n_kv_heads, head_dim)
    
    # Apply rotary embeddings
    q, k = apply_rotary_emb(q, k, freqs_cis[position:position + T])
    
    # Handle cache for inference
    if cache is not None:
        k = jnp.concatenate([cache[0], k], axis=-1])
        v = jnp.concatenate([cache[1], v], axis=-1])
    new_cache = (k, v)
    
    # Repeat k/v heads for grouped-query attention
    k = repeat_kv(k, n_heads // n_kv_heads)
    v = repeat_kv(v, n_heads // n_kv_heads)
    
    # Compute attention scores and apply attention
    q, k, v = map(lambda x: x.transpose(0, 2, 1, 3), (q, k, v))
    scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / math.sqrt(head_dim)
    
    # Apply attention mask if provided
    if mask is not None:
        scores = scores + mask[:, :, :T, :T]
    
    # Compute attention weights and final output
    scores = jax.nn.softmax(scores, axis=-1)
    output = jnp.matmul(scores, v)
    output = output.transpose(0, 2, 1, 3).reshape(B, T, -1)
    
    return jnp.dot(output, params['wo']), new_cache

KV-cache : It stores previously computed key (K) and value (V) tensors from past tokens. We can cache this kv-cache during inference.

Llama architecture Llama architecture
if cache is not None:
    k = jnp.concatenate([cache[0], k], axis=-1)  # Concatenate cached keys with new keys
    v = jnp.concatenate([cache[1], v], axis=-1)  # Concatenate cached values with new values
new_cache = (k, v)  # Create new cache with updated k/v pairs

Feed-forward

This is simple feed-forward network with SiLU activation function.

def feed_forward(params, x):
    
    w3_ = jnp.dot(x, params['w3'])

    # SwiGLU(a,b)=SiLU(a)⊙b 
    activated = jax.nn.silu(w3_)
    
    
    w1_ = jnp.dot(x, params['w1'])
    
    
    combined = activated * w1_
    
    # Final output projection with w2
    output = jnp.dot(combined, params['w2'])
    
    return output

Transformer-block

This is where all the important components come together in the transformer block. We unpack the pre-initialized weights and distribute them to their respective layers. The transformer blocks include attention, normalization, feed-forward processing layers and residual connections.

# Transformer block implementation
def transformer_block(params, x, mask, freqs_cis, n_heads, n_kv_heads, cache=None, position=0):
    # Apply attention with normalization and residual connection
    attn_output, new_cache = attention(
        params['attention'],
        rms_norm(x, params['attention_norm']),
        mask,
        freqs_cis,
        n_heads,
        n_kv_heads,
        cache,
        position
    )
    
    # First residual connection
    h = x + attn_output
    
    # Apply feed-forward network with normalization and residual
    ffn_output = feed_forward(params['ffn'], rms_norm(h, params['ffn_norm']))
    
    # Second residual connection
    out = h + ffn_output
    
    return out, new_cache

Forward-Pass

The forward pass takes your data through the entire model from converting input tokens into embeddings, through a series of transformer blocks, and finally to the output layer. In other words, it connects all the layers (embedding, transformers, and output) to produce the final predictions.

# Forward pass through the entire model
def model_forward(params, inputs, config, cache=None, position=0):
    # Get batch dimensions
    B, T = inputs.shape
    
    # Convert input tokens to embeddings
    h = params['token_embedding'][inputs]
    
    # Compute freqs_cis for this forward pass
    freqs_cis = precompute_freqs_cis(config.dim // config.n_heads, config.max_seq_len)
    
    # Create causal mask
    mask = jnp.tril(jnp.ones((config.max_seq_len, config.max_seq_len)))
    mask = jnp.where(mask == 0, -1e9, 0.0)
    mask = mask.astype(h.dtype)
    mask = mask[None, None, :, :]

    # Process through transformer blocks
    new_caches = []
    for i, block in enumerate(params['blocks']):
        layer_cache = cache[i] if cache is not None else None
        h, layer_cache = transformer_block(
            block, h, mask, freqs_cis,
            config.n_heads, config.n_kv_heads,
            layer_cache, position, training=False)
        new_caches.append(layer_cache)

    # Final normalization and output projection
    h = rms_norm(h, params['norm_f'])
    logits = jnp.dot(h, params['output'])
    
    return logits, new_caches

Dataset

Now the model part is complete so its time to train our model on shakespeare dataset. First we will read our data from .txt file then we will encode our data with bpe and then convert it into jax array.

# Initialize tokenizer and load data
enc = tiktoken.get_encoding("gpt2")

# Read text file
with open('shakespeare.txt', 'r') as f:
    text = f.read()

# Convert text to token IDs
tokens = enc.encode(text)
# Convert to JAX array
data = jnp.array(tokens)

Get Batches

The get_batch function creates training batches from our Shakespeare dataset. We need to feed our model with chunks of data. For each batch, we randomly select starting positions in the text, this way the model sees a variety of contexts.

Now, here's where JAX's cool vmap feature comes into play. Instead of writing a loop to extract each chunk, we use vmap to automate.

How does it work ?

vmap is like a vectorized loop; it takes a function that processes a single index (using lax.dynamic_slice to get a sequence of tokens) and applies it to every element in our array of indices. This means our input sequences (x) and corresponding target sequences (y, which are shifted by one token for next-word prediction) are created in one go.

def get_batch(key, data, batch_size, seq_len):
    # Generate random starting indices
    ix = random.randint(key, (batch_size,), 0, len(data) - seq_len)
    
    # Vectorized operation to get input and target sequences
    x = vmap(lambda i: lax.dynamic_slice(data, (i,), (seq_len,)))(ix)
    y = vmap(lambda i: lax.dynamic_slice(data, (i + 1,), (seq_len,)))(ix)
    
    return x, y

Loss Function

This function computes the cross-entropy loss for a batch during training. It first performs a forward pass using the model to generate logits for the input data. Then, it reshapes both the logits and targets to merge the batch and sequence dimensions. After applying the log softmax to the logits, it extracts the log probabilities corresponding to the correct target tokens and computes their negative mean as the final loss value.

The cross-entropy loss is defined as:

L=1Ni=1NlogP(yi)

Where:

  • P(yi) is the probability of the correct class, calculated using the softmax function:

P(yi)=ezijezj

# Compute cross-entropy loss
def compute_loss(params, batch):
    # Split batch into inputs and targets
    inputs, targets = batch
    # Forward pass to get logits
    logits, = model_forward(params, inputs, config)
    # Reshape for loss computation
    logits = logits.reshape(-1, config.vocab_size)
    targets = targets.reshape(-1)
    # Calculate negative log likelihood
    loss = -jnp.mean(jnp.take_along_axis(jax.nn.log_softmax(logits),
    targets[:, None], axis=1))
    return loss

Update function

Now we need to write a function to update our weights. For simplicity, we're using Stochastic Gradient Descent (SGD) here, though you can also use Adam or AdamW for faster convergence.

In the code, you'll notice the @jax.jit decorator. This is one of the features that sets jax apart. JIT (Just-In-Time) compilation speeds up execution by converting your Python code into optimized machine code.

How does it work ?

When you decorate a function with JAX’s jit, it changes how the function executes. Normally, when you call a function, Python runs it line by line. For example, if you have:

def sqr(x): 
    print("HI jiited") # side effect 
    return x * x

print(sqr(2)) 
print(sqr(3)) 
print(sqr(4))

Every time you call sqr, it prints "HI jiited" and then returns the square of the number. However, when you add the @jax.jit decorator:

@jax.jit
def sqr(x): 
    print("HI jiited") # side effect  
    return x * x

print(sqr(2)) 
print(sqr(3)) 
print(sqr(4))

Jax first traces your function to build an optimized computation graph. This tracing happens the first time the function is called and converts the Python code into machine code.

Because of this tracing, any side effects like the print statement; are only executed during the initial tracing. Once the function is compiled, other remaining calls use the compiled version, and you might not see the print output every time.

@jax.jit
def update_step(params, batch):
    # Compute both loss and gradients in a single pass using value_and_grad
    # This is more efficient than computing them separately
    loss, grads = jax.value_and_grad(compute_loss)(params, batch)

    # Update parameters using gradient descent
    # jax.tree.map applies the update rule to each parameter in the model
    # The lambda function implements: p_new = p_old - learning_rate * gradient
    params = jax.tree.map(
        lambda p, g: p - config.learning_rate * g,
        params,
        grads
    )

    # Return updated parameters and the loss value for monitoring training
    return params, loss

In our update_step function, @jax.jit compiles the code. The function computes loss and gradients simultaneously with jax.value_and_grad, updates the parameters using gradient descent with help of jax.tree.map, and returns the updated parameters and loss.

Trainig-Loop

Finally, its time to train our 2 million parameter model on shakespeare dataset. We first prepare batches using the get_batch which splits our data into batches so we can train faster with our limited compute.

for epoch in range(num_epochs):
 
   
   epoch_loss = 0.0

   for step in range(steps_per_epoch):
   
      # Generate new random keys for reproducibility
      key, batch_key = random.split(key)
      
      # Sample random batch of sequences
      batch = get_batch(batch_key, data, config.batch_size, config.max_seq_len)
      
      # Forward pass, compute loss and update parameters
      params_state, loss = update_step(params_state, batch)
     
      # loss for epoch average
      epoch_loss += loss
      
   
      if step % 100 == 0:
            print(f"epoch {epoch + 1}, step {step}/{steps_per_epoch}: loss = {loss:.4f}")
      

  avg_epoch_loss = epoch_loss / steps_per_epoch
     
 
  epoch_losses.append(avg_epoch_loss)
      
  
  print(f"\nepoch {epoch + 1} | average loss: {avg_epoch_loss:.4f}")
Llama architecture Llama architecture

Thank you for reading this far !!

You can support me :

1.You can read more about fucntional programming here.
2.I have referred this blog-post a lot.
3.This repository contain all the code of llama3 in jax from model building and trianing.
4.The architecture of llama is from this video by Umar Jamil.
5. if you don't set PREALLOCATE = 'false' it will preallocate 75% of your memory and somtimes it can cause "out of memory" error.
6.If you prefer object-oriented style, then libraries like Haiku and Flax can help you with that without affecting jax's features.
7.In this Andrej Karpathy Build the BPE from scratch and explain it.
8.The math behind the ROPE is very interesting. Check out these few video and blog to to learn more.