Intro to Distributed LLM Training, Part 1: Orchestration & Fault Tolerance

Mar 29, 2024

Gradient Team

Training large amounts of LLMs at once can be complex and can take thousands or millions of GPU hours - constraining your resources. Take a look at how Gradient thinks about infrastructure and efficiency optimizations as we dive into our own proprietary distributed training platform focusing on orchestration and fault tolerance.

Training large amounts of LLMs at once can be complex and can take thousands or millions of GPU hours - constraining your resources. Take a look at how Gradient thinks about infrastructure and efficiency optimizations as we dive into our own proprietary distributed training platform focusing on orchestration and fault tolerance.

Training large amounts of LLMs at once can be complex and can take thousands or millions of GPU hours - constraining your resources. Take a look at how Gradient thinks about infrastructure and efficiency optimizations as we dive into our own proprietary distributed training platform focusing on orchestration and fault tolerance.

Quick Primer

While OpenAI’s GPT-4 and Mistral’s Mixtral-8x7B have been widely adopted for AI use cases, training custom Large Language Models (LLMs) enables enterprises to tailor a model to domain-specific needs. Generalist LLMs are great for prototypes and general capabilities, but we find that generalist LLMs underperform on industry-specific tasks in production settings. Using the finance industry as an example, current models are often suboptimal in tabular understanding, tasks involving domain knowledge, and others.

At Gradient, we’ve trained a suite of domain-specific models for healthcare and financial services. (Read the research innovations from a previous version of our finance LLM in our white paper here.) Within our suite of finance models specifically, we further pretrained Llama-2-70b with 1.1B tokens of curated and preprocessed financial data, requiring distributed training across multiple nodes. As we continuously train more large-scale models, our in-house distributed training platform plays a huge role.

In this blog post series, we dive deeper into the infrastructure and efficiency optimizations of our own proprietary distributed training platform.

Distributed Training

Powerful LLMs typically consist of billions of parameters trained over huge datasets, which can take thousands or millions of GPU hours. (Llama-2-70b took over 1.7M GPU hours to train!) LLMs require a ton of GPU vRAM to load the model weights – a 70B model in BF16 would require close to 140GB of memory. On top of that, additional memory is then needed for the optimizer states and gradients. Training over a large dataset also requires significant amounts of GPU compute power (FLOPs) to process.

Distributed training enables the user to split workloads across multiple GPUs to drastically accelerate the training process. Typically each server contains 8 GPUs. One could simply distribute training across the 8 GPUs in one server – multi-GPU, single-node training. However, training 70B+ parameter models on billions of tokens of data requires thousands of GPU hours, so multi-node training, across multiple servers, is needed.

Orchestration

The first step in setting up our in-house distributed training platform was to simplify the orchestration setup between nodes. Axolotl, a popular open-source training framework, supports distributed training via its integrations with various distributed training libraries, such as HuggingFace Accelerate and DeepSpeed. However, tuning parameters across the different libraries is tricky since many can interfere with each other in unexpected ways during training (e.g. resulting in NaN weights), and there are dozens of parameters to configure.

Through deep diving into the open-source code (and a few painful training errors later), our engineering team mapped out the available sets of parameters to expose in our training platform for each of the different kinds of distributed parallelism (data, model, tensor, or pipeline parallelism). Now configuring distributed training is guesswork free for our researchers.

Distributed Orchestration Refresher

A quick refresher on how distributed training is set up – Distributing the work across multiple nodes means that the nodes need to be able to communicate with each other. With PyTorch, a “master node” is established to coordinate the distributed operations.

  • One node is spun up first as the “master node”, with its URL provided to all the other nodes.

  • Each node runs a Kubernetes (k8s) job that executes its own training process.

  • All nodes communicate with the master node to synchronize distributed training values (e.g. optimizer states, gradients, etc) and gather system-level information. Direct peer-to-peer communication is also leveraged.

Fault Tolerance

Large distributed training jobs often last days or weeks across dozens to hundreds of nodes. Node failure can be extremely costly if it interrupts the job and results in progress being lost.

The underlying hardware itself needs to be reliable. Another risk factor is the utilization of spot GPU instances from hyperscalers. Many companies resort to spot instances either due to cost or GPU quota limitations. Spot instances are typically much lower in price and easier to acquire but access can be removed without warning (and any processes on the instance terminated) if another user bids a higher price and wins the capacity.

Our platform leveraged two techniques together to ensure distributed training jobs are fault tolerant: checkpointing and automated recovery.

Checkpointing

Checkpoints are saved intermediate states (such as model weights and number of steps or epochs run) during the model training. Axolotl provides the ability to log regular checkpoints to Weights and Biases (wandb) throughout the training process. This then enables the user to load the intermediate model weights as starting points to resume training.

Automated recovery

While Axolotl provides the ability to easily configure and save regular checkpoints, recovering after a node failure requires monitoring and resuming the training again using the checkpointed model.

Gradient’s in-house distributed training platform optimizes the recovery process in a few different ways.

  • An independent job is kicked off in tandem with the training itself. This job automates the monitoring and graceful resumption process to minimize training downtime. If training has failed, the job will automatically load the most recent checkpoint from wandb and restart the training.

  • The tokenized training data is cached at the beginning for faster startup times.

  • Use of high-performance mounted volumes for checkpoints and tokenized data reduces data transfer time, which can be non-trivial when dealing with billions of tokens of training data.

Conclusion

With orchestration and fault tolerance, the main operational aspects of multi-node training are handled by our in-house distributed training platform, making it zero effort for each additional model our research team experiments with. These specific implementations are also agnostic to model architecture, enabling us to train large-scale models based on any underlying architecture, from transformers to state space models to MoEs.

Coming up in Part 2 of this series, we’ll bring you on our journey into the efficiency optimizations we explored and implemented.