← Writing

LLMs From Scratch: Day 1

May 21, 2026

The first step to all of this was going back and re-reading Attention Is All You Need. It does a good job of introducing a lot of the key concepts used in modern LLM architectures, despite the fact that it's almost a decade old. Namely, they introduce the transformer, which they built using their own attention mechansim, Scaled Dot Product Attention.

I also made the decision to do this project in JAX. This was for a couple reasons: one, I have never really used JAX, so I thought it would be good to learn; two, it gets rid of many of the conventions other ML frontends like PyTorch provide. While I will use the underlying JAX linear algebra and autodifferentiation, I'm building everything on top of that, and JAX makes it hard to cheat by default. Finally, because it departs from the class-based design that PyTorch uses, it reinforces a lot of the lower-level concepts that you take for granted when using the former, so again it presents an opportunity to experience a lot of ML advances from first-principles.

After reading and taking notes on AIAYN, I implemented two things in JAX.

First, I implemented the Scaled Dot Product Attention. This was pretty straightforward after consulting the JAX documentation.

Second, I implemented a simple feed-forward network the used the softmax that I wrote for Scaled Dot Product Attention as an activation. While not difficult theoritcally, I spent a lot of time just learning the JAX conventions and debugging parts of the code that JAX didn't like. I used the example from the JAX documentation for MNIST to create a quick training loop to make sure everything trained correctly, and achieved ~91.66% accuracy after 20 epochs.

See you in day 2.