Porting GPT-2 to MLX for fun!
what is MLX?
MLX is an array framework purpose built for Apple Silicon; think NumPy but tuned more for Apple’s hardware. It also has higher level abstractions that provide an interface that’s similar to PyTorch. I wanted to learn MLX, what better way to do than to port the model that started it all for LLMs.
the constituents of a GPT
GPT is a decoder only transformer, meaning it goes from a sequence of embedded tokens and then auto-regressively generates a sequence from this pre-fill. It also uses positional embeddings to encode the position of tokens in a sequence.
Each decomposable block is composed of causal (masked) multi-headed self attention, MLP and pre-layer norms for both. This block is then stacked multiple times in series.
Finally, a layer norm and an LM head transforms the hidden dim back up to the vocab size of the tokenizer. This can then be decoded by treating the outputs as logits.
loading pre-trained weights from HF Transformers
Mostly boring stuff here, this involves loading the model weights from the HF Transformers library and then renaming the state_dict keys into a form that can be used with MLX.
improving generation with sampling
Greedy decoding produces the same sequence by picking the most likely outcome every time.
This can be improved by introducing two parameteres temperature and top_k alongside sampling to improve the diversity of generated sequences.
The temperature parameter controls the sharpness of the probability distrbution of the tokens whereas the top_k parameter controls for how many of the options are considered while sampling.
Setting a temperature of 0 falls back to greedy decoding.
what’s next?
Lots of ideas to try! Here’s a couple of them in no particular order:
- Instruction tuned weights for GPT-2
- Training/fine-tuning GPT-2 for some niche task
- Implementing other LLMs such as Llama