Lets build GPT-3: Hyperparameters, Algorithms, Distributed Training  (part 3)

Lets build GPT-3: Hyperparameters, Algorithms, Distributed Training (part 3)

Tags
AI
LLM
NLP
Published
January 15, 2025
Author
Philip Redford

Intro

This article is part 3 in a three part series on building GPT-2 from scratch. You can find part one here and part two here. In this post we will focus on optimising our model implementation, focussing on the algorithms used. We will also explore distributed training to speed up our training process. We will assume you are familiar with the base code which can be found here.
Note for this blog in particular we will be making use of Lambda Labs for GPUs.

Algorithm Optimisations

Up until this point we have been following the approach used in GPT-2. GPT-3 is very similar to GPT-2 except it is trained for much longer and with a larger context window. GPT-3 was also trained with slightly more optimised parameters. Thus we can add some of these optimisations into our model.

Learning Rate Schedular

GPT-3 uses Adam optimiser with .
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)
They also clip the global norm of the gradient at 1.0, where the global norm of the gradient is the square root of the sum of all gradients squared. This helps if by some bad luck you end up with a very high loss on a batch which may lead to a large change in model parameters unexpectedly.
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
GPT-3 does not use a constant learning rate through training, they use a cosine decay schedule to slower decay the learning rate to 10% of its original value. They also use a warm up schedule to slower increase the learning rate over the first 375M tokens.
Learning Rate Schedule
Learning Rate Schedule
We can implement this using Python
max_lr = 6e-4 min_lr = max_lr * 0.1 warmup_steps = 10 max_steps = 30 def get_lr(it: int): # 1. Linear warmup if it < warmup_steps: return max_lr * (it + 1) / warmup_steps # 2. If fully decayed, return min_lr if it >= max_steps: return min_lr # 3. Otherwise, return the current learning rate based on the cosine decay decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return min_lr + (max_lr - min_lr) * coeff
We then need to update our training loop to include this update before the optimiser step:
for step in range(max_steps): ... # Determine the learning rate for this step lr = get_lr(step) for param_group in optimizer.param_groups: param_group["lr"] = l
PyTorch also has an implementation of this which uses the concept of a schedular and schedular step:
from torch.optim.lr_scheduler import CosineAnnealingLR scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.01) for step in range(max_steps): ... scheduler.step()
There are many other learning rate schedules and there is no definitive conclusion on what the optimal learning rate schedule should be.

Batch Size Schedular

GPT-3 also uses batch size scheduling from 32k tokens to the full batch size between the first 4-12 billion tokens depending on model size. Generally this only offers a minor improvement and it mainly aimed at improving the systems rather than an algorithmic improvement.
The intuition for this is that the early gradient updates and mainly focusses on obvious/easy updates, like ignoring certain tokens that are very rare. For this blog we will skip this step but its worth being aware of.
When loading/batching the data, the datapoints are sampled without replacement (within each epoch) in order to minimise overfitting.

Weight Decay

We will use a weight decay of 0.1 which adds a form of regularisation to the model. We will add an optimiser method to our model:
def configure_optimizers(self, weight_decay=0.1, learning_rate=3e-4, device="cpu"): # Step 1: get the parameters to optimize (requires_grad=True) param_dict = {n: p for n, p in self.named_parameters() if p.requires_grad} # Step 2: Create optimizer groups. All 2D parameters will be decayed, all others unchanged # i.e. all weights tensors in matmuls + embeddings will be decayed, all biases, and layernorms unchanged decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] no_decay_params = [p for n, p in param_dict.items() if p.dim() < 2] optimizer_grouped_parameters = [ {"params": decay_params, "weight_decay": weight_decay}, {"params": no_decay_params, "weight_decay": 0.0}, ] print( f"num decayed parameter tensors: {len(decay_params)}, with {sum(p.numel() for p in decay_params):,} parameters" ) print( f"num non-decayed parameter tensors: {len(no_decay_params)}, with {sum(p.numel() for p in no_decay_params):,} parameters" ) optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8) return optimizer
And then we configure this before starting our training script
optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=max_lr, device=device)

FusedAdamW

We can also add an optimised version of AdamW using a kernel fusion available in PyTorch. We do this by adding the following lines to our method above:
# Fused AdamW optimizer fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters use_fused = fused_available and device == "cuda" logger.info(f"using fused AdamW: {use_fused}") optimizer = torch.optim.AdamW( optimizer_grouped_parameters, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused )
Fused is a lot faster than the standard approach when working on a GPU with CUDA. It works by optimising the for loop that is used to update all the parameters.

Gradient Accumulation

GPT-3 (small/medium/large) was trained with a token batch size of 0.5M tokens. Given the sequence length was 1024, this leads to a sequence batch size of about 488. This batch size is far too big to fit on a single GPU, but it is important to use this size as it has been optimised alongside all the other parameters.
In order to ‘simulate’ this batch size we need to add something called Gradient Accumulation which allows us to process data points in series but add up all the individual gradients in order to create the desired batch size. We will then do a single update once we have accumulated all gradients. This accumulation approaches has the benefit of smoothing the update steps over a wider range of sequences, adding a regularising effect.
# Gradient Accumulation total_batch_size = 52488 # 2**19 ~ 0.5M used for nice number B = 8 # Micro batch size (real is 16) T = 27 # Sequence length (real is 1024) assert total_batch_size % (B * T) == 0, "Total batch size must be divisible by micro batch size and sequence length" gradient_accumulation_steps = total_batch_size // (B * T) logger.info(f"Total batch size: {total_batch_size}") logger.info(f"Gradient accumulation steps: {gradient_accumulation_steps}")
We can implement this in our training loop. We need to adjust our loss calculation because by default our loss function will calculate the mean loss over the mini batch, but we want to accumulate the full batch before taking the mean:
for step in range(max_steps): t0 = time.time() optimizer.zero_grad() for micro_batch in range(gradient_accumulation_steps): x, y = train_loader.next_batch() x = x.to(device) y = y.to(device) # Use autocast to automatically cast the model to the correct precision with torch.autocast(device, dtype=torch.bfloat16): logits, loss = model(x, y) # We need to scale our loss by the number of micro batches to get the correct gradient loss = loss / gradient_accumulation_steps # By keeping the loss in the loop, we can accumulate the gradients over the micro batches loss.backward() ... # Outside the microbatch loop so all accumulated update in one go optimizer.step()

Distributed Training

In order to scale our training and run our loop faster, we can scale our training to multiple GPUs. This scaling can come in two method:
  • Distributed Model Parallel: Same model is split across multiple GPUs. Useful for training large models that will not fit on a single GPU.
  • Distributed Data Parallel: Model copies running on multiple GPUs. Useful for speeding up model training.
Given our model does comfortably fit in GPU memory, we will focus on the latter method.

Distributed Data Parallel (DDP)

We can check how many GPUs we have available with the nvidia-smi command. This will give us something like below (depending on your GPU configuration).
notion image
This is a configuration with a single GPU. With multiple GPUs you will see something like below:
notion image
To do utilise data parallel in PyTorch we can use the DistributedDataParallel method, which launches a process for each GPU. There will essentially be a training script running on each GPU, and every so often the gradients/loss from each process will be averaged and used to update the model. This update needs to be synchronised across all GPUs.
from torch.distributed import destroy_process_group, init_process_group # Initialize the distributed process group # torchrun command sets the env variables RANK, WORLD_SIZE ddp = int(os.environ.get("RANK", -1)) != -1 if ddp: assert torch.cuda.is_available(), "DDP is only supported on CUDA" init_process_group(backend="nccl") ddp_rank = int(os.environ.get("RANK")) ddp_world_size = int(os.environ.get("WORLD_SIZE")) ddp_local_rank = int(os.environ.get("LOCAL_RANK")) device = f"cuda:{ddp_local_rank}" torch.cuda.set_device(ddp_local_rank) master_process = ddp_rank == 0 else: ddp_rank = 0 ddp_world_size = 1 ddp_local_rank = 0 master_process = True # Set the device device = "cpu" if torch.cuda.is_available(): device = "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" logger.info(f"Using device: {device}")
In the above code, the torchrun command sets the relevant env variables which are then used in the script. If we have 8 GPUs and then the world size will be 8. If the process is on the first GPU then the rank will be 0. The logging is all done on the first process or the master process.
We also need to adjust our data loading and gradient accumulation:
total_batch_size = 1024 # 2**20 ~ 1M used for nice number B = 8 # Micro batch size (real is 16) T = 32 # Sequence length (real is 1024) assert total_batch_size % (B * T * ddp_world_size) == 0, ( "Total batch size must be divisible by micro batch size and sequence length" ) gradient_accumulation_steps = total_batch_size // (B * T * ddp_world_size)
We also want to update our logging so only the master process logs to stdout. We do this using our master_process flag:
if master_process: logger.info(f"Total batch size: {total_batch_size}") logger.info(f"Gradient accumulation steps: {gradient_accumulation_steps}")
We also need to update our data loader slightly so each process is using a different chunk of data
class DataLoaderLite: def __init__(self, B: int, T: int, process_rank: int, num_processes: int): ... # This means that each process will start at a different position in the tokens self.current_pos = self.B * self.T * process_rank def next_batch(self): ... self.current_pos += B * T * self.num_processes # If we've reached the end of the tokens, reset the position if self.current_pos + (B * T * self.num_processes + 1) >= len(self.tokens): self.current_pos = self.B * self.T * self.process_rank return x, y
We then need to wrap our model to a DDP model with PyTorch. In the forward pas almost nothing changes in the model. In the backward pass, each individual process will share its gradients using all_reduce which averages out the gradients and deposits the average on all process so they all update in the same way. DDP can actually share information as the backward pass is still being propagated.
from torch.nn.parallel import DistributedDataParallel as DDP if ddp: model = DDP(model, device_ids=[ddp_local_rank])
We need to update our training loop. Because we are using the gradient accumulation step, we do not want to synchronise our updates after each micro step as this would be wasteful. Instead we synchronise on the final micro step. There’s a few ways to achieve this, for example using a context manager, but the simplest is to update the models require_backward_grad_sync variable to be false unless on the final step.
... with torch.autocast(device, dtype=torch.bfloat16): logits, loss = model(x, y) loss = loss / gradient_accumulation_steps if ddp: model.require_backward_grad_sync = micro_batch == gradient_accumulation_steps - 1 loss.backward() ...
During logging, we want to make sure we are logging the total loss across all processes rather than the lost on just the master device. We can do this like this:
if ddp: dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
When we launch our training script we are no longer going to launch with a simple python train_gpy.py call. Instead we will use torchrun
torchrun —standalone -nproc_per_node={n_GPU} train_gpt.py

Data

Datasets used (FineWeb EDU)

GPT-3 uses a filtered commoncrawl dataset (60%) which is a mostly random subset of internet data. Becasue the raw data is mostly garbage, a lot of work goes into cleaning and optimising the dataset to keep the mostly high quality data. Additionally data derived from highly ranked Reddit outbound links is used (22%). A couple of books datasets are used (8% + 8%). The final 3% comes from crawling Wikipedia. Unfortunately GPT never releases the datasets used for training.
Red Pyjamas and Slim Pyjamas datasets are open alternatives which are similar to the datasets described above.
A great modern alternative is the FineWeb EDU dataset release by HuggingFace which has about 6.5M high quality tokens. You can find more information here. We will sample 10B tokens from this dataset for now.
You can download the dataset using the HuggingFace API as below. The full script for loading the processing the dataset can be found here. This script tokenises each document in the dataset, outputting in shards of shard_size to files.
from datasets import load_dataset fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train")

Cleanup

We these changes we also need to update out main training script slightly.
def load_tokens(file_path: str): npt = np.load(file_path) ptt = torch.tensor(npt, dtype=torch.long) return ptt class DataLoaderLite: def __init__(self, B: int, T: int, process_rank: int, num_processes: int, split: str): self.B = B self.T = T self.process_rank = process_rank self.num_processes = num_processes assert split in {"train", "val"} # Get the shards data_root = "edu_fineweb10B" shards = os.listdir(data_root) shards = [s for s in shards if split in s] shards = sorted(shards) shards = [os.path.join(data_root, shard) for shard in shards] self.shards = shards assert len(self.shards) > 0, f"No shards found for split: {split}" if master_process: logger.info(f"Loaded {len(self.shards)} shards for split: {split}") # This means that each process will start at a different position in the tokens self.current_shard = 0 self.tokens = load_tokens(self.shards[self.current_shard]) self.current_pos = self.B * self.T * process_rank def next_batch(self): B, T = self.B, self.T buf = self.tokens[self.current_pos : self.current_pos + B * T + 1] x = buf[:-1].view(B, T) y = buf[1:].view(B, T) self.current_pos += B * T * self.num_processes # If we've reached the end of the tokens, move to next shard if self.current_pos + (B * T * self.num_processes + 1) >= len(self.tokens): # Update the shard if we've reached the end of the current shard self.current_shard = (self.current_shard + 1) % len(self.shards) self.tokens = load_tokens(self.shards[self.current_shard]) self.current_pos = self.B * self.T * self.process_rank return x, y
This mostly changes so we are loading from the shards we created in the previous step and ensures that rather than loop over the same data when we reach the end of the file we instead move to the next file.
We will also update our parameters:
warmup_steps = 715 max_steps = 19073

Train/Validation Split

After our setup in the previous step it is simple to load the val dataset alongside the training set. This validation dataset is about 10% of the full data.
train_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train") val_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val")
We can then add an evaluation step to our training loop:
for step in range(max_steps): t0 = time.time() last_step = step == max_steps - 1 # Validate if step % 100 == 0: model.eval() val_loader.reset() with torch.no_grad(): val_loss_accum = 0.0 val_loss_steps = 20 for _ in range(val_loss_steps): x, y = val_loader.next_batch() x = x.to(device) y = y.to(device) with torch.autocast(device_type=device_type, dtype=torch.bfloat16): logits, loss = model(x, y) loss = loss / val_loss_steps val_loss_accum += loss.detach() if ddp: dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG) if master_process: logger.info(f"Validation loss: {val_loss_accum:.4f}")
Optionally we can also add a text generation step to sample outputs and qualitatively analyse the results. But for some reason this has been causing problems with torch.compile so this step is skipped for now during training if compiling the model. You can see this in the code base here.

Evaluation

HellaSwag

HellaSwag is a useful tool for evaluating the model as we are training and spotting problems early. HellaSwag works by providing the model with an input and then asking it to determine the most likely following text. For now we will skip this step but may return to it in a later post.

Training Run

We are now almost ready to run out full training loop for real. We tracking purposes we will add some improved logging.
log_dir = "log" os.makedirs(log_dir, exist_ok=True) log_file = os.path.join(log_dir, f"log_{ddp_rank}.txt") with open(log_file, "w") as f: pass
Then we can log our training and validation loss:
with open(log_file, "a") as f: f.write(f"{step} train {loss_accum:.6f}\n")
We can also checkpoint our model during training so we can recover our best model for inference.
if master_process: if step > 0 and (step % 5000 == 0 or last_step): checkpoint_path = os.path.join(log_dir, f"checkpoint_{step:05d}.pth") checkpoint = { "model": raw_model.state_dict(), "config": raw_model.config, "step": step, "val_loss": val_loss_accum, "optimizer": optimizer.state_dict(), } torch.save(checkpoint, checkpoint_path) logger.info(f"Saved checkpoint to {checkpoint_path}")
If you’re looking to save time and cost on tokenising and creating dataset shards, you can find and download pre-tokenised and sharded files in this HuggingFace Hub repo created by jfzhang. In order to use this we just need to update our code to use the following lines, all else will be the same:
from huggingface_hub import snapshot_download repo_id = "jfzhang/edu_fineweb10B_tokens_npy_files" local_dir = "./edu_fineweb10B/" snapshot_download(repo_id=repo_id, repo_type="dataset", local_dir=local_dir)

Launch Training

We can now launch our training, in my case with 2 GPUs and a training script train_gpt2.py
torchrun --standalone --nproc_per_node=2 train_gpt2.py
If downloading pre-tokenised data as described above this will take a couple of minutes to download, otherwise tokenising from scratch could take up to 30 minutes. You’re training run should log results at each step as below:
notion image

Results

For my run, to keep costs down I ran with:
  • Total Batch Size = 262144
  • (micro) Batch_size=32
  • Sequence Length = 1024
  • Warmup Steps = 250
  • Max Steps = 5000
  • Max Learning Rate = 6e-4
  • Min Learning Rate = Max Learning Rate * 0.1
total_batch_size = 262144 # 2**18 used for nice number, real GPT is 2**19 B = 32 # Micro batch size (real is 16/32/64 - depends on GPU size) T = 1024 # Sequence length # Learning rate and optimizer parameters max_lr = 6e-4 min_lr = max_lr * 0.1 warmup_steps = 250 # Real GPT is 715 max_steps = 5000 # Real GPT is 19073
You can find the training curve below and validation set performance in comparison to the fully trained GPT-2 model.
notion image
This model if trained for longer, using a larger micro batch and total batch would most likely outperform the GPT-2 benchmark on this dataset.

Summary

In this series we have built an almost exact copy of GPT-2 from scratch. In fact we also almost built a GPT-3 clone as well, as this model is very similar except for the training data used and length of training time. In fact, on unseen evaluations such as HellaSwag, this model actually outperform GPT-3.
In order to convert our model to a ChatGPT style chatbot we would need to fine tune on a question answer labeled dataset. This is different to the dataset we used here which essentially a text completion dataset rather than text response. We could do this using a supervised approach where next tokens are the response the model is expected to generate. We could also fine tune using a Reinforcement Learning Approach called Reinforcement Learning with Human Feedback (RLHF) which uses a value network to estimate human evaluation of model responses to update the agent policy (network weights). This is outside the scope of this series but we may explore this in future articles.