Intro
This article is part 2 in a three part series on building GPT-2 from scratch. You can find part one here. In this post we will focus on optimising our model and training process for efficient training on GPUs. We will assume you are familiar with the base code which can be found here.
Note for this blog in particular we will be making use of Lambda Labs for GPUs.
Optimisations on GPU
Data Types and Mixed Precision
By default in PyTorch every number in our tensors is represented by a
torch.float32
which uses 32 bits to store the number. This float32 represents a high level of precession, but for deep learning this level of precession may be unnecessary. We can actually decrease our precision to float16, which will enable us to do much faster computation. It will also decrease the memory footprint of the model, and training parameters like gradients. There is also a benefit in terms of networking because there are less bits being transferred. We can see the scale of computation as it relates to data type used in the below table. As you can see, by decreasing our precession from float32 down to float16 (Tensor Core) we actually increase are teraFLOPS by almost 30 times. This is a huge computational improvement in training, and allows us to train much faster/cheaper.
INT8
is a date type used for inference but does not work well for training, due to its inability to represent the normal distributions in training. This is because activations and weights follow a normal distribution during training.Generally, lower data types have a much lower memory bandwidth compared to higher memory types, meaning data can be access and moved faster. This is highly valuable as this is typically a limiting factor in GPU utilisation, meaning typically GPUs tend to hover at less than 60% utilisation. But with lower precision the GPU utilisation can significantly improve.
TF32
data type offers and approximation of full FP32
but truncates the precision of the value. This means is can still represents the same range of values, but the number of significant figures is decreased. For our purposes this is fine as most weights in neural networks tend towards more extreme values, i.e. the order of magnitude is more important than the precise number represented.BFloat16
is similar to the above in that is represents a similar range of values but the precession used is significantly lower. This can be particularly useful for very large models. This is apposed to Float16
which decreases the range of values that can be represented, but has increases precision. In order to use Float16
we have to use gradients scalars to update the range of values because the exponent values is decreased. BFloat16
solved this problem by maintaining the same exponent range.Tensor Cores
In the above table you may have spotted the term ‘Tensor Core’, but what exactly is a Tensor Core? A tensor core is a specialised hardware unit designed to accelerate matrix multiplication and convolution operations. This makes them incredibly useful for Deep Learning applications.
Tensor cores were introduced by NVIDIA in their recents GPUs and the main benefits of Tensor Cores over older GPU architectures is:
- They enable mixed-precision computing, allowing for faster computation while maintaining accuracy.
- Significantly speed up training and inference of deep neural networks compared to regular CUDA cores.
- Operations using tensor cores are exposed via APIs like cuBLAS and cuDNN, and in deep learning frameworks like PyTorch and TensorFlow.
- Tensor cores are leveraged automatically by the GPU when using supported data types (e.g. FP16, TF32, BF16) and dimensions that are multiples of 8.
Not all NVIDIA GPUs have tensor cores - they are found in workstation/server/HPC oriented cards like the Quadro, Tesla and A100 series. Gaming-oriented GeForce cards generally do not include them.
Updating the data types
We can define the torch data type using the below method:
torch.set_float32_matmul_precision("high")
This will set the data type to
TensorFloat32
. The other options are highest=float32
, medium=bloat16
. We will also increase the batch size and and sequence length to a size similar to that used when training the original GPT-2. If this batch size does not fit into your GPU memory then decrease the batch size incrementally until you find an optimal batch size. Generally powers of two works best for batch sizes (2, 4, 8, 16, 32, 64, 128). You should increase the batch size to the maximum your GPU can handle.
train_loader = DataLoaderLite(B=16, T=1024)
When we make these changes in the code we see a significant speedup (approx 3x) but this is not as large an increase as we may have expected. This is because the network training is memory bound, meaning there a bottleneck in the pipeline slowing computations. This moving around of bits (memory) can be optimised slightly
BFloat16
In order to utilise
bloat16
we only ever want to use the torch.autocast
method. This should also only be applied to the forward pass and loss calculation. It is not recommended to autocast the backward pass. with torch.autocast(device_type="cuda"): output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step()
When we examine the layers in the network we can see that some layers now use
bloat16
but others still use float32
. This is why this method is known as mixed precision. Its not always clear which operations are cast to lower precision with autocast but you can find more details here. Generally matrix and convolution layers can be cast to lower precision but operations like layer norm, softmax and loss function calculations are less robust to precision reduction so cannot (should not?) be cast to lower precision. You can find more info on mixed precision in PyTorch in this blog here.
Model Compilation
We can use
torch.complile
to compile our model and make it much more efficient. Generally model compilation will cost us some time in ‘compilation time’ but we will gain a lot in model efficiency. Thus usually this tradeoff is worth while. We can implement this in our code in 1 line:model = GPT(GPTConfig()) model.to(device) model = torch.compile(model)
Generally the only times this is not worth doing is during debugging of the model when we want to run it quickly.
The reason this is much faster is that is speeds up the model read/write operations. Without this compilation, the python compiler will greedily move from one operation to the next in order. This is because python does not look ahead, whereas Torch compiler will look at the whole model before running, and therefore can optimise the order of operations.
The first step the torch compiler does is remove the python compiler from the training loop. Then it will compile the entire model without python involved using lower level languages.
The main bottleneck comes from the time it takes data to travel between the GPU and its equivalent of RAM, called High Bandwidth Memory (HBM). Thus
torch.compile
ensure that data is retained on the GPU chip rather than traveling back and forth between the GPU and HBM. This is because torch.compile can see that there is no need to offload data that is about to be used again in the next operation. This is an example of something called Kernel Fusion.Flash Attention
Flash attention was introduced in a 2022 research paper that is an incredibly efficient implementation of attention that uses a kernel fusion operation that combines the different operations required for attention into a single compute step that can be performed on the GPU without offloading to the HBM. This requires a slight rewrite of the attention code thus is not automatically discovered by the torch compile method.
Technically speaking this operation actually requires more computation, but it does not require any offloading of the attention matrix to the HBM. Thus is can be performed up to 7x times faster on a GPU.
With Flash attention we can remove the following lines from our causal self attention forward pass:
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) y = att @ v # (B, nh, T, hs)
And replace with the following code:
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
Note that since Flash attention was introduced, there has also been Flash attention 2 and 3 released which offer further improvements.
Number Choice - why multiples of 2?
Strangely some number choices can also impact the model training efficiency. For example batch sizes that are multiples of 2, [2,4,8,16,32,64,128], tend to work much better than say a batch size of 13 or 17. This is because most things in CUDA work in powers of two, and lots of kernel fusions also are written powers of two, particularly 16,32, and 64.
Therefore we should scan our code and look for ‘ugly’ numbers and try to optimise these. For example our vocab size of 50257 is a quite ugly number. We would rather have a number like 50304 which is divisible by 2,4 8, 32, 64, 128 which makes it highly desirable.
model = GPT(GPTConfig(vocab_size=50304))
This is kind of like adding fake tokens and strangely also adds computation. But when we run this on the GPU we find that it actually improved the runtime by a small 3-4%. When we increase the vocab size, we add a few values to the embedding matrix. But these values will never be used as the tokeniser only has values up to 50257. But this input embedding layer is also used by the output layer, so the the network will learn that these outputs will never been seen, driving the weights to
-inf
. This may seem unusual but actually when training on some texts you may find that some tokens are never observed so many weights will be driven to -inf
.In CUDA many kernels use ‘block tiles’ which use powers of two. When an input does not fit into these clean power of two kernels, a backup process is started using a different kernel which often slows down the whole process, even if we have to add extra padding.
Conclusion
In this post we have worked on optimising our model and training process to efficiently work on GPUs. We did this using:
- Mixed Precision Data Types
- Model Compilation
- Flash Attention
- Cleaning up ‘Ugly Numbers’
In the next post in this series we will work on optimising the transformers algorithm itself, optimise the hyperparameters, and utilise distributed training to speed things up.