Whalegrad 🐳: Lightweight Deep Learning Library in C
In this post, we'll explore Whalegrad, a lightweight deep learning library written in C that I've been developing. This project is inspired by Andrej Karpathy's micrograd but takes a different approach by implementing the autograd engine in C for better performance while maintaining the educational aspects.
Note
- This project is still under development.
- The code examples shown are for educational purposes.
- All the source code is available on GitHub.
What is Whalegrad?
Whalegrad is a compact autograd engine written in pure C, designed for educational purposes. Like its Python counterpart micrograd, Whalegrad demonstrates how automatic differentiation works under the hood, but with the added benefit of C's performance and low-level control.
At its core, Whalegrad automatically computes gradients (derivatives) of operations performed on tensors. These gradients are essential for training neural networks through backpropagation. The beauty of Whalegrad lies in its simplicity - the entire autograd engine is implemented in just a few hundred lines of C code.
Let's fucking go !! we're doing it, get your diet coke !! Let's dive into how Whalegrad works and how you can use it to build simple neural networks.
The Core: Tensor Structure
The fundamental building block of Whalegrad is the Tensor structure. Unlike PyTorch or TensorFlow tensors which support n-dimensional arrays, our implementation currently focuses on scalar values for simplicity.
Here's how our Tensor is defined:
// tensor.h - The core Tensor structure
typedef struct Tensor Tensor;
struct Tensor {
double data; // The scalar value
double grad; // Gradient of this tensor
bool requires_grad; // Whether to track gradients
int num_children; // Number of child tensors
Tensor* children[MAX_CHILDREN]; // Child tensors used in the operation
void (*backward)(struct Tensor*); // Gradient function
char op[32]; // Operation that created this tensor
};
Each Tensor keeps track of its value (data
), its gradient (grad
), and a reference to the tensors that were used to create it (children
). The backward
function pointer holds the operation-specific gradient computation function.
Creating a new tensor is as simple as:
Tensor* tensor_create(double data) {
return create_tensor_with_children(data, NULL, 0, "");
}
Building the Computational Graph
When you perform operations on tensors, Whalegrad automatically builds a computational graph. Each node in this graph represents a tensor, and the edges represent the operations that created them.
Let's look at how the basic operations are implemented:
// Addition operation
Tensor* tensor_add(Tensor* a, Tensor* b) {
Tensor* children[] = {a, b};
Tensor* out = create_tensor_with_children(a->data + b->data, children, 2, "+");
out->backward = add_backward;
return out;
}
// Multiplication operation
Tensor* tensor_mul(Tensor* a, Tensor* b) {
Tensor* children[] = {a, b};
Tensor* out = create_tensor_with_children(a->data * b->data, children, 2, "*");
out->backward = mul_backward;
return out;
}
These operations build the computational graph by creating new tensors that remember their "parents" (the input tensors). For addition, we compute the sum of the input values and set the backward function to add_backward
. For multiplication, we compute the product and set the backward function to mul_backward
. Each resulting tensor stores its operation type ("+", "*") and maintains references to its input tensors, which allows gradients to flow back during backpropagation.
Automatic Differentiation
The real magic of Whalegrad lies in its ability to automatically compute gradients. Each operation defines how gradients should flow backward through the computational graph.
At its core, automatic differentiation is based on the chain rule from calculus, which states that if \(z = f(y)\) and \(y = g(x)\), then:
\[\frac{dz}{dx} = \frac{dz}{dy} \cdot \frac{dy}{dx}\]
In our computational graph, we have many nested operations, and the chain rule extends to:
\[\frac{\partial L}{\partial x_i} = \sum_j \frac{\partial L}{\partial y_j} \cdot \frac{\partial y_j}{\partial x_i}\]
Where \(L\) is our final output, \(y_j\) are direct outputs of \(x_i\), and the summation accounts for all paths from \(x_i\) to \(L\).
For example, let's look at how gradients are computed for addition and multiplication:
// Backward function for addition
void add_backward(Tensor* self) {
for (int i = 0; i < self->num_children; i++) {
self->children[i]->grad += self->grad;
}
}
// Backward function for multiplication
void mul_backward(Tensor* self) {
self->children[0]->grad += self->children[1]->data * self->grad;
self->children[1]->grad += self->children[0]->data * self->grad;
}
These backward functions implement the chain rule from calculus. For addition, the gradient flows equally to both inputs (∂(a+b)/∂a = ∂(a+b)/∂b = 1). For multiplication, following the product rule, each input receives the gradient multiplied by the other input's value:
\[\frac{\partial (a \cdot b)}{\partial a} = b \quad \text{and} \quad \frac{\partial (a \cdot b)}{\partial b} = a\]
So if \(z = a \cdot b\) and we know \(\frac{\partial L}{\partial z}\), then:
\[\frac{\partial L}{\partial a} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial a} = \frac{\partial L}{\partial z} \cdot b\]
\[\frac{\partial L}{\partial b} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial b} = \frac{\partial L}{\partial z} \cdot a\]
This is exactly what our mul_backward
function implements, with self->grad
representing \(\frac{\partial L}{\partial z}\).
To perform backpropagation, we first need to build a topological ordering of the computational graph:
void build_topo(Tensor* v, Tensor** topo, int* topo_size, Tensor** visited, int* visited_size) {
if (find_tensor(visited, *visited_size, v) == -1) {
visited[(*visited_size)++] = v;
for (int i = 0; i < v->num_children; i++) {
build_topo(v->children[i], topo, topo_size, visited, visited_size);
}
topo[(*topo_size)++] = v;
}
}
This function builds a topological ordering of the computational graph using depth-first search. The key insight is that we need to process nodes in an order where each node is processed only after all its children (dependencies) have been processed. The function keeps track of visited nodes to avoid cycles and builds the ordering by recursively exploring each node's children before adding the node itself to the ordering.
Then we traverse this ordering in reverse, calling each tensor's backward function:
void tensor_backward(Tensor* t) {
Tensor* topo[MAX_TOPO_SIZE];
Tensor* visited[MAX_TOPO_SIZE];
int topo_size = 0;
int visited_size = 0;
build_topo(t, topo, &topo_size, visited, &visited_size);
t->grad = 1.0;
for (int i = topo_size - 1; i >= 0; i--) {
if (topo[i]->backward) {
topo[i]->backward(topo[i]);
}
}
}
This is the main backpropagation function. It first builds a topological ordering of the computational graph, sets the gradient of the output tensor to 1.0 (starting the chain rule), and then processes each tensor in reverse topological order. By traversing the graph in reverse, we ensure that a node's gradient is fully computed before it's used to compute the gradients of its dependencies. Each tensor's specific backward
function handles the gradient computation based on the operation that created it.
Implementing Activation Functions
Neural networks rely on non-linear activation functions like ReLU. Here's how we implement ReLU in Whalegrad:
void relu_backward(Tensor* self) {
if (self->data > 0) {
self->children[0]->grad += self->grad;
}
}
Tensor* tensor_relu(Tensor* a) {
Tensor* children[] = {a};
Tensor* out = create_tensor_with_children(
a->data < 0 ? 0 : a->data,
children,
1,
"ReLU"
);
out->backward = relu_backward;
return out;
}
The ReLU activation function and its gradient implementation demonstrate how non-linearities work in neural networks. Mathematically, ReLU is defined as:
\[ReLU(x) = \max(0, x) = \begin{cases} x & \text{if } x > 0 \\ 0 & \text{if } x \leq 0 \end{cases}\]
Its derivative is a step function:
\[ReLU'(x) = \begin{cases} 1 & \text{if } x > 0 \\ 0 & \text{if } x \leq 0 \end{cases}\]
The backward function implements this gradient: it passes the gradient through when the input was positive (ReLU'(x) = 1 when x > 0) and blocks the gradient when the input was negative (ReLU'(x) = 0 when x ≤ 0). This simple non-linearity is crucial for enabling neural networks to learn complex patterns.
Building a Neural Network
With our tensor operations in place, we can now build a simple neural network. Here's how we define a multi-layer perceptron (MLP) structure:
typedef struct {
Tensor** weights;
Tensor** biases;
int* layer_sizes;
int num_layers;
} MLP;
This structure defines a Multi-Layer Perceptron (MLP), a basic feedforward neural network. It stores arrays of weights and biases for each layer, the size of each layer, and the total number of layers. This simple representation allows us to create networks with arbitrary depth and width, though in our current implementation each weight and bias is a scalar value rather than a matrix.
Creating an MLP
Creating an MLP is straightforward. We initialize random weights and biases for each layer:
MLP* mlp_create(int* layer_sizes, int num_layers) {
MLP* mlp = (MLP*)malloc(sizeof(MLP));
if (!mlp) return NULL;
// Seed the random number generator
srand(time(NULL));
mlp->num_layers = num_layers - 1;
mlp->layer_sizes = (int*)malloc(num_layers * sizeof(int));
mlp->weights = (Tensor**)malloc((num_layers - 1) * sizeof(Tensor*));
mlp->biases = (Tensor**)malloc((num_layers - 1) * sizeof(Tensor*));
// Store layer sizes
for (int i = 0; i < num_layers; i++) {
mlp->layer_sizes[i] = layer_sizes[i];
}
// Initialize weights and biases
for (int i = 0; i < num_layers - 1; i++) {
mlp->weights[i] = tensor_create(random_double());
mlp->biases[i] = tensor_create(random_double());
}
return mlp;
}
This function creates a new MLP with randomly initialized weights and biases. It allocates memory for the network structure and parameter arrays, and initializes each weight and bias with a random value between -1 and 1. The randomization is crucial for breaking symmetry during training, allowing different neurons to learn different features. Note that proper memory management is important in C, so we carefully allocate memory for all the network components.
Forward Pass
The forward pass function propagates input data through the network. For each layer, it performs a linear transformation followed by a non-linear activation. Mathematically, for a layer with weights \(W\), bias \(b\), and activation function \(f\), the output \(y\) for input \(x\) is:
\[y = f(Wx + b)\]
In our case, for the hidden layers, \(f\) is the ReLU function. For the output layer, we typically use the identity function (no activation). Each operation in this function builds on our tensor operations, automatically constructing the computational graph that will later be used for backpropagation.
Tensor* mlp_forward(MLP* mlp, Tensor* input) {
Tensor* current = input;
// Process through each layer
for (int i = 0; i < mlp->num_layers; i++) {
// Linear transformation: wx + b
Tensor* linear = tensor_add(
tensor_mul(mlp->weights[i], current),
mlp->biases[i]
);
// Apply activation for hidden layers
if (i < mlp->num_layers - 1) {
current = tensor_relu(linear);
} else {
current = linear; // No activation for output layer
}
}
return current;
}
Though our implementation uses scalar values for simplicity, in a full neural network library, \(W\) would be a matrix, \(x\) and \(b\) would be vectors, and the computation would involve matrix-vector multiplication.
Example Usage
Let's see Whalegrad in action with a simple example. This code shows how to build a computational graph, compute gradients, and verify the results:
int main() {
// Create tensors
Tensor* a = tensor_create(-4.0);
Tensor* b = tensor_create(2.0);
// c = a + b
Tensor* c = tensor_add(a, b);
// d = a * b + b**3
Tensor* temp1 = tensor_mul(a, b);
Tensor* temp2 = tensor_pow(b, 3.0);
Tensor* d = tensor_add(temp1, temp2);
// e = c - d
Tensor* e = tensor_sub(c, d);
// f = e**2
Tensor* f = tensor_pow(e, 2.0);
// Compute gradients
tensor_backward(f);
// Display results
printf("f --> %.4f\n", f->data);
printf("df/da = %.4f\n", a->grad);
printf("df/db = %.4f\n", b->grad);
return 0;
}
The mathematical operations in this example can be expressed as:
\[ \begin{align} c &= a + b\\ d &= a \cdot b + b^3\\ e &= c - d\\ f &= e^2 \end{align} \]
To compute the gradients \(\frac{\partial f}{\partial a}\) and \(\frac{\partial f}{\partial b}\), we apply the chain rule:
\[ \begin{align} \frac{\partial f}{\partial e} &= 2e\\ \frac{\partial e}{\partial c} &= 1\\ \frac{\partial e}{\partial d} &= -1\\ \frac{\partial c}{\partial a} &= 1\\ \frac{\partial d}{\partial a} &= b\\ \frac{\partial d}{\partial b} &= a + 3b^2 \end{align} \]
So, the full gradients are:
\[ \begin{align} \frac{\partial f}{\partial a} &= \frac{\partial f}{\partial e} \cdot \frac{\partial e}{\partial c} \cdot \frac{\partial c}{\partial a} + \frac{\partial f}{\partial e} \cdot \frac{\partial e}{\partial d} \cdot \frac{\partial d}{\partial a}\\ &= 2e \cdot 1 \cdot 1 + 2e \cdot (-1) \cdot b\\ &= 2e(1 - b) \end{align} \]
\[ \begin{align} \frac{\partial f}{\partial b} &= \frac{\partial f}{\partial e} \cdot \frac{\partial e}{\partial c} \cdot \frac{\partial c}{\partial b} + \frac{\partial f}{\partial e} \cdot \frac{\partial e}{\partial d} \cdot \frac{\partial d}{\partial b}\\ &= 2e \cdot 1 \cdot 1 + 2e \cdot (-1) \cdot (a + 3b^2)\\ &= 2e(1 - a - 3b^2) \end{align} \]
The output of this example would show the computed value of f and the gradients with respect to a and b, demonstrating that our automatic differentiation engine correctly implements these mathematical derivatives.
What's Next for Whalegrad
- N-dimensional tensors: Extending support beyond scalar values to matrices and higher-dimensional tensors.
- More operations: Implementing additional operations like convolution, pooling, and softmax.
- Optimized memory management: Improving how tensors are created and destroyed to minimize memory leaks.
- SIMD optimization: Using low-level CPU features for better performance on numerical operations.
- GPU support: Adding optional GPU acceleration for matrix operations.
Thank you for reading!
You can support me: