Broadly speaking, there are two ways one can approach machine learning:1 from the top, or from the bottom. Being a physicist I of course chose the latter, starting with the fundamentals. After reading a few statistics and deep learning textbooks, playing around with MNIST, completing some transformer/LLM mini-projects, and getting up to speed on reinforcement learning through OpenAI Spinning Up, I decided it was time get some hands-on experience with fine-tuning. This felt special because it was the last missing piece I needed to be able to make a practically useful LLM from scratch, and it finally made contact with the first step one would take in the top-down approach to learning ML: loading a pretrained model from Hugging Face and using their API to fine-tune it.
This mini-project took longer than I expected, and it was less conceptual than most of my self-study so far, but I still think it was a good use of time. Likewise, this blog post will not be particularly conceptual, but it was a good exercise to organize and share my thoughts. The main target audience for this post is my former self at the start of this mini-project, as the title suggests.
In the context of language models, the term “fine-tuning” is used in contrast with pretraining. In a sense both are just training, so the distinction is more about intent and the dataset than the training loop itself. The one sentence explanation would be that pretraining teaches the model language, and fine-tuning teaches it some particular language-based task. The pretraining phase is both more extensive and more general, and this is reflected in the choice of dataset. To first approximation, modern LLM’s are pretrained on the text of the whole internet. This is the biggest and broadest dataset imaginable, consisting of many billions or even trillions of tokens. In contrast, datasets for fine-tuning consist of examples of the task of interest, and contain only (many) thousands of data points.
Before this mini-project I didn’t appreciate how bad raw pretrained models are at most particular tasks, and therefore how important fine-tuning is. In retrospect this makes sense though, since a random prompt on the internet won’t necessarily be followed by a good answer. Frequently it’s followed by a bad answer, or a refusal to answer, or even a non-sequitur. And in the case of classification tasks like sentiment analysis, which requires switching out the final layer’s language model head for a randomly initialized classification head, the model initially performs at the level of random chance. So fine-tuning is essential.
The most basic form of fine-tuning simply consists of further training of the whole model with a new dataset. Because the fine-tuning dataset is much smaller than the pretraining dataset, this is still much faster and cheaper than pretraining. But it still seems a bit like killing a mosquito with a hammer, and it may be prohibitively expensive (in compute or memory) on small local devices. To address this, researchers have proposed a zoo of lighter-weight forms of fine-tuning that achieve comparable performance while training far fewer parameters. The leader of the pack is called low-rank adaptation (LoRA), and was introduced in 2021 by researchers at Microsoft. It modifies matrix-valued parameters by adding a low-rank term to them (parametrized as the product of two narrow rectangular matrices) and training just that low-rank part. This reduces the number of trainable parameters by hundreds or even thousands of times, and yet performs slightly better than full fine-tuning for simple tasks!2 A true win-win scenario.
.
Dirty Hands
To get hands-on experience with these techniques I loaded the pretrained version of GPT-2 from Hugging Face, swapped out the language model head for a binary classification head, and fine-tuned it in a few different ways for sentiment analysis of the SST-2 dataset of movie reviews. My Kaggle notebook is linked above.
The dataset needed to be tokenized and cleaned a bit, through truncation and padding, so that all examples contained the same number of token IDs. This was fairly straightforward after loading the tokenizer from Hugging Face and familiarizing with their DatasetDict class, which is basically a dict of Dataset objects. In turn, a Dataset is like a dict of lists or “columns”. In the early stages it helped to restrict to smaller datasets using Dataset’s .select() method.
The first approach I took was coding a basic training loop by hand and feeding the model to it. It was gratifying to see most of what I knew about smaller models carry over seamlessly to this larger model, around 100M parameters. The main hiccups came in the form of numerical issues with the optimizers. With its default values, AdamW was spitting out NaN right away. Increasing its eps parameter to 1e-4 from the default 1e-8 resolved this. But it still struggled to improve the model beyond 50% accuracy, so I switched to SGD to simplify the troubleshooting. It worked and got the test accuracy up around 85%, but I noticed that the training accuracy was stuck under 90%, which meant that something was going wrong. Tweaking hyperparameters and increasing the number of epochs didn’t fix it, so it seems to have been a genuine shortcoming of SGD at this scale.3 This is plausible, since SGD is notoriously finicky for large transformers. I switched back to AdamW, and after lowering the learning rate to 1e-5 from the default of 1e-3, it worked like a charm! It was interesting to see a more pronounced difference between SGD and AdamW at this scale, since at smaller scales both perform adequately. With AdamW training accuracy increased without bound after each epoch, and test accuracy plateaued around 91.5% after getting above 90% on the first epoch. This was with batch size 32. Putting the model in .train() or .eval() mode didn’t make much of a difference (with the default dropout of 0.1). It took about 3 minutes for one epoch on the full dataset of 67k data points, on a single T4 GPU.
Next, I embarked on the process of learning Hugging Face’s Trainer API, which is the “quick” and “convenient” way to train models. First I just asked Claude to write a throwaway script for me. It had persistent errors and I eventually gave up on it. The API has over 100 kwargs, and I didn’t realize at first how few of them are necessary to get up and running. Next I tried code from several tutorials and github repos, but most weren’t recent enough and fell victim to breaking changes in the packages they used. Eventually I found a tutorial that was free of errors, but when I called Trainer.train() my Kaggle notebook just hung indefinitely. After much consternation, in the end it turned out that Kaggle was using wandb.ai (weights & biases) as the backend for the Trainer.train() method, and it was trying to prompt the user to log in, but Kaggle’s UI wasn’t passing that prompt to me. So I had no clue what the problem was. I only figured this out because I eventually tried pasting the same code in a Google Colab notebook, and its UI did pass the weights & biases login prompt to me. The main lesson I took away from this experience is that I should index more heavily on getting code running, even if it doesn’t do exactly what I want at first. Modifying code that already works is a much more robust feedback loop than debugging code that mysteriously doesn’t work.
With that bug finally out of the way, using Trainer actually went pretty smoothly. It has various convenient features built in, such as automatic detection and use of multiple GPUs. To make the comparison with my by-hand loop as direct as possible, I passed my AdamW optimizer object to Trainer and used the same batch size of 32, the same constant learning rate schedule, and the same fp16 precision. At 4.5 minutes per epoch it was a bit slower than my by-hand loop, but this was forgivable considering the extra overhead of thorough checkpointing. What was less forgivable, and more confusing, was its poor accuracy. Sometimes it got to 85% after an epoch, but most of the time it stayed near 50%. This was a bit of a surprise to me, since naively the training loops should have been identical. Even after lowering the learning rate from 1e-5 to 5e-6 to improve stability, Trainer’s test accuracy still lagged behind, plateauing around 89.5% compared with my by-hand loop’s 91.5%.
After being unable to find a bug causing this low test accuracy, I accepted that it might be legitimate. Thankfully Kaggle provides several different GPUs, and switching from the T4 to the P100 offered some further insight. On the P100 GPU, Trainer could handle the higher learning rate of 1e-5, and several times in a row its accuracy after a single epoch was 88.99%.4 Continued training plateaued around 90.4%. I don’t fully understand why the performances differed at all between Trainer and my by-hand loop, but the fact that Trainer’s own performance varied between different GPUs suggests to me that the culprit is fairly low-level, akin to numerical instability, and therefore not worth obsessing over.
For completeness I also tested my by-hand loop on the P100, and its performance was basically unchanged, still usually exceeding 90% test accuracy after a single epoch and plateauing around 91.5%. It seems useful at this point to collect these rough results in a table. For unknown reasons, my by-hand loop outperforms the Trainer API by 1-2%. At an eye-test level the standard deviation for single epoch runs was around 1% in all cases, so it seems unlikely to me that the difference in plateaued accuracies would wash out with more training runs.
| Test accuracy (T4) | Test accuracy (P100) | |
| Trainer API | 89.5% | 90.4% |
| By-hand loop | 91.5% | 91.4% |
Now we finally turn to low-rank adaptation (LoRA). All it took was importing the peft library (parameter-efficient fine-tuning), initializing a LoraConfig object, and wrapping the model as:
model_lora = get_peft_model(model_og, config)
My config file used rank r=8 and alpha parameter lora_alpha=16, following the rule of thumb that it should be twice the rank. I also used LoRA dropout 0.05. The API automatically chose which modules to apply LoRA to, but by printing the model I could see that it was only done for the QKV matrices, not the MLPs, and the QKV matrices were packaged together into one 768×2304 matrix, which was decomposed into (768×8)x(8×2304). The number of trainable parameters was 290k, or 0.24% of the model’s 125M total parameters.
Once I had the LoRA model, I just fed it into my by-hand training loop. Presumably using Trainer would have also worked, but may have led to worse performance. To get robust results all I had to do was increase the learning rate to 1e-4 from my previous 1e-5, keeping eps=1e-4. On both the T4 and P100 GPUs, the test accuracy after one epoch was usually between 88.5% and 89.0%. The early-stopped test accuracies from three runs were 90.8%, 90.7%, and 90.4%. This is slightly below the 91.5% I got with full fine-tuning, but in the same ballpark.
The speed of the LoRA training was very similar to that of full fine-tuning, taking about 3 minutes per epoch. I had expected it to be faster, but the similar speed seems plausible given that the forward pass and overhead of loading data points and passing them from CPU to GPU is the same for both approaches. The original LoRA paper mentions that despite the huge reduction in trainable parameters, one should not expect to save orders of magnitude in time or space since the whole model must still be stored, and a forward pass must still be done. The real savings come from being able to store many different LoRA-fine-tuned versions of the same base model for almost no extra cost. Even for a single task, this is still useful because it makes checkpointing cheaper.
To get a sense for whether my accuracies were reasonable for GPT-2 on the SST-2 dataset, I looked through publicly available models on Hugging Face doing similar tasks. Most of the claimed test accuracies were in the 90-92% range. One of the most relevant comparisons I found was from Michele Cafagna. She trained gpt2-medium (about 350M parameters, as opposed to my 100M) on the SST-2 dataset and reported a test accuracy of 92% on what I assume is the same 872 data point validation dataset that I used. However, when I loaded her model and checked its accuracy, I found an accuracy of only 90.3%. Either way, this was a good confidence boost that my training runs were reasonable. The full summary of results is shown in this table:
| Test accuracy (T4) | Test accuracy (P100) | |
| Full fine-tuning (Trainer API) | 89.5% | 90.4% |
| Full fine-tuning (by-hand loop) | 91.5% | 91.4% |
| LoRA (by-hand loop) | 90.7% | 90.8% |
| Cafagna’s gpt2-medium | 90.3% | 90.3% |
In the end, the three methods weren’t hugely different. The choice of what to use in practice might hinge more on time and flexibility considerations than quality considerations. In any case, I’m glad to now be familiar with all three approaches. Given more time it would have been nice to play around with RLHF and reward models, but I’ll stop here.
- Or any new subject, really. ↩︎
- Later research showed that for more complicated tasks like math and coding, LoRA is insufficiently powerful. ↩︎
- To be fair I was only using fp16 precision rather than full fp32, but this is standard and still fairly precise. ↩︎
- This level of consistency suggests that the model was in model.eval() mode during training. I had thought Trainer enabled model.train() mode by default, but I learned after finishing data collection that this is not actually the case. ↩︎
Leave a comment