1 all replica loading Ioq9nnd
Imagine training a new AI / ML model like Gemma 3 or Llama 3.3 across hundreds of powerful accelerators like TPUs or GPUs to achieve a scientific breakthrough. You might have a team of powerful computers working in sync, constantly learning and refining. But every so often, they need to save their progress — a “checkpoint” — and then pick up from this known state in the case of an interruption.
With the traditional approach, each device independently reads the same checkpoint from a central storage like Google Cloud Storage (GCS), resulting in duplicate data transfers. When GCS bandwidth of a project is fully utilized, it causes significant delays before training even begins. This bottleneck isn’t just an inconvenience; it cuts productivity and increases cost (remember that you’re paying for all the accelerators that are waiting while the checkpoint is being saved or restored).
Today, we’ll explore how you can deliver efficient checkpoint loading and unlock faster, more cost-efficient, and more impactful AI development.
Google engineers found a way to optimize checkpoint loading with Orbax and JAX so developers can get back to the work that matters. Orbax is a toolkit designed to streamline the process of saving and loading checkpoints for machine learning models, particularly those built using the JAX. JAX is a high-performance Python library that focuses on high-performance numerical computing.
The core idea behind our solution in Orbax is simple yet powerful: instead of every device fetching the entire checkpoint, only one replica (copy) downloads it. This replica then broadcasts the checkpoint data to all other replicas in the training setup. The approach leverages high-speed interconnects like DCN (Data Center Network) and inter-chip interconnect (ICI) to facilitate rapid data transfer between machines.
There’s a trade-off between the time saved reading data from external storage like GCS and the time spent replicating/broadcasting. Broadcasting requires compiling Python code into low-level instructions for your hardware (like GPUs or TPUs), which takes time. However, when data reading bandwidth becomes a bottleneck, this compilation time becomes negligible compared to the time wasted by redundant data reads. This is when reading a checkpoint on one device then broadcasting that checkpoint to other devices significantly speeds up the process. The solution achieves this by quickly reading the checkpoint from GCS on a single replica and efficiently distributing the data to other devices.
To evaluate the efficiency of our optimized checkpoint loading technique, we ran benchmarks on different hardware devices.
We saw a 6.8x speedup on a CPU cluster with 2048 VMs (32 slices of 64 `n2-standard-32` VMs). Note: Due to observed performance fluctuations during CPU testing at this scale, the reported CPU benchmark is an average of several runs.
On a TPU cluster with 13 slices of v5e-256 machines, checkpoint loading completed compared to frequent timeouts with the standard approach. On a smaller scale, 5 slices of v5e-256, we observed over 2.6 times speedup. Results are shown in the table below.
Device Type | Number of VMs | Model Size | Speedup |
TPU | 5 x v5e-256 = 320 | 80B | 2.6 |
TPU | 13 x v5e-256 = 832 | 80B | Standard approach frequently failing |
CPU | 32 x 64 = 2048 | 78B | 6.8 |
Although this optimized checkpoint loading was effective, it required additional High Bandwidth Memory (HBM) for the broadcasting function. In some cases, this meant only a third of the available HBM could be used for the checkpoint data and that led to memory challenges working with large-scale checkpoint loading and emergency checkpointing.
To solve this, Orbax got a smart upgrade. Instead of broadcasting the entire checkpoint at once, it now breaks it into smaller, manageable chunks and broadcasts them sequentially. You can either tell Orbax how much memory it’s allowed to use for broadcasting, or it can estimate the HBM available for broadcasting based on the accelerator and model size.
This approach offers two important benefits:
Flexibility: Users can tailor broadcasting to their specific hardware and model sizes.
Minimized OOM errors: The solution respects user-defined memory limits, eliminating the risk of out-of-memory (OOM) errors during broadcasting.
The user can optionally specify a memory_limit
parameter. If not provided, a utility function determines the device type and uses a predefined mapping to estimate the HBM associated with that device. The memory limit is then estimated as HBM - 2 * pytree_memory_per_device
, incorporating a scaling factor to prevent OOM errors in edge cases.
This enhancement in Orbax ensures efficient memory utilization during checkpoint loading, particularly for large-scale models and in scenarios where memory constraints are critical. It empowers users to optimize the broadcasting process based on their specific hardware and model requirements, further enhancing the performance and reliability of Orbax’s checkpoint loading capabilities.
To illustrate how to use this optimized checkpoint loading feature, let’s refer to an example from MaxText. In this example, you enable the feature by setting enable_single_replica_ckpt_restoring=True
.
Let’s break down this example in reverse order to understand the process. In Orbax, restoring a checkpoint involves providing restore_args to the checkpoint manager, e.g:
Here, pspec
tells Orbax how to arrange the pieces of your model across all the available devices in your setup (the global mesh). For example, given 5 slices of v5e-256 machines (that’s 1,280 TPU v5e chips in total), pspec
may specify that the 80B parameter model is sharded on a single slice of v5e-256 and then replicated on each of the other slices. A simple example can be found here.
To restore using the optimized checkpoint loading method, we first need to register the SingleReplicaArrayHandler
, optionally specifying the memory limit (in bytes) used for broadcasting:
Next, we need to provide restore_args
for each pytree element as follows:
Here, sharding
specifies the sharding on the global mesh, while single_replica_sharding
represents the sharding on a replica slice to which the current host belongs. These can be obtained using the following utility functions:
Thus, setting enable_single_replica_ckpt_restoring=True
triggers the creation and registration of a SingleReplicaArrayHandler
which will be responsible for restoring the checkpoint. This example also demonstrated how to obtain the necessary single_replica_shardings
for each host, a crucial requirement for this method.
Important note: Currently, the broadcasting functionality works only when replica_axis_index=0
.
Unlock the full potential of ML models–learn more about Orbax and its features here. Compare checkpoint loading times with the new method to see if it helps. To get started:
Follow the instructions to install xpk and MaxText. XPK, the Accelerated Processing Kit helps developers orchestrate ML workloads with GKE accelerator clusters. MaxText offers highly scalable reference LLM implementations with JAX.
Train and save a checkpoint: Use your existing setup or adapt the example below.
Benchmark by comparing checkpoint loading times with and without enable_single_replica_ckpt_restoring=true
.
Example command (adjust for your needs):
In this example, we adjust some parameters like `base_num_query_heads` to increase the model’s size. We set the number of training steps to 5 and the checkpointing frequency to every 2 steps so we can quickly observe the effects of our checkpointing changes. Note that you may need to run this command twice if no checkpoint has been saved yet. To explore the full range of configuration options, refer to the available configuration files or check out the code.
This work is a collaboration between multiple teams within Google. We especially want to thank Colin Gaffney, Rafi Witten, and Yash Katariya for their invaluable contributions. We also extend our sincere appreciation to Roshani Narasimhan, Matt Davidow, Vaibhav Singh, Niranjan Hira, Shivani Matta and Andi Gavrilescu for their guidance and support throughout this project.
Source : https://github.com/Wan-Video/Wan2.1/issues/264#issuecomment-2747490626 submitted by /u/CeFurkan [link] [comments]
Large language models (LLMs) are changing the way we think about AI.
Roger Quero, Liwei Guo, Jeff Watts, Joseph McCormick, Agata Opalach, Anush MoorthyWe are excited to announce…
SOCs are seeing false positive rates drop 70%, while shaving 40+ hrs a week of…
Trump officials accidentally invited the editor-in-chief of The Atlantic to their Signal group chat. Hours…
A research team from the Skoltech AI Center and Samara University have developed a system…