Maximize performance and reduce your deep learning training cost with AWS Trainium and Amazon SageMaker

Today, tens of thousands of customers are building, training, and deploying machine learning (ML) models using Amazon SageMaker to power applications that have the potential to reinvent their businesses and customer experiences. These ML models have been increasing in size and complexity over the last few years, which has led to state-of-the-art accuracies across a range of tasks and also pushing the time to train from days to weeks. As a result, customers must scale their models across hundreds to thousands of accelerators, which makes them more expensive to train.

SageMaker is a fully managed ML service that helps developers and data scientists easily build, train, and deploy ML models. SageMaker already provides the broadest and deepest choice of compute offerings featuring hardware accelerators for ML training, including G5 (Nvidia A10G) instances and P4d (Nvidia A100) instances.

Growing compute requirements calls for faster and more cost-effective processing power. To further reduce model training times and enable ML practitioners to iterate faster, AWS has been innovating across chips, servers, and data center connectivity. The new Trn1 instances powered by AWS Trainium chips offer the best price-performance and the fastest ML model training on AWS, providing up to 50% lower cost to train deep learning models over comparable GPU-based instances without any drop in accuracy.

In this post, we show how you can maximize your performance and reduce cost using Trn1 instances with SageMaker.

Solution overview

SageMaker training jobs support ml.trn1 instances, powered by Trainium chips, which are purpose built for high-performance ML training applications in the cloud. You can use ml.trn1 instances on SageMaker to train natural language processing (NLP), computer vision, and recommender models across a broad set of applications, such as speech recognition, recommendation, fraud detection, image and video classification, and forecasting. The ml.trn1 instances feature up to 16 Trainium chips, which is a second-generation ML chip built by AWS after AWS Inferentia. ml.trn1 instances are the first Amazon Elastic Compute Cloud (Amazon EC2) instances with up to 800 Gbps of Elastic Fabric Adapter (EFA) network bandwidth. For efficient data and model parallelism, each ml.trn1.32xl instance has 512 GB of high-bandwidth memory, delivers up to 3.4 petaflops of FP16/BF16 compute power, and features NeuronLink, an intra-instance, high-bandwidth, nonblocking interconnect.

Trainium is available in two configurations and can be used in the US East (N. Virginia) and US West (Oregon) Regions.

The following table summarizes the features of the Trn1 instances.

Instance SizeTrainium
EFA and
trn1.2xlarge132832Up to 12.5No
trn1n.32xlarge (coming soon)165121285121600Yes

Let’s understand how to use Trainium with SageMaker with a simple example. We will train a text classification model with SageMaker training and PyTorch using the Hugging Face Transformers Library.

We use the Amazon Reviews dataset, which consists of reviews from The data spans a period of 18 years, comprising approximately 35 million reviews up to March 2013. Reviews include product and user information, ratings, and a plaintext review. The following code is an example from the AmazonPolarity test set:

title':'Great CD',
'content':"My lovely Pat has one of the GREAT voices of her generation. I have listened to this CD for YEARS and I still LOVE IT. When I'm in a good mood it makes me feel better. A bad mood just evaporates like sugar in the rain. This CD just oozes LIFE. Vocals are jusat STUUNNING and lyrics just kill. One of life's hidden gems. This is a desert isle CD in my book. Why she never made it big is just beyond me. Everytime I play this, no matter black, white, young, old, male, female EVERYBODY says one thing ""Who was that singing ?""",

For this post, we only use the content and label fields. The content field is a free text review, and the label field is a binary value containing 1 or 0 for positive or negative reviews, respectively.

For our algorithm, we use BERT, a transformer model pre-trained on a large corpus of English data in a self-supervised fashion. This model is primarily aimed at being fine-tuned on tasks that use the whole sentence (potentially masked) to make decisions, such as sequence classification, token classification, or question answering.

Implementation details

Let’s begin by taking a closer look at the different components involved in training the model:

  • AWS Trainium – At its core, each Trainium instance has Trainium devices built into it. Trn1.2xlarge has 1 Trainium device, and Trn1.32xlarge has 16 Trainium devices. Each Trainium device consists of compute (2 NeuronCore-v2), 32 GB of HBM device memory, and NeuronLink for fast inter-device communication. Each NeuronCore-v2 consists of a fully independent heterogenous compute unit with separate engines (Tensor/Vector/Scalar/GPSIMD). GPSIMD are fully programmable general-purpose processors that you can use to implement custom operators and run them directly on the NeuronCore engines.
  • Amazon SageMaker Training – SageMaker provides a fully managed training experience to easily train models without having to worry about infrastructure. When you use SageMaker Training, it runs everything needed for a training job, such as code, container, and data, in a compute infrastructure separate from the invocation environment. This allows us to run experiments in parallel and iterate fast. SageMaker provides a Python SDK to launch training jobs. The example in this post uses the SageMaker Python SDK to trigger the training job using Trainium.
  • AWS Neuron – Because Trainium NeuronCore has its own compute engine, we need a mechanism to compile our training code. The AWS Neuron compiler takes the code written in Pytorch/XLA and optimizes it to run on Neuron devices. The Neuron compiler is integrated as part of the Deep Learning Container we will use for training our model.
  • PyTorch/XLA – This Python package uses the XLA deep learning compiler to connect the PyTorch deep learning framework and cloud accelerators like Trainium. Building a new PyTorch network or converting an existing one to run on XLA devices requires only a few lines of XLA-specific code. We will see for our use case what changes we need to make.
  • Distributed training – To run the training efficiently on multiple NeuronCores, we need a mechanism to distribute the training into available NeuronCores. SageMaker supports torchrun with Trainium instances, which can be used to run multiple processes equivalent to the number of NeuronCores in the cluster. This is done by passing the distribution parameter to the SageMaker estimator as follows, which starts a data parallel distributed training where the same model is loaded into different NeuronCores that process separate data batches:
distribution={"torch_distributed": {"enabled": True}}

Script changes needed to run on Trainium

Let’s look at the code changes needed to adopt a regular GPU-based PyTorch script to run on Trainium. At a high level, we need to make the following changes:

  1. Replace GPU devices with Pytorch/XLA devices. Because we use torch distribution, we need to initialize the training with XLA as the device as follows:
    device = "xla"

  2. We use the PyTorch/XLA distributed backend to bridge the PyTorch distributed APIs to XLA communication semantics.
  3. We use PyTorch/XLA MpDeviceLoader for the data ingestion pipelines. MpDeviceLoader helps improve performance by overlapping three steps: tracing, compilation, and data batch loading to the device. We need to wrap the PyTorch dataloader with the MpDeviceDataLoader as follows:
    train_device_loader = pl.MpDeviceLoader(train_loader, "xla")

  4. Run the optimization step using the XLA-provided API as shown in the following code. This consolidates the gradients between cores and issues the XLA device step computation.

  5. Map CUDA APIs (if any) to generic PyTorch APIs.
  6. Replace CUDA fused optimizers (if any) with generic PyTorch alternatives.

The entire example, which trains a text classification model using SageMaker and Trainium, is available in the following GitHub repo. The notebook file Fine tune Transformers for building classification models using SageMaker and Trainium.ipynb is the entrypoint and contains step-by-step instructions to run the training.

Benchmark tests

In the test, we ran two training jobs: one on ml.trn1.32xlarge, and one on ml.p4d.24xlarge with the same batch size, training data, and other hyperparameters. During the training jobs, we measured the billable time of the SageMaker training jobs, and calculated the price-performance by multiplying the time required to run training jobs in hours by the price per hour for the instance type. We selected the best result for each instance type out of multiple jobs runs.

The following table summarizes our benchmark findings.

ModelInstance TypePrice (per node * hour)Throughput (iterations/sec)ValidationAccuracyBillable Time (sec)Training Cost in $
BERT base classificationml.trn1.32xlarge24.7256.640.984603341.47
BERT base classificationml.p4d.24xlarge37.695.440.984655368.6

The results showed that the Trainium instance costs less than the P4d instance, providing similar throughput and accuracy when training the same model with the same input data and training parameters. This means that the Trainium instance delivers better price-performance than GPU-based P4D instances. With a simple example like this, we can see Trainium offers about 22% faster time to train and up to 50% lower cost over P4d instances.

Deploy the trained model

After we train the model, we can deploy it to various instance types such as CPU, GPU, or AWS Inferentia. The key point to note is the trained model isn’t dependent on specialized hardware to deploy and make inference. SageMaker provides mechanisms to deploy a trained model using both real-time or batch mechanisms. The notebook example in the GitHub repo contains code to deploy the trained model as a real-time endpoint using an ml.c5.xlarge (CPU-based) instance.


In this post, we looked at how to use Trainium and SageMaker to quickly set up and train a classification model that gives up to 50% cost savings without compromising on accuracy. You can use Trainium for a wide range of use cases that involve pre-training or fine-tuning Transformer-based models. For more information about support of various model architectures, refer to Model Architecture Fit Guidelines.

About the Authors

Arun Kumar Lokanatha is a Senior ML Solutions Architect with the Amazon SageMaker Service team. He focuses on helping customers build, train, and migrate ML production workloads to SageMaker at scale. He specializes in Deep Learning especially in the area of NLP and CV. Outside of work, he enjoys Running and hiking.

Mark Yu is a Software Engineer in AWS SageMaker. He focuses on building large-scale distributed training systems, optimizing training performance, and developing high-performance ml training hardwares, including SageMaker trainium. Mark also has in-depth knowledge on the machine learning infrastructure optimization. In his spare time, he enjoys hiking, and running.

Omri Fuchs is a Software Development Manager at AWS SageMaker. He is the technical leader responsible for SageMaker training job platform, focusing on optimizing SageMaker training performance, and improving training experience. He has a passion for cutting-edge ML and AI technology. In his spare time, he likes cycling, and hiking.

Gal Oshri is a Senior Product Manager on the Amazon SageMaker team. He has 7 years of experience working on Machine Learning tools, frameworks, and services.