skip to content
Nick Wall

Scratch

/ 11 min read

Last Updated:

About Scratch

This is a collection of ML models and methods implemented from scratch, compiled for my learning and reference. I use this project to digest new technologies and research, and just to implement things for fun. Not all of the implementations are optimized but they are all easy to read and understand.

The source code is available on Github.

I am always adding new models and methods when I find notable papers or need to implement something myself to fully grok the implementation details. If you have a suggestion for my roadmap, please open an issue on GitHub.

Update: July 2024

I have fully migrated the entire project to JAX. Flax’s new NNX API is really promising, I am excited to see how it will evolve. I have no problem being an early adopter and I personally think this is the correct direction for ML research and development. NNX is a very powerful API, as it has a 2 part solution. Leveraging a graph with reference semantics to enable intuitive Module Graph operations such as weight tying, model surgery, and more, while still preserving the utility of PyTrees as the core data structure gets the best of both worlds.

This simplicity is what sold me to try NNX and dedicate the energy to building a comprehensive set of implementations, since to my knowledge no one else has done this yet.

from flax import nnx
import optax
# Define a model using the new API. This API is so similar to PyTorch, enabling
# better interoperability with the rest of the research ecosystem.
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
# Initialize the model eagerly and create an optimizer
model = Model(2, 64, 3, rngs=nnx.Rngs(0))
# Note that the optimizer can leverage reference sharing while still having
# interoperability with the rest of the JAX ecosystem because it is a PyTree.
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
# Lifted transforms enable writing normal JAX code that can be JIT compiled and the NNX
# semantics are preserved.
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x)
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
# The optimizer state is still a PyTree, but it can be updated in-place because
# references are shared.
optimizer.update(grads)
return loss

This is a really great improvement in the JAX ecosystem. It has so many benefits, I am excited to see how it evolves.


I have fully moved over all of the deep learning models to NNX, and have added some new models as well. Additionally, I have moved the more tradtional ML methods over to using JAX as a backend as well.

Currently Implemented

The following methodologies are currently implemented in the main branch:

Deep Learning

I have developed annotated implementations of many popular deep learning models, as outlined in their respective papers. The models are optimized for readability and ease of understanding. This is not to suggest that they are not performant, but rather I choose to spell out the details of the implementation for clarity.

Beyond just models, there are many common deep learning components that are implemented as well such as common layers for composability such as:

  • MultiQueryAttention as proposed in Fast Transformer Decoding: One Write-Head is All You Need. This is one of my favorites, Noam Shazeer’s strikes again, just look how simple this is:

    Q=einsum(bnd,hdkbhnk,X,Pq)K=einsum(bmd,dkbmk,M,Pk)V=einsum(bmd,dvbmv,M,Pv)logits=einsum(bhnk,bmkbhnm,Q,K)weights=softmax(logits+mask)O=einsum(bhnm,bmvbhnv,weights,V)Y=einsum(bhnv,hdvbnd,O,Po) \text{Q} = \text{einsum}('bnd,hdk \rightarrow bhnk', X, P_q) \\ \text{K} = \text{einsum}('bmd,dk \rightarrow bmk', M, P_k) \\ \text{V} = \text{einsum}('bmd,dv \rightarrow bmv', M, P_v) \\ \text{logits} = \text{einsum}('bhnk,bmk \rightarrow bhnm', Q, K) \\ \text{weights}= \text{softmax}(\text{logits} + \text{mask}) \\ \text{O} = \text{einsum}('bhnm,bmv \rightarrow bhnv', \text{weights}, V) \\ \text{Y} = \text{einsum}('bhnv,hdv \rightarrow bnd', O, P_o) \\
    so simple, yet it dramatically improves throughput and memory utilization.

  • MultiHeadAttention as proposed in Attention Is All You Need

  • GroupedQueryAttention as proposed in GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

New layers are being added constantly, such as GNN elements, advancements in MoE routing, and more.

I also have an interest in new optimization techniques such as schedule-free optimizers and distributed aware training such as DistributedShampoo as outlined in Scalable Scond Order Optimization for Deep Learning. Examples such as these are on my roadmap.

Language Modeling

I have implemented a number of language modeling models that are used in industry and have been used to train state-of-the-art models. These includes models used in a variety of different applications such as text generation, classification, and summarization. I intend to add new models and reimplement them. I plan to add GPT-2, GPT-Neo, Mistral/Mixtral, and a few others that have become so incredibly popular. Additionally, I will be adding a few less popular models that had papers I really interesting and high quality, such as MPT, T5, Mixture of Depths, and more. Gemma 2 seems like an interesting project as well because I really liked Gemma 1, it performed well, but I have found the exisiting fine tuning code for v2 to be really underwhelming in its results.

Llama

The Llama series of models have become ubiquitous, as they are simple to implement, have incredible performance, and are very easy to fine tune. I believe that the choice to not deviate from standard transformer components of the 2022-2024 era (RMS pre-norm, Grouped Query Attention, and so on) are the key to their success. Scratch contains implementations of Llama 2 and Llama 3, as the only difference between Llama 1 and 2 is data.

LLama has been the foundation of much research and engineering developments. This is most notable in the movement towards local LLMs with projects like llama.cpp ollama, and so much more. Going forward, I intend to use the llama models as a baseline implementation for a feature to convert HuggingFace Safetensors checkpoints into orbax weights to enable using pretrained checkpoints with all of the models in the scratch repository.

BERT

I have implemented BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding because it is so ubiquitous as one of the original papers implementing the Transformer architecture for language modeling. BERT is really interesting because its popularity arose not from generative CLM objectives, but for its performance on a variety of different tasks such as question answering, sentiment analysis, and summarization. To honor the original glory of BERT, my implementation comes in several flavors: BertForSequenceClassification, BertForTokenClassification, BertForQuestionAnswering.

OLMo

OLMo: Accelerating the Science of Language Models is a paper that I thought was really underappreciated. The work of the team, primarily out of Allen Institute for Artificial Intelligence (AI2) and the University of Washington, is honestly groundbreaking. It is so rare for the entire stack for generative language models to be open sourced and well documented. So many things go wrong, and require expert knowledge that is not even documented anwhere other than the discussions between practitioners. OLMo is a great example of how the community can come together to build a better future for generative language models. Even though the model is not SOTA, it is still impressive and includes interesting notes about training a large scale model. Many of the issues they face are not really relevant on the JAX stack that I implemented, but forcing me to consider them and highlight the JAX way of doing things is definitely a useful exercise.

Mamba 2

When the original Mamba paper was released I was so intrigued. The linear and sub-quadratic attention mechanisms, that have been proposed over the last few years, have always felt like a band-aid solution to a fundemental problem. I really like the exploration of SSMs and leveraging a hardware-informed design. My only real complaint with the original Mamba paper is that the direction of AI research into transformers has forced hardware designers to double down on Matrix Multiplication. As a result, any architecture needs to just accept the reality of matmuls to maximize the hardware we have. Mamba 2, in Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality, does this and the theory of the state space duality is utterly fascinating. The details of the paper are fascinating. My implementation is annotated to attempt to make clear how the decomposition of the SSM and the scalar case genaralization with the semiseperable matrix enable redefining the SSM as effectively a linear attention mechanism.

The details in the paper and the write up by the authors are, in my opinion, must reads.

First defining the SSD layers as a selective SSM, building on the foundations of SSMs and the work from the original paper using selective scans:

ht=Atht1+Btxtyt=CtTyt h_t = A_t h_{t-1} + B_t x_t \\ y_t = C_t^T y_t

And the proposing the scalar-identity structure on

A\textit{A}

Which when formally defined can be written as a sequence transformation:

XYX \mapsto Y
Y(T,P)=SSM(A(T),B(T,N),C(T,N))(X(T,P)) \mathbf{Y}^{(\text{T,P})} = \text{SSM}(A^{(\text{T})}, B^{(\text{T,N})}, C^{(\text{T,N})})(X^{(\text{T,P})})

The progression that they define then reveals a dual form that is similar to common linear-attention mechanisms:

M=LCBTR(T,τ) M = L \circ C B^T \in \mathbb{R}^{(\text{T}, \tau)}

This blew my mind when it first clicked. If you have not read any of the work from Tri Dao and Albert Gu I would recommend you do so. Their work is so fascinating, and reveals how much the way we think about sequence tasks can be rediscovered from first principles.

Vision Models

Currently the vision models are exclusively focused on image recognition. I have implemented a number of vision models that are used in the industry and have been used to train state-of-the-art models. I intend to extend this section to include more vision objectives, architecture, and techniques. Focusing on a simpler objective enabled me to really standardize a set of primitives that would be useful for other models; such as my dataset interface, the trainer design, and the development feedback loop.

ResNet 1.5

I have implemented Deep Residual Learning for Image Recognition with certain adjustments that have become popular in the industry. ResNet 1.5 is a very popular architecture that has been used across many different fields, and is a reference architecture for much emerging research. The ResNet architecture really spearheaded the use of residual connections in deep models.

Vision Transformer

I have implemented An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. The Vision Transformer pioneered the use of Transformers outside of NLP and has been a very popular architecture in the deep learning community. The main contribution of the Vision Transformer is the use of spatial patches to capture information from the input image.

Swin Transformer

I have implemented Swin Transformer: Hierarchical Vision Transformer using Shifted Windows with a focus on the Swin Transformer V2 architecture. The Swin Transformer is a very ineresting evolution of Vision Transformers as it implements a new attention mechanism that leverages the spatial information of the input image with a shifting window. The shifting window allows the model to focus on different parts of the image at different scales.

Convolutional Neural Network

For completness and the ability to compare against other models, I have implemented a simple CNN using modern techniques aligned with the original description in Gradient-based learning applied to document recognition.

Classical Machine Learning Methods

Regression Methods

Linear Regression is one of the most well understood methods of statistical learning. I have implemented a number of regression methods in a way that is so natural to JAX. The Linear Regression implementation y^=wTx+b\hat{y} = \mathbf{w}^T \mathbf{x} + b with L1 and L2 regularizations in the Lasso and Ridge cases

L(w,b)=12ni=1n(y^iyi)2+λw1 \mathcal{L}(\mathbf{w}, b) = \frac{1}{2n} \sum_{i=1}^{n} (\hat{y}_i - y_i)^2 + \lambda \|\mathbf{w}\|_1
L(w,b)=12ni=1n(y^iyi)2+λ2w22 \mathcal{L}(\mathbf{w}, b) = \frac{1}{2n} \sum_{i=1}^{n} (\hat{y}_i - y_i)^2 + \frac{\lambda}{2} \|\mathbf{w}\|_2^2

as well as ElasticNet demonstrate a highly composable API that is easy to reason about and easy to extend. The implementation is annotated with notes that explain the math and the implementation details.

L(w,b)=12ni=1n(y^iyi)2+λ1w1+λ22w22 \mathcal{L}(\mathbf{w}, b) = \frac{1}{2n} \sum_{i=1}^{n} (\hat{y}_i - y_i)^2 + \lambda_1 \|\mathbf{w}\|_1 + \frac{\lambda_2}{2} \|\mathbf{w}\|_2^2

I have extended the linear methods to include logistic regression which just appends the sigmoid function:

σ(x)=11+ex \sigma(x) = \frac{1}{1 + e^{-x}}

to enable classification use cases.

K-Nearest Neighbors

K Nearest Neighbors is a simple and intuitive method of regression. As outlined in Nearest Neighbor pattern classification, KNNs use the training data distribution as a clustering approximation. The KNN algorithm is non-parametric and lazy, meaning it makes no assumptions about the underlying data distribution and does not learn a discriminative function from the training data. Instead, it stores all the training data and performs computation only during the prediction phase.

y^q={mode(y(1),y(2),,y(k))if classification1ki=1ky(i)if regression\hat{y}_q = \begin{cases} \text{mode}(y_{(1)}, y_{(2)}, \ldots, y_{(k)}) & \text{if classification} \\ \frac{1}{k} \sum_{i=1}^{k} y_{(i)} & \text{if regression} \end{cases}

Support Vector Machines (SVMs)

Support Vector Machines (SVMs) are supervised learning models used for classification and regression tasks. Introduced in the paper, Support-vector networks by Cortes and Vapnik, SVMs aim to find the optimal hyperplane that maximizes the margin between different classes in the feature space.

SVMs work by transforming the input data into a higher-dimensional space using a kernel function, and then finding the hyperplane that best separates the classes in this space. The points that lie closest to the decision boundary are called support vectors, and they play a crucial role in defining the hyperplane.

For classification, the decision function for SVMs can be defined as:

f(x)=sign(wTx+b)f(\mathbf{x}) = \text{sign}(\mathbf{w}^T \mathbf{x} + b)

For regression, SVMs use a method called Support Vector Regression (SVR), which aims to fit the error within a certain threshold. The decision function for SVR is:

f(x)=i=1n(αiαi)K(xi,x)+bf(\mathbf{x}) = \sum_{i=1}^{n} (\alpha_i - \alpha_i^*) K(\mathbf{x}_i, \mathbf{x}) + b

where:

  • αi\alpha_i and αi\alpha_i^* are the Lagrange multipliers
  • K(xi,x)K(\mathbf{x}_i, \mathbf{x}) is the kernel function
  • bb is the bias term

SVMs are powerful models that are particularly effective in high-dimensional spaces and are known for their robustness against overfitting, especially in cases where the number of dimensions exceeds the number of samples.