LLMs from Scratch: Day 4
After a busy week of work, we're back! Before diving into the construction of the full network, we still need to implement a few things, and I think now is also a good time to think about how we mask the attention in our decoder during training.
As we know, the decoder predicts the next token in a sequence given some input during inference. We then feed this new token back into the decoder, along with the rest of the original sequence, and repeat until we hit the stop token, often something like <stop> or <EOS>. The decoder input is shifted by one token, usually <start> or something like that. However, during training, we already have the full known output. This means we can effectively batch the decoder's training on a single sentence in one forward pass rather than doing several consecutive forward passes. The catch is that during inference time, we don't have access to future tokens, so the tokens that occur in the future must be masked within our attention blocks.
Imagine a sequence of 100 tokens embedded into . For self-attention, this means all have shape (100, 512). Consider just the first token now, the first row of 512 values. When we perform the operation in attention, we are taking the dot product of the first token with itself --- this is totally fine. But now imagine the first two tokens stacked on top of each other with shape (2, 512). The dot product will compare the first token with itself, and then the first token against the second token. However, this second operation is comparing the current token (1) with a future token (2), which is not possible without the full sequence. Moving forward, we compare the second token with the first, which is ok since the second token occurs after the first, and finally the second token with itself. So we end up with a 2x2 matrix, but the top-right corner is an invalid comparison in the case where we don't have the full sequence.
Taking this out further, we can see that the parts that do not need to be masked are the lower-triangular portion of . That means that for masked attention, we simply need to generate a lower-triangular matrix of ones repeat it along the batch dimension, and do element-wise multiplication before performing the softmax within attention. I'll add a mask parameter to the ScaledDotProductAttention function from earlier:
def ScaledDotProductAttention(Q,K,V, mask=None):
kq = Q @ jnp.swapaxes(K, -1, -2)
if mask is not None:
kq = kq * mask
kq = kq / jnp.sqrt(K.size)
kq = softmax(kq)
return kq @ VI also took the time to implement a very simple dropout function that can be applied to activations during training:
def dropout(arr, P=0.1, key=jax.random.key(42)):
# First generate a uniform random matrix with the same shape as arr
mask = jax.random.uniform(key, arr.shape)
# Now, use P as a dropout threshold
mask = mask.at[mask>P].set(1)
mask = mask.at[mask<=P].set(0)
# Now do element-wise mul between arr and mask
return arr * maskThe last piece that I think needs to be considered before jumping into full model construction + training is the process of tokenizing our inputs.
I used the BPE tokenizer from the Moses decoder which was used by Seq2Seq and then Attention Is All You Need. To summarize, AIAYN noted that the model training was highly sensitive to the tokenizer, so I opted to reuse theirs rather than implement it from scratch. This tokenizer ultimately gave us two relevant files: a joint vocab file for English and German (it gives individual vocab files too, but Seq2Seq notes that the join vocab actually does better), and a BPE merges file that describes the merges that happen between substrings.
The basic process for tokenizing a sentence then becomes:
- Split the sentence into individual words.
- Split each word into its characters.
- Iterate through all the possible merges we can make from the BPE file for that word.
- Find the minimum number of tokens needed to represent the word.
- Repeat for the entire sentence.
At first this sounds like a relativelty straightforward process, but the challenge becomes clear when we start to think about how you find the relevant merges and optimize for the fewest tokens per word. Naively you could loop through the merges file for every word, saving any merge that produces a substring present in each word. Then iterate through all combinations of substrings for each words, tracking the fewest number of substrings used. However, the outer loop over the merges file is , where is the number of merges (in this case 32,000), and is the sentence length. The inner loop is dependent on the number of valid substrings for each word, but assuming an average of , it's something like . In other words, really really slow.
Instead, there's a better way to do this. LeetCode fans may recall problems about "prefixes" that involve Tries. Begrudgingly, I have found an actual use for them, and for dynamic programming. The implemention is roughly the following: Create a Trie from the merges file. Each child node will add a character to its parent's existing substring. Then, when we look for valid substrings, we simply iterate through a word, traverse down the Trie for each valid substring until there are none, and store all the valid substrings originating from each character. Then we create a dp array to reconstruct the word from substrings, recording the "score" (number of tokens), required to get there. Once we have the optimal number of substrings, we work back to recover what the specific substrings are. Finally, using a hashmap, we can just lookup each substring to get its corresponding token.
I'll try and implement those in day 5.