Training PyTorch on Cloud TPUs

Vaibhav Singh
8 min readDec 4, 2020

PyTorch/XLA on TPU

You have read PyTorch/XLA on TPU is GA now. Perhaps you have run through some example Colabs as well. But you want to understand how you can run your own PyTorch model on TPUs. You are in the right place. This article deconstructs a simple example in terms of key concepts. The goal is to equip you with necessary fundamentals to update your own training code to train on TPUs.

Basic Concepts

PyTorch/XLA introduces a new tensor abstraction called XLA Tensor. Any operation performed on XLA tensor is traced into an IR (Intermediate Representation) graph. In the simplest scenario the tracing continues until it encounters a `xm.mark_step()` call. At this stage the IR graph is converted into XLA HLO (High Level Optimization) format and sent to TPU. The HLO format is compiled by the TPU runtime and further optimized into a LLO (Low Level Optimized) format. From LLO format the TPU instructions are generated and executed on TPUs.

Other scenarios which have effects similar to mark_step are: 1) A tensor value is accessed on the CPU. (e.g. print statement or a .item() call, if (..).any() statement) 2) An operation for which there is no XLA lowering is unavailable. In both of the scenarios subgraph leading to input of the said op/tensor must be (materialized) i.e. all the steps from conversion to HLO to compilation and execution on TPU are performed. Additionally the output of the execution is fetched from TPU to CPU. In case the op has no XLA lowering, The IR graph is truncated at the op. The inputs leading to the op are calculated and the CPU implementation is used to determine the inputs for the truncated portion. Evidently, this has a negative impact on the training performance. And therefore, it is recommended to create a request for the XLA lowering on PyTorch/XLA github project.

Library Elements

Now let’s walk through a simple PyTorch code observing the elements to work with Cloud TPUs. We will examine the often used PyTorch modules, the model code, training loop and finally the multiprocessing spawn, and comment on the syntactical additions which enable the training to run on TPUs. The changes to enable sophisticated toolings such as fairseq may seem more complex than the example presented here. However, the concepts explained here are bijective wrt any implementation.

The following snippet shows the usual pytorch modules:

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist

From PyTorch XLA we will use the following three modules:

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

The xla_model provides the abstraction for a TPU core and the methods for common operations on the core. A TPU core is the smallest hardware unit usable in the programming model.

xla_multiprocessing modules provides methods for set up and execution of parallel operations on multiple cores. A TPU device consists of 8 TPU cores. xla_multiprocessing allows to work with either a single TPU core or all 8 cores.

parallel_loader module provides methods to augment PyTorch dataloders such that dataloading operation overlap with the execution on TPU cores in the data pipeline.

Please note that the modules mentioned here are the minimum set. For more details on PyTorch/XLA API please to the API guide.

Model Code

The following snippets shows a simple convolutional neural network model:

class ToyModel(nn.Module):
""" Toy Classifier """
def __init__(self):
super(ToyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.mp1 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1440, 10)
def forward(self, x):
x = nn.ReLU()(self.conv1(x))
x = self.mp1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.Softmax(dim=-1)(x)
return x

In this example, you notice that there are no dynamic elements in the model. In such scenarios no code changes are required in the model code to work with Cloud TPUs.

A dynamic element in the models is any construct or a group thereof which results in variations in the model graph seen across training steps. Variation in the model graph triggers XLA recompilation. For large scale models these recompilations are expensive and may result in less than optimal training speed.

Here is an example, where the model graph changes frequently. The scale of variation has been made synthetically large here to illustrate the concept.

class ToyModel (nn.Module):
""" Toy Classifier """
def __init__(self):
super(ToyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.mp1 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1440, 20)
self.fc2 = nn.Linear(20, 20)
self.fc3 = nn.Linear(20, 10)
self.layerdrop_prob = 0.5
def forward(self, x):
x = nn.ReLU()(self.conv1(x))
x = self.mp1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
for i in range(50):
if torch.rand(1,1) > self.layerdrop_prob :
x = self.fc2(x)
x = self.fc3(x)
x = nn.Softmax(dim=-1)(x)
return x

(*Layer drop, shown here, is a commonly used technique in models with multiple transformer layers. Understandably it may help the generalization performance of the model.[1])

The aspect to note here is the range of variation. In this example with loop-size 50, and a choice to keep or drop every layer, since the layers are identical we have 50 distinct graphs possible. If layers were distinct the space would be 2⁵⁰ possible graphs. This means possibly a new graph every training step. In such scenarios, reduction of the space of the variation is recommended.

PyTorch-XLA uses a compilation cache mechanism which prevents recompilation of graphs of shapes seen earlier. If the variation space is limited, this prevents the compilation penalties beyond the initial few steps. However if the variation space is too large and the compilation penalty is not contained.

It’s self-evident that there are more possible scenarios of dynamism in the model graph. In future articles we will explore the more of these scenarios and ways to remedy them.

Training Method

A typical training method consists of a device abstraction, model transfer to this abstraction, dataset creation, a dataloader, a random sampler and a training loop (forward and backward pass followed by parameter update). Let’s begin with device abstraction:

def train(rank, FLAGS):
print("Starting train method on rank: {}".format(rank))
dist.init_process_group(
backend='nccl', world_size=FLAGS['world_size'], init_method='env://',
rank=rank)
model = ToyModel()
torch.cuda.set_device(rank)
model.cuda(rank)
...
...

With Pytorch/XLA this loop becomes:

def train(rank, FLAGS):
device = xm.xla_device()
rank = xm.get_ordinal()
print("Starting train method on rank: {}".format(rank))
model = ToyModel()
model = model.to(device)
...
...

xm in this snippet refers to torch_xla.core.xla_model (described here). xm.xla_device() provides the device abstraction. Rank variable is not required to transfer the model to the device, however it is illustrated here (through xm.get_ordinal method) for any potential usage in your training code.

The element examined next is the dataloader. In a typical pytorch training method it might be used similar to:

def train(rank, FLAGS): 
...
...
transform = transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
]
)
train_dataset = torchvision.datasets.MNIST(
'/tmp/', train=True, download=True, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=FLAGS['world_size'], rank=rank)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=FLAGS['batch_size'], shuffle=False, num_workers=0, sampler=train_sampler)
for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
...
...

PyTorch/XLA `parallel_loader` module introduced earlier, provides a wrapper around the PyTorch Dataloader to optimize data pipeline by overlapping the dataloading and the computation on the device. Following snippet illustrates the usage:

def train(rank, FLAGS): 
...
...
transform = transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
]
)
train_dataset = torchvision.datasets.MNIST(
'/tmp/', train=True, download=True, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=FLAGS['world_size'], rank=rank)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=FLAGS['batch_size'], shuffle=False,
num_workers=0, sampler=train_sampler)
train_loader = pl.MpDeviceLoader(train_loader, device)for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
...
...

Here `pl` refers to the torch_xla.distributed.parallel_loader module introduced here. And ‘device’ refers to xm.xla_device() introduced here.

The final element of a typical training method is the training loop, where we iterate for the required number of updates over the dataloader. Within this loop, we perform the forward and backward passes and the parameter update. Here is an example snippet:

def train(rank, FLAGS): 
...
for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward
optimizer.zero_grad()
loss.backward()
# Parameter Update
optimizer.step()

With PyTorch/XLA for data parallel training, similar to GPU, the training method is executed on each core on replicas of the model. Not unlike GPUs, the forward and backward passes are executed on the model replica. At the end of the backward pass, an ALL_REDUCE operation is performed across cores before the parameter update. Here is an illustration:

def train(rank, FLAGS): 
...
for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward
optimizer.zero_grad()
loss.backward()
# Reduce gradients
xm.reduce_gradients(optimizer)
# Parameter Update
optimizer.step()

xm(xla_model) also provides an optimizer_step method which can be used to execute the last two steps in one call. These methods can be used to perform reduction operations with custom core groups (often used in model parallel settings).

Multiprocessing Spawn

In order to execute data parallel training, PyTorch/XLA provides a spawn method, a wrapper around pytorch multiprocessing spawn. The method performs the necessary TPU devices/mesh configuration before calling the pytorch multiprocessing spawn. Following snippet shows torch multiprocessing spawn call:

mp.spawn(train, nprocs=FLAGS['world_size'], args=(FLAGS,))

With PyTorch/XLA multiprocessing module, this becomes:

xmp.spawn(train, nprocs=FLAGS['world_size'], args=(FLAGS,), start_method='fork')

Notice the start_method=’fork’. It is not required specifically for PyTorch, but illustrated here as a good practice to enable the spawned threads to access the memory of the parent process. Particularly for larger models, using this helps avoid saving the redundant replica of the model in memory.

Saving Checkpoints

Saving checkpoints with PyTorch/XLA in principle is similar that operation on other devices. The tensors/data to be saved in the checkpoint needs to be transferred from device to cpu before it can be written into the checkpoint file.

PyTorch/XLA provides a method to ease this process via `xm.save`. This method automatically recognizes if it is called from the master process. It translates the device data to cpu for all the processes, But only for the master process it saves the data using torch.save. It also provides necessary synchronization across TPU cores. Since xm.save performs the synchronization inside, in a distributed setting, it must be called in each process. Not doing so will result in the master core waiting for synchronization to finish when it was not triggered on other cores, and will cause training to get stuck.

Refer to this example for more details on model saving and loading process for PyTorch/XLA.

Conclusion and Next Steps

This article attempts to summarize PyTorch/XLA constructs to help you update your model and training code to run with Cloud TPUs. It can be viewed as a write up companion of the first part of my talk at Cloud Next. Here is the full code of the snippets shared in preceding sections. For a review of PyTorch/XLA internals, the reader is encouraged to watch this talk from PyTorch Developer Day 2020. In the next article we will dive into performance debugging concepts and tools.

Resources

Following table summarizes a few good resources for PyTorch on Cloud TPUs:

Resources to find out more details

Acknowledgments

A big thank you to my outstanding colleagues Daniel Sohn, Isaack Karanja, Jack Cao and and Taylan Bilal for their generous help with the content and review, and to the open source community of PyTorch/XLA for making this amazing project possible. Special Thanks to Zak Stone and Joe Spisak for their leadership and encouragement.

--

--

Vaibhav Singh

Machine Learning Engineer @Google, Hardware Design Staff Engineer in past life.