In this post, we'll implement llama3 from scratch using pure jax in just 100 lines of code. Why jax? Because I think it has good aesthetics. Also jax looks like a NumPy wrapper but it has some cool features like xla; a linear algebra accelerator, jit, vmap, pmap etc., which makes your training go brr brr.
Jax is one of the first libraries which strongly focuses on the soul of pure functional programming which makes it more cool.1
Note :
This post assumes familiarity with Python and basic understanding of transformer architectures.
This implementation is for educational purposes, which means it is not for any production stuff but it covers all components of the model.2
If you don't wanna read this amazing blog post then you can check out all the code at.3
Table of Contents
LLaMA3
Model Weights Initialization
Tokenization
Embedding
Root Mean Square Layer Normalization
Rotary Positional Encoding
Group-Query Attention
Feed-Forward
Trasnformer-Block
Foward-Pass
Dataset
Training
LLaMA3
At its core, LLaMA 3 is a decoder only transformer language model that generates text one token at a time,
building on previous tokens to predict what comes next ; like completing a sentence word by word.
So lets fucking go !! we're doing it, get your diet coke !! First, we will begin with setting up device5 and configuring the model.
# Configure JAX to use GPU and prevent memory preallocationos.environ['JAX_PLATFORM_NAME'] = 'gpu'os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'print("JAX devices:", jax.devices())
So these are the hyperparameter we need to train approximately 2 million parameters model.
# Define model hyperparametersargs = ModelArgs(
vocab_size=enc.n_vocab, # Size of vocabularydim=256, # Embedding dimensionn_layers=6, # Number of transformer layersn_heads=8, # Number of attention headsn_kv_heads=4, # Number of key-value heads for GQAmax_seq_len=512, # Maximum sequence lengthnorm_eps=1e-5# Normalization epsilon
)
Model Weights Initialization
In pure JAX, we don't use classes like in PyTorch. We use only pure fucntions why ? cause it makes our code more predictable and easier to parallelize.
A pure function always returns the same output for the same input and doesn’t cause any side effects.6 For example, if you call F(x), you'll always get the same y.
Since we aren’t using a framework like PyTorch’s nn.Module to automatically track parameters, we must initialize and update our weights manually.
Handling randomness is also different. Instead of relying on a single global seed as in NumPy or PyTorch, in jax we need to manage randomness with explicit pseudo-random number generator (PRNG) keys. Each random operation gets its own unique key, which is derived by splitting a parent key. This will help in reproducibility and parallelism.
For example, below you can see we are creating a key and splitting it into sub keys and then providing that key to the function which involves the randomness.
# Generate and split random keys for reproducibility
key = jax.random.PRNGKey(42)
# Create a new subkey for random operations
key, subkey = jax.random.split(key)
# Initialize random weights using the subkey
weights = jax.random.normal(subkey, (784, 512))
Now lets start with our Model Weights Initialization, first we create the random values for our parameters with normal ditribuition.
# Initialize weights with optional scalingdefinit_weight(key, shape, scale=None):
# Calculate default scale if none provided
scale = 1.0 / math.sqrt(shape[0]) if scale isNoneelse scale
# Return scaled normal distributionreturn jax.random.normal(key, shape) * scale
Next, we'll identify all the learnable parameters of our model(llama3), assign each a unique key to ensure reproducibility, and apply the initialization process to them.
Since weights are essentially numbers stored in arrays, we can use dictionaries to manage them as key-value pairs.
First we will start with attention module which has four trainable parameters.
# Initialize attention weights for multi-head attentiondefinit_attention_weights(key, dim, n_heads, n_kv_heads):
# Split key for each weight matrixkeys = jax.random.split(key, 4)
head_dim = dim // n_heads# Return dictionary of weight matricesreturn {
'wq': init_weight(keys[0], (dim, n_headshead_dim)), # Query weights'wk': init_weight(keys[1], (dim, n_kv_headshead_dim)), # Key weights'wv': init_weight(keys[2], (dim, n_kv_headshead_dim)), # Value weights'wo': init_weight(keys[3], (n_headshead_dim, dim)) # Output projection
}
Next we have our Feed-forward network which has 3 trainable parameters.
# Initialize feed-forward network weightsdefinit_ffn_weights(key, dim):
# Split key into three for each weight matrix
keys = jax.random.split(key, 3)
return {
'w1': init_weight(keys[0], (dim, 4 * dim)), # First projection'w2': init_weight(keys[1], (4 * dim, dim)), # Output projection'w3': init_weight(keys[2], (dim, 4 * dim)) # Gate projection
}
Then we combine our weights into transformer block, adding two additional parameters for two layers of RMSNorm.
Tokenization means dividing the text into words and subwords (tokens).
We will be using Byte Pair Encoding (BPE) for training our model (BPE was used in training Llama 3).7
I will not build bpe from scratch we will use tiktoken library by openai for bpe.
import jax.numpy as jnp
import tiktoken
# Load GPT-2 BPE encoding
enc = tiktoken.get_encoding("gpt2")
# reading a line fromwithopen('../shakespeare.txt', 'r') as f:
text = f.readlines()[0] # Take the first line# Encode the text into token IDs
tokens = enc.encode(text)
data = jnp.array(tokens, dtype=jnp.int32) # Store as JAX array# Decode back to text
decoded_text = enc.decode(tokens)
print("original Text:", text.strip())
print("encoded Tokens:", tokens)
print("decoded Text:", decoded_text)
## Ouput ##
# Original Text: From fairest creatures we desire increase,
# Encoded Tokens: [220, 3574, 37063, 301, 8109, 356, 6227, 2620, 11, 198]
# Decoded Text: From fairest creatures we desire increase,
Embeddings
We cannot provide tokens directly to a model because tokens are discrete, while neural networks operate on continuous numerical data this is important for performing mathematical operations. Therefore, we use an embedding layer to convert the discrete tokens into a continuous vector space. These embeddings also help encode the semantic and syntactic relationships between tokens.
There are two types of embeddings: static and dynamic. We use dynamic embeddings to train LLMs. Why? Because static embeddings work well for finding similarities between words and representing them in a similar vector space, as seen in the first image.
However, they suffer from semantic ambiguity, as shown in the second image.
This is where Self-Attention comes in, it refines these embeddings to incorporate context. So, we start with random embeddings and update them according to the context.
# Converting the input tokens into embeddings
h = params["token_embedding"][inputs]
# token_embedding is a matrix of shape (vocab_size, dim).# inputs are token IDs (integers).# token_embedding is a matrix of shape (vocab_size, dim).
Root Mean Square Layer Normalization
RMS normalization is an important layer in llama3 models. It helps keep the training stable by making sure that the numbers in the network don’t become too high or too low. This balance is very important, especially in deep networks.
# RMS Norm function for stabilizing trainingdefrms_norm(x, weight, eps=1e-5):
# Calculate variance across last dimension
variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
# Normalize and scalereturn x * weight * jnp.reciprocal(jnp.sqrt(variance + eps))
Rotary Positional Encoding
Transformers don't naturally know the order of tokens, so we need to add some position info. In llama3 to solve this we have ROPE. It does this by “rotating” the query and key vectors based on their position.8
How It Works:
Precompute Rotation Factors:
First we create a table of rotation factors using a range of frequencies. This means each token gets its own unique rotation angle.
# Compute rotary position embeddingsdefprecompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
# Generate frequency bandsfreqs = 1.0 / (theta ** (jnp.arange(0, dim // 2, dtype=jnp.float32) / dim))
# Generate position indicest = jnp.arange(end, dtype=jnp.float32)
# Compute outer productfreqs = jnp.outer(t, freqs)
# Convert to complex exponentialreturn jnp.complex64(jnp.exp(1j * freqs))
Apply the Rotation:
Pair Up Features:
we reshape the vectors so that every two numbers form a pair; imagine them as the real and imaginary parts of a complex number.
Rotate:
We multiply these complex numbers by our precomputed rotation factors. This rotates each pair in the complex plane.
Convert Back:
Finally, we split the rotated complex numbers back into their real and imaginary parts to restore the original shape.
Math Behind It:
For each pair (x2i,x2i+1), the rotation is given by:
(x′2ix′2i+1)=(cos(θi)−sin(θi)sin(θi)cos(θi))(x2ix2i+1)
where θi is the rotation angle for that token.
In short, ROPE embeds positional information directly into the token features by rotating them. This way attention module gets the info about token order without extra position vectors.
# Apply rotary embeddings to queries and keysdefapply_rotary_emb(xq, xk, freqs_cis):
# Reshape inputs for complex multiplication
xq_r, xk_r = jnp.reshape(xq, (*xq.shape[:-1], -1, 2)),
jnp.reshape(xk, (*xk.shape[:-1], -1, 2))
# Convert to complex numbers
xq_complex = jnp.complex64(xq_r[..., 0] + 1j * xq_r[..., 1])
xk_complex = jnp.complex64(xk_r[..., 0] + 1j * xk_r[..., 1])
# Reshape frequencies for broadcasting
freqs_cis = jnp.reshape(freqs_cis, (1, freqs_cis.shape[0], 1, freqs_cis.shape[1]))
# Apply rotation through complex multiplication
xq_out = xq_complex * freqs_cis
xk_out = xk_complex * freqs_cis# Convert back to real numbers and reshapexq = jnp.stack([jnp.real(xq_out), jnp.imag(xq_out)], axis=-1).reshape(xq.shape)
xk = jnp.stack([jnp.real(xk_out), jnp.imag(xk_out)], axis=-1).reshape(xk.shape)
returnxq, xk
Group-Query Attention
Now it's time for attention. Grouped Query Attention (GQA) is an optimized version of Multi-Head Attention that improves efficiency by sharing key and value representations among multiple query heads. This reduces computational overhead and memory usage, enabling faster inference and better scaling for transformer models.
At it's core, it's just self-attention but with some modification.
Scaled Dot-Product Attention:
A=softmax(QKT√dh)V
# Attention mechanism with grouped-query attentiondefattention(params, x, mask, freqs_cis, n_heads, n_kv_heads, cache=None, position=0):
# Get input dimensionsB, T, C = x.shapehead_dim = C // n_heads# Project inputs to queries, keys, and valuesq = jnp.dot(x, params['wq']).reshape(B, T, n_heads, head_dim)
k = jnp.dot(x, params['wk']).reshape(B, T, n_kv_heads, head_dim)
v = jnp.dot(x, params['wv']).reshape(B, T, n_kv_heads, head_dim)
# Apply rotary embeddingsq, k = apply_rotary_emb(q, k, freqs_cis[position:position + T])
# Handle cache for inferenceifcacheis not None:
k = jnp.concatenate([cache[0], k], axis=-1])
v = jnp.concatenate([cache[1], v], axis=-1])
new_cache = (k, v)
# Repeat k/v heads for grouped-query attentionk = repeat_kv(k, n_heads // n_kv_heads)
v = repeat_kv(v, n_heads // n_kv_heads)
# Compute attention scores and apply attentionq, k, v = map(lambdax: x.transpose(0, 2, 1, 3), (q, k, v))
scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / math.sqrt(head_dim)
# Apply attention mask if providedifmaskis not None:
scores = scores + mask[:, :, :T, :T]
# Compute attention weights and final outputscores = jax.nn.softmax(scores, axis=-1)
output = jnp.matmul(scores, v)
output = output.transpose(0, 2, 1, 3).reshape(B, T, -1)
returnjnp.dot(output, params['wo']), new_cache
KV-cache : It stores previously computed key (K) and value (V) tensors from past tokens. We can cache this kv-cache during inference.
ifcacheis not None:
k = jnp.concatenate([cache[0], k], axis=-1) # Concatenate cached keys with new keysv = jnp.concatenate([cache[1], v], axis=-1) # Concatenate cached values with new valuesnew_cache = (k, v) # Create new cache with updated k/v pairs
Feed-forward
This is simple feed-forward network with SiLU activation function.
This is where all the important components come together in the transformer block. We unpack the pre-initialized weights and distribute them to their respective layers. The transformer blocks include attention, normalization, feed-forward processing layers and residual connections.
# Transformer block implementationdef transformer_block(params, x, mask, freqs_cis, n_heads, n_kv_heads, cache=None, position=0):
# Apply attention with normalization and residual connectionattn_output, new_cache = attention(
params['attention'],
rms_norm(x, params['attention_norm']),
mask,
freqs_cis,
n_heads,
n_kv_heads,
cache,
position
)
# First residual connectionh = x + attn_output# Apply feed-forward network with normalization and residualffn_output = feed_forward(params['ffn'], rms_norm(h, params['ffn_norm']))
# Second residual connectionout = h + ffn_outputreturnout, new_cache
Forward-Pass
The forward pass takes your data through the entire model from converting input tokens into embeddings, through a series of transformer blocks, and finally to the output layer. In other words, it connects all the layers (embedding, transformers, and output) to produce the final predictions.
# Forward pass through the entire modeldefmodel_forward(params, inputs, config, cache=None, position=0):
# Get batch dimensionsB, T = inputs.shape
# Convert input tokens to embeddingsh = params['token_embedding'][inputs]
# Compute freqs_cis for this forward passfreqs_cis = precompute_freqs_cis(config.dim // config.n_heads, config.max_seq_len)
# Create causal maskmask = jnp.tril(jnp.ones((config.max_seq_len, config.max_seq_len)))
mask = jnp.where(mask == 0, -1e9, 0.0)
mask = mask.astype(h.dtype)
mask = mask[None, None, :, :]
# Process through transformer blocksnew_caches = []
fori, blockinenumerate(params['blocks']):
layer_cache = cache[i] ifcacheis notNoneelseNoneh, layer_cache = transformer_block(
block, h, mask, freqs_cis,
config.n_heads, config.n_kv_heads,
layer_cache, position, training=False)
new_caches.append(layer_cache)
# Final normalization and output projectionh = rms_norm(h, params['norm_f'])
logits = jnp.dot(h, params['output'])
returnlogits, new_caches
Dataset
Now the model part is complete so its time to train our model on shakespeare dataset. First we will read our data from .txt file then we will encode our data with bpe and then convert it into jax array.
# Initialize tokenizer and load dataenc = tiktoken.get_encoding("gpt2")
# Read text filewithopen('shakespeare.txt', 'r') asf:
text = f.read()
# Convert text to token IDstokens = enc.encode(text)
# Convert to JAX arraydata = jnp.array(tokens)
Get Batches
The get_batch function creates training batches from our Shakespeare dataset. We need to feed our model with chunks of data. For each batch, we randomly select starting positions in the text, this way the model sees a variety of contexts.
Now, here's where JAX's cool vmap feature comes into play. Instead of writing a loop to extract each chunk, we use vmap to automate.
How does it work ?
vmap is like a vectorized loop; it takes a function that processes a single index (using lax.dynamic_slice to get a sequence of tokens) and applies it to every element in our array of indices. This means our input sequences (x) and corresponding target sequences (y, which are shifted by one token for next-word prediction) are created in one go.
defget_batch(key, data, batch_size, seq_len):
# Generate random starting indicesix = random.randint(key, (batch_size,), 0, len(data) - seq_len)
# Vectorized operation to get input and target sequencesx = vmap(lambdai: lax.dynamic_slice(data, (i,), (seq_len,)))(ix)
y = vmap(lambdai: lax.dynamic_slice(data, (i + 1,), (seq_len,)))(ix)
returnx, y
Loss Function
This function computes the cross-entropy loss for a batch during training. It first performs a forward pass using the model to generate logits for the input data. Then, it reshapes both the logits and targets to merge the batch and sequence dimensions. After applying the log softmax to the logits, it extracts the log probabilities corresponding to the correct target tokens and computes their negative mean as the final loss value.
The cross-entropy loss is defined as:
L=−1NN∑i=1logP(yi)
Where:
P(yi) is the probability of the correct class, calculated using the softmax function:
P(yi)=ezi∑jezj
# Compute cross-entropy lossdefcompute_loss(params, batch):
# Split batch into inputs and targetsinputs, targets = batch# Forward pass to get logitslogits, = model_forward(params, inputs, config)
# Reshape for loss computationlogits = logits.reshape(-1, config.vocab_size)
targets = targets.reshape(-1)
# Calculate negative log likelihoodloss = -jnp.mean(jnp.take_along_axis(jax.nn.log_softmax(logits),
targets[:, None], axis=1))
returnloss
Update function
Now we need to write a function to update our weights. For simplicity, we're using Stochastic Gradient Descent (SGD) here, though you can also use Adam or AdamW for faster convergence.
In the code, you'll notice the @jax.jit decorator. This is one of the features that sets jax apart. JIT (Just-In-Time) compilation speeds up execution by converting your Python code into optimized machine code.
How does it work ?
When you decorate a function with JAX’s jit, it changes how the function executes. Normally, when you call a function, Python runs it line by line. For example, if you have:
defsqr(x):
print("HI jiited") # side effectreturnx * xprint(sqr(2))
print(sqr(3))
print(sqr(4))
Every time you call sqr, it prints "HI jiited" and then returns the square of the number. However, when you add the @jax.jit decorator:
@jax.jitdefsqr(x):
print("HI jiited") # side effectreturnx * xprint(sqr(2))
print(sqr(3))
print(sqr(4))
Jax first traces your function to build an optimized computation graph. This tracing happens the first time the function is called and converts the Python code into machine code.
Because of this tracing, any side effects like the print statement; are only executed during the initial tracing. Once the function is compiled, other remaining calls use the compiled version, and you might not see the print output every time.
@jax.jitdefupdate_step(params, batch):
# Compute both loss and gradients in a single pass using value_and_grad# This is more efficient than computing them separatelyloss, grads = jax.value_and_grad(compute_loss)(params, batch)
# Update parameters using gradient descent# jax.tree.map applies the update rule to each parameter in the model# The lambda function implements: p_new = p_old - learning_rate * gradientparams = jax.tree.map(
lambdap, g: p - config.learning_rate * g,
params,
grads
)
# Return updated parameters and the loss value for monitoring trainingreturnparams, loss
In our update_step function, @jax.jit compiles the code. The function computes loss and gradients simultaneously with jax.value_and_grad, updates the parameters using gradient descent with help of jax.tree.map, and returns the updated parameters and loss.
Trainig-Loop
Finally, its time to train our 2 million parameter model on shakespeare dataset.
We first prepare batches using the get_batch which splits our data into batches so we can train faster with
our limited compute.
forepochinrange(num_epochs):
epoch_loss = 0.0forstepinrange(steps_per_epoch):
# Generate new random keys for reproducibilitykey, batch_key = random.split(key)
# Sample random batch of sequencesbatch = get_batch(batch_key, data, config.batch_size, config.max_seq_len)
# Forward pass, compute loss and update parametersparams_state, loss = update_step(params_state, batch)
# loss for epoch averageepoch_loss += lossifstep % 100 == 0:
print(f"epoch {epoch + 1}, step {step}/{steps_per_epoch}: loss = {loss:.4f}")
avg_epoch_loss = epoch_loss / steps_per_epochepoch_losses.append(avg_epoch_loss)
print(f"\nepoch {epoch + 1} | average loss: {avg_epoch_loss:.4f}")