This post is co-written with Meta’s PyTorch team.
In today’s rapidly evolving AI landscape, businesses are constantly seeking ways to use advanced large language models (LLMs) for their specific needs. Although foundation models (FMs) offer impressive out-of-the-box capabilities, true competitive advantage often lies in deep model customization through fine-tuning. However, fine-tuning LLMs for complex tasks typically requires advanced AI expertise to align and optimize them effectively. Recognizing this challenge, Meta developed torchtune, a PyTorch-native library that simplifies authoring, fine-tuning, and experimenting with LLMs, making it more accessible to a broader range of users and applications.
In this post, AWS collaborates with Meta’s PyTorch team to showcase how you can use Meta’s torchtune library to fine-tune Meta Llama-like architectures while using a fully-managed environment provided by Amazon SageMaker Training. We demonstrate this through a step-by-step implementation of model fine-tuning, inference, quantization, and evaluation. We perform the steps on a Meta Llama 3.1 8B model utilizing the LoRA fine-tuning strategy on a single p4d.24xlarge worker node (providing 8 Nvidia A100 GPUs).
Before we dive into the step-by-step guide, we first explored the performance of our technical stack by fine-tuning a Meta Llama 3.1 8B model across various configurations and instance types.
As can be seen in the following chart, we found that a single p4d.24xlarge delivers 70% higher performance than two g5.48xlarge instances (each with 8 NVIDIA A10 GPUs) at almost 47% reduced price. We therefore have optimized the example in this post for a p4d.24xlarge configuration. However, you could use the same code to run single-node or multi-node training on different instance configurations by changing the parameters passed to the SageMaker estimator. You could further optimize the time for training in the following graph by using a SageMaker managed warm pool and accessing pre-downloaded models using Amazon Elastic File System (Amazon EFS).
Generative AI models offer many promising business use cases. However, to maintain factual accuracy and relevance of these LLMs to specific business domains, fine-tuning is required. Due to the growing number of model parameters and the increasing context length of modern LLMs, this process is memory intensive. To address these challenges, fine-tuning strategies like LoRA (Low-Rank Adaptation) and QLoRA (Quantized Low-Rank Adaptation) limit the number of trainable parameters by adding low-rank parallel structures to the transformer layers. This enables you to train LLMs even on systems with low memory availability like commodity GPUs. However, this leads to an increased complexity because new dependencies have to be handled and training recipes and hyperparameters need to be adapted to the new techniques.
What businesses need today is user-friendly training recipes for these popular fine-tuning techniques, which provide abstractions to the end-to-end tuning process, addressing the common pitfalls in the most opinionated way.
torchtune is a PyTorch-native library that aims to democratize and streamline the fine-tuning process for LLMs. By doing so, it makes it straightforward for researchers, developers, and organizations to adapt these powerful LLMs to their specific needs and constraints. It provides training recipes for a variety of fine-tuning techniques, which can be configured through YAML files. The recipes implement common fine-tuning methods (full-weight, LoRA, QLoRA) as well as other common tasks like inference and evaluation. They automatically apply a set of important features (FSDP, activation checkpointing, gradient accumulation, mixed precision) and are specific to a given model family (such as Meta Llama 3/3.1 or Mistral) as well as compute environment (single-node vs. multi-node).
Additionally, torchtune integrates with major libraries and frameworks like Hugging Face datasets, EleutherAI’s Eval Harness, and Weights & Biases. This helps address the requirements of the generative AI fine-tuning lifecycle, from data ingestion and multi-node fine-tuning to inference and evaluation. The following diagram shows a visualization of the steps we describe in this post.
Refer to the installation instructions and PyTorch documentation to learn more about torchtune and its concepts.
This post demonstrates the use of SageMaker Training for running torchtune recipes through task-specific training jobs on separate compute clusters. SageMaker Training is a comprehensive, fully managed ML service that enables scalable model training. It provides flexible compute resource selection, support for custom libraries, a pay-as-you-go pricing model, and self-healing capabilities. By managing workload orchestration, health checks, and infrastructure, SageMaker helps reduce training time and total cost of ownership.
The solution architecture incorporates the following key components to enhance security and efficiency in fine-tuning workflows:
The following diagram illustrates the solution architecture. Users initiate the process by calling the SageMaker control plane through APIs or command line interface (CLI) or using the SageMaker SDK for each individual step. In response, SageMaker spins up training jobs with the requested number and type of compute instances to run specific tasks. Each step defined in the diagram accesses torchtune recipes from an Amazon Simple Storage Service (Amazon S3) bucket and uses Amazon EFS to save and access model artifacts across different stages of the workflow.
By decoupling every torchtune step, we achieve a balance between flexibility and integration, allowing for both independent execution of steps and the potential for automating this process using seamless pipeline integration.
In this use case, we fine-tune a Meta Llama 3.1 8B model with LoRA. Subsequently, we run model inference, and optionally quantize and evaluate the model using torchtune and SageMaker Training.
Recipes, configs, datasets, and prompt templates are completely configurable and allow you to align torchtune to your requirements. To demonstrate this, we use a custom prompt template in this use case and combine it with the open source dataset Samsung/samsum from the Hugging Face hub.
We fine-tune the model using torchtune’s multi device LoRA recipe (lora_finetune_distributed) and use the SageMaker customized version of Meta Llama 3.1 8B default config (llama3_1/8B_lora).
You need to complete the following prerequisites before you can run the SageMaker Jupyter notebooks:
The following figure illustrates the steps in our workflow.
You can look up the torchtune configs for your use case by directly using the tune CLI.For this post, we provide modified config files aligned with SageMaker directory path’s structure:
torchtune uses these config files to select and configure the components (think models and tokenizers) during the execution of the recipes.
As part of our example, we create a custom container to provide custom libraries like torch nightlies and torchtune. Complete the following steps:
Run the 1_build_container.ipynb
notebook until the following command to push this file to your ECR repository:
sm-docker
is a CLI tool designed for building Docker images in SageMaker Studio using AWS CodeBuild. We install the library as part of the notebook.
Next, we will run the 2_torchtune-llama3_1.ipynb
notebook for all fine-tuning workflow tasks.
For every task, we review three artifacts:
In this section, we walk through the steps to run and monitor the fine-tuning task.
The following code shows a shortened torchtune recipe configuration highlighting a few key components of the file for a fine-tuning job:
We use Weights & Biases for logging and monitoring our training jobs, which helps us track our model’s performance:
Next, we define a SageMaker task that will be passed to our utility function in the script create_pytorch_estimator
. This script creates the PyTorch estimator with all the defined parameters.
In the task, we use the lora_finetune_distributed
torchrun recipe with config config-l3.1-8b-lora.yaml
on an ml.p4d.24xlarge instance. Make sure you download the base model from Hugging Face before it’s fine-tuned using the use_downloaded_model
parameter. The image_uri
parameter defines the URI of the custom container.
To create and run the task, run the following code:
The following code shows the task output and reported status:
The final model is saved to Amazon EFS, which makes it available without download time penalties.
You can monitor various metrics such as loss and learning rate for your training run through the Weights & Biases dashboard. The following figures show the results of the training run where we tracked GPU utilization, GPU memory utilization, and loss curve.
For the following graph, to optimize memory usage, torchtune uses only rank 0 to initially load the model into CPU memory. rank 0 therefore will be responsible for loading the model weights from the checkpoint.
The example is optimized to use GPU memory to its maximum capacity. Increasing the batch size further will lead to CUDA out-of-memory (OOM) errors.
The run took about 13 minutes to complete for one epoch, resulting in the loss curve shown in the following graph.
In the next step, we use the previously fine-tuned model weights to generate the answer to a sample prompt and compare it to the base model.
The following code shows the configuration of the generate recipe config_l3.1_8b_gen_trained.yaml
. The following are key parameters:
meta_model_0.pt
from Amazon EFSNext, we configure the SageMaker task to run on a single ml.g5.2xlarge instance:
In the output of the SageMaker task, we see the model summary output and some stats like tokens per second:
We can generate inference from the original model using the original model artifact consolidated.00.pth
:
The following code shows the comparison output from the base model run with the SageMaker task (generate_inference_on_original). We can see that the fine-tuned model is performing subjectively better than the base model by also mentioning that Amanda baked the cookies.
To speed up the inference and decrease the model artifact size, we can apply post-training quantization. torchtune relies on torchao for post-training quantization.
We configure the recipe to use Int8DynActInt4WeightQuantizer, which refers to int8 dynamic per token activation quantization combined with int4 grouped per axis weight quantization. For more details, refer to the torchao implementation.
We again use a single ml.g5.2xlarge instance and use SageMaker warm pool configuration to speed up the spin-up time for the compute nodes:
In the output, we see the location of the quantized model and how much memory we saved due to the process:
You can run model inference on the quantized model meta_model_0-8da4w.pt by updating the inference-specific configurations.
Finally, let’s evaluate our fine-tuned model in an objective manner by running an evaluation on the validation portion of our dataset.
torchtune integrates with EleutherAI’s evaluation harness and provides the eleuther_eval recipe.
For our evaluation, we use a custom task for the evaluation harness to evaluate the dialogue summarizations using the rouge metrics.
The recipe configuration points the evaluation harness to our custom evaluation task:
The following code is the SageMaker task that we run on a single ml.p4d.24xlarge instance:
Run the model evaluation on ml.p4d.24xlarge:
The following tables show the task output for the fine-tuned model as well as the base model.
The following output is for the fine-tuned model.
Tasks | Version | Filter | n-shot | Metric | Direction | Value | ± | Stderr |
samsum | 2 | none | None | rouge1 | ↑ | 45.8661 | ± | N/A |
none | None | rouge2 | ↑ | 23.6071 | ± | N/A | ||
none | None | rougeL | ↑ | 37.1828 | ± | N/A |
The following output is for the base model.
Tasks | Version | Filter | n-shot | Metric | Direction | Value | ± | Stderr |
samsum | 2 | none | None | rouge1 | ↑ | 33.6109 | ± | N/A |
none | None | rouge2 | ↑ | 13.0929 | ± | N/A | ||
none | None | rougeL | ↑ | 26.2371 | ± | N/A |
Our fine-tuned model achieves an improvement of approximately 46% on the summarization task, which is approximately 12 points better than the baseline.
Complete the following steps to clean up your resources:
In this post, we discussed how you can fine-tune Meta Llama-like architectures using various fine-tuning strategies on your preferred compute and libraries, using custom dataset prompt templates with torchtune and SageMaker. This architecture gives you a flexible way of running fine-tuning jobs that are optimized for GPU memory and performance. We demonstrated this through fine-tuning a Meta Llama3.1 model using P4 and G5 instances on SageMaker and used observability tools like Weights & Biases to monitor loss curve, as well as CPU and GPU utilization.
We encourage you to use SageMaker training capabilities and Meta’s torchtune library to fine-tune Meta Llama-like architectures for your specific business use cases. To stay informed about upcoming releases and new features, refer to the torchtune GitHub repo and the official Amazon SageMaker training documentation .
Special thanks to Kartikay Khandelwal (Software Engineer at Meta), Eli Uriegas (Engineering Manager at Meta), Raj Devnath (Sr. Product Manager Technical at AWS) and Arun Kumar Lokanatha (Sr. ML Solution Architect at AWS) for their support to the launch of this post.
Neural contextual biasing allows speech recognition models to leverage contextually relevant information, leading to improved…
Today we are announcing the general availability of Amazon Bedrock Prompt Management, with new features…
The rise of Natural Language Processing (NLP) combined with traditional Structured Query Language (SQL) has…
Through his wealth and cultural influence, Elon Musk undoubtedly strengthened the Trump campaign. WIRED unpacks…
The growing use of artificial intelligence (AI)-based models is placing greater demands on the electronics…
This post is co-written with Steven Craig from Hearst. To maintain their competitive edge, organizations…