Foundation models are large deep learning models trained on a vast quantity of data at scale. They can be further fine-tuned to perform a variety of downstream tasks and form the core backbone of enabling several AI applications. The most prominent category is large-language models (LLM), including auto-regressive models such as GPT variants trained to complete natural text. LLMs typically contain billions of parameters, making them rarely fit on one single accelerator, and require model parallelism techniques. Another category is diffusion models, notably Stable Diffusion, that has pushed AI image generation to an unprecedented milestone where remarkable visuals can be generated from a simple text description. Diffusion models are typically much smaller than LLMs and distributed training remains to play a critical role in facilitating development.
SageMaker model parallel (SMP) library is a large-model training solution available on Amazon SageMaker platform. It can be integrated with PyTorch models to easily apply a range of state-of-the-art large-model distributed training techniques to train at scale. Earlier this year, SMP launched sharded data parallelism, a distributed training technique powered by Amazon in-house MiCS technology under the hood. Sharded data parallel shards model parameters, gradients, and optimizer states across data-parallel workers. MiCS performs a number of optimizations including scale-aware partitioning to provide near-linear scalability. In Train gigantic models with near-linear scaling using sharded data parallelism, we shared that sharded data parallel in SMP achieved 39.7% speed up compared to DeepSpeed ZeRO-3 on a 30B parameter GPT-2 model with sequence length 2048.
To help our customers further minimize training costs and accelerate time-to-market, we are thrilled to introduce two new performance improvements in SageMaker model parallel — SMDDP Collectives and FlashAttention. SMDDP Collectives is the most performant collective library on AWS infrastructure for large model training offered by SageMaker distributed data parallel library. FlashAttention is introduced in Dao et al., which re-implements the attention mechanism in an IO-aware manner, reducing the memory bandwidth requirement and saving on attention speed and memory footprint. These two components collectively push our sharded data parallel technique to be 30.58% faster when training a 100B parameter GPT-NeoX model on 32 p4d.24xlarge instances. For customers who are already using sharded data parallel on supported models, no code changes are necessary to benefit from the performance boost offered by these latest features. Stability AI, the inventor of the Stable Diffusion family of models that showed unparalleled image generation abilities, chose to use SMP to build foundation models. With SMP, Stability AI achieved 163 TFLOPs per GPU for a 13B-parameter GPT-NeoX on 32 p4d.24xlarge instances, a 58% speed up compared to DeepSpeed. You can learn more about Stability AI’s mission and partnership with AWS in the talk of Stability AI CEO at AWS re:Invent 2022 or in this blog post.
“Our mission at Stability AI is to build the foundation to activate humanity’s potential through AI. To achieve this mission, we need to efficiently train open-source foundation models on hundreds of accelerated compute instances. We rely on SageMaker and its distributed training libraries to optimize performance and implement state-of-the-art strategies to shard models and data across our training cluster. These optimizations reduce our training costs, help us meet customer needs faster, and speed up the development of new models.”
— Emad Mostaque, Founder and CEO of Stability AI.
In this blog post, we’ll first present our latest performance improvements in the SageMaker model parallel library. Then, we’ll revisit how to train foundational models using sharded data parallel. Finally, we’ll benchmark performance of 13B, 50B, and 100B parameter auto-regressive models and wrap up with future work.
Starting from AWS Deep Learning Containers (DLC) PyTorch 1.12.1, SageMaker model parallel library v1.13 comes with the following two new components that are critical in improving training performance. They are currently available on ml.p4d.24xlarge instance with Elastic Fabric Adapter (EFA) enabled:
In sharded data parallel, since only a shard of the model state is present on a GPU, an AllGather collective is needed to gather the full set of parameters from across all GPUs in the sharding group during forward or backward pass computations. In the previous versions of SageMaker model parallel, we used NVIDIA Collective Communications Library (NCCL) for these collectives. However, NCCL is a general purpose collective communications library not designed for AWS infrastructure, which leads to sub-optimal performance even with EFA enabled.
Previously, we had developed the SMDDP Collectives library that provided an AWS-optimized implementation of the All-Reduce collective to speedup performance of pure data parallel training. To improve the performance of large model training with sharded data parallelism, we expanded the SMDDP Collectives library to include an optimized implementation of the AllGather collective. The key advantage of SMDDP Collectives AllGather is that it adopts an all-to-all-type communication pattern for inter-node communication, enabling our collective to have high-throughput and be less latency-sensitive. In addition, our AllGather collective offloads the communication-related processing to the CPU, thereby freeing up valuable GPU cycles for gradient computation, leading to significant performance improvement especially on large models.
In modern transformer architecture, one of the largest sources of memory consumption is the activation footprint in the self-attention layer. This is because each attention head computes an SxS attention matrix for each input, where S is the sequence length, and this matrix goes through several operations, such as dropout, softmax, and matrix multiplication, with each intermediate output requiring memory space for use in back-propagation.
FlashAttention (Dao et al.) is a recent innovation from HazyResearch in Stanford that re-implements the self-attention mechanism in an I/O-aware manner. The main insight behind FlashAttention is that the self-attention mechanism is bottlenecked by memory bandwidth to and from GPU high bandwidth memory (HBM). This means that the self-attention layer can be computed in chunks across the sequence dimension, with each chunk going through the entire self-attention pipeline at a time. The intermediate results for a chunk are stored at the high-bandwidth SRAM, avoiding the expensive round-trip to the HBM for every iteration. Although a naive implementation would run into the issue of the cross-chunk dependency at the softmax layer, FlashAttention introduces a clever implementation that side-steps this dependency. Combined with re-computation in backward pass, FlashAttention results in substantial memory savings and performance improvement (25% faster training for GPT-NeoX 13B over 16 p4d nodes), due to avoidance of the HBM round-trip and storage of SxS matrices. You can find visuals and more explanations in HazyResearch’s FlashAttention repository.
To train foundation models with SMP powered by SMDDP Collectives, there’s no additional changes required in your sharded data parallel training jobs. If you’re new to using sharded data parallel, follow this complete tutorial notebook and blog post that will walk you through the entire process, from data processing, defining and submitting training jobs, to monitoring training logs. A ready-to-use training script for GPT-2 model can be found at train_gpt_simple.py
. For training a different model type, you can follow the API document to learn about how to apply SMP APIs.
We highlight the key hyperparameters in the PyTorch Estimator of a sharded data parallel training job as below. The hyperparameter ddp_dist_backend
in smp_options
now has a new option, "auto"
, as its default value. With "auto"
, SMP will use AWS-optimized AllGather for sharded data parallelism jobs and fall back to NCCL otherwise. You can refer to this document for supported configurations. If you want to run sharded data parallel in SMP specifically with NCCL as the communication backend of choice, you can set “ddp_dist_backend"
to "nccl"
in smp_options
.
With the latest SMPv1.13 release, the sharded data parallel training technique supports FlashAttention for popular models including BERT, RoBERTa, GPT-2, GPT-J, GPT-Neo and GPT-NeoX out-of-the-box. This is enabled by passing tensor_parallelism=True
during model creation without setting tensor_parallel_degree
. You can find an example in the same training script train_gpt_simple.py
.
We benchmarked sharded data parallelism in the SageMaker model parallel library on three different scales of models to understand how the two new features, FlashAttention and AWS-optimized AllGather, contribute to performance improvement. Placement group is not required to reproduce these benchmarks on SageMaker.
In this setting, we focus on understanding the performance gain contributed by FlashAttention and we leave AWS-optimized AllGather out of the picture. Using flash attention saves substantial GPU memory, which helps us increase batch size or reduce sharding degree, thereby improving performance. As the below results show, we observed an average of about 20.4% speedup in SMP with flash attention for 13B parameter GPT-NeoX model on various configurations across 16-64 p4d nodes. Memory usage during standard attention computation scales in a quadratic manner with an increase in sequence length, but FlashAttention has memory usage linear in sequence length. Hence FlashAttention is even more helpful as sequence length increases and makes it possible to use larger sequence lengths. Being memory-efficient without trading off model quality, FlashAttention has gained traction quickly in the large model training community in the past months including integration with Hugging Face Diffusers and Mosaic ML.
Configuration | Performance | ||||
Model/Training | Cluster | SMP | Without FlashAttention (TFLOPs/GPU) | With FlashAttention (TFLOPs/GPU) | % Speedup |
13B GPT-NeoX Seq length: 2048 Global batch size: 1024 FP16 | 16 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:64 gradient_accumulation: 1 | 130 | 159 | 22.31 |
13B GPT-NeoX Seq length: 2048 Global batch size: 2048 FP16 | 32 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:64 gradient_accumulation: 1 | 131 | 157 | 19.85 |
13B GPT-NeoX Seq length: 2048 Global batch size: 4096 FP16 | 64 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:64 gradient_accumulation: 1 | 131 | 156 | 19.08 |
Now, we look at how AWS-optimized AllGather from SMDDP Collectives speedup large model training with SMP. We benchmark a 50B-parameter Bloom model and compare the performance with and without AWS-optimized AllGather collective. We observe that SMDDP collectives speeds up model training by upto 40% across 32 nodes to 64 nodes training jobs. SMDDP collectives help achieve better performance due to better utilization of the 400 Gbps network bandwidth available with p4d.24xlarge instances. This coupled with the design choice to offload communication-related processing to the CPU, helps achieve good compute-to-network overlap leading to optimized performance. Compute-to-network overlap especially becomes important in large models since the size of data communicated across nodes scales linearly with an increase in the model size.
Configuration | Performance | ||||
Model/Training | Cluster | SMP | Without AWS-optimized AllGather (TFLOPs/GPU) | With AWS-optimized AllGather (TFLOPs/GPU) | % Speedup |
50B Bloom Seq length: 2048 Global batch size: 2048 BF16 | 32 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:128 gradient_accumulation: 1 | 102 | 143 | 40.20 |
50B Bloom Seq length: 2048 Global batch size: 4096 BF16 | 64 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:128 gradient_accumulation: 1 | 101 | 140 | 38.61 |
Finally, we benchmark SMP with both of the latest features enabled. It shows that this new release of SMP v1.13 is 30% faster than the previous version on a 100B-parameter GPT-NeoX model.
Configuration | Performance | ||||
Model/Training | Cluster | SMP | Without FlashAttention and without AWS-optimized AllGather (TFLOPs/GPU) | With FlashAttention + AWS-optimized AllGather (TFLOPs/GPU) | % Speedup |
100B GPT-NeoX Seq length: 2048 Global batch size: 2048 FP16 | 32 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:256 offload_activations
| 121 | 158 | 30.58 |
100B GPT-NeoX Seq length: 2048 Global batch size: 4096 FP16 | 64 p4d.24xlarge nodes | Activation checkpointing sharded_data_parallel_degree:256 offload_activations
| 122 | 158 | 29.51 |
For future work, we’ll be working on supporting an AWS-optimized Reduce-Scatter in SMDDP Collectives. The Reduce-Scatter collective is critical in averaging and sharding gradients computed in the backward pass. We expect this to further speed up SMP library in the future releases.
In this post, we discuss the two latest performance improvements for sharded data parallel technique in SageMaker model parallel library. LLMs show great promise in improving the quality and re-usability of ML models. AWS teams are working closely with customers to keep reducing their training costs and time-to-market. You can find more SageMaker model parallel examples in Amazon SageMaker Examples GitHub repo or attend our next distributed training workshops. If you are interested in speeding up large model training, check out these features and let us know what you build!
Podcasts are a fun and easy way to learn about machine learning.
TL;DR We asked o1 to share its thoughts on our recent LNM/LMM post. https://www.artificial-intelligence.show/the-ai-podcast/o1s-thoughts-on-lnms-and-lmms What…
Palantir and Grafana Labs’ Strategic PartnershipIntroductionIn today’s rapidly evolving technological landscape, government agencies face the…
Amazon SageMaker Pipelines includes features that allow you to streamline and automate machine learning (ML)…
When it comes to AI, large language models (LLMs) and machine learning (ML) are taking…
Cohere's Command R7B uses RAG, features a context length of 128K, supports 23 languages and…