Training PyTorch on Cloud TPUs

PyTorch/XLA on TPU

You have read PyTorch/XLA on TPU is GA now. Perhaps you have run through some example Colabs as well. But you want to understand how you can run your own PyTorch model on TPUs. You are in the right place. This article deconstructs a simple example in terms of key concepts. The goal is to equip you with necessary fundamentals to update your own training code to train on TPUs.

Basic Concepts

PyTorch/XLA introduces a new tensor abstraction called XLA Tensor. Any operation performed on XLA tensor is traced into an IR (Intermediate Representation) graph. In the simplest scenario the tracing continues until it encounters a `xm.mark_step()` call. At this stage the IR graph is converted into XLA HLO (High Level Optimization) format and sent to TPU. The HLO format is compiled by the TPU runtime and further optimized into a LLO (Low Level Optimized) format. From LLO format the TPU instructions are generated and executed on TPUs.

Library Elements

Now let’s walk through a simple PyTorch code observing the elements to work with Cloud TPUs. We will examine the often used PyTorch modules, the model code, training loop and finally the multiprocessing spawn, and comment on the syntactical additions which enable the training to run on TPUs. The changes to enable sophisticated toolings such as fairseq may seem more complex than the example presented here. However, the concepts explained here are bijective wrt any implementation.

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

Model Code

The following snippets shows a simple convolutional neural network model:

class ToyModel(nn.Module):
""" Toy Classifier """
def __init__(self):
super(ToyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.mp1 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1440, 10)
def forward(self, x):
x = nn.ReLU()(self.conv1(x))
x = self.mp1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.Softmax(dim=-1)(x)
return x
class ToyModel (nn.Module):
""" Toy Classifier """
def __init__(self):
super(ToyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.mp1 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1440, 20)
self.fc2 = nn.Linear(20, 20)
self.fc3 = nn.Linear(20, 10)
self.layerdrop_prob = 0.5
def forward(self, x):
x = nn.ReLU()(self.conv1(x))
x = self.mp1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
for i in range(50):
if torch.rand(1,1) > self.layerdrop_prob :
x = self.fc2(x)
x = self.fc3(x)
x = nn.Softmax(dim=-1)(x)
return x

Training Method

A typical training method consists of a device abstraction, model transfer to this abstraction, dataset creation, a dataloader, a random sampler and a training loop (forward and backward pass followed by parameter update). Let’s begin with device abstraction:

def train(rank, FLAGS):
print("Starting train method on rank: {}".format(rank))
dist.init_process_group(
backend='nccl', world_size=FLAGS['world_size'], init_method='env://',
rank=rank)
model = ToyModel()
torch.cuda.set_device(rank)
model.cuda(rank)
...
...
def train(rank, FLAGS):
device = xm.xla_device()
rank = xm.get_ordinal()
print("Starting train method on rank: {}".format(rank))
model = ToyModel()
model = model.to(device)
...
...
def train(rank, FLAGS): 
...
...
transform = transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
]
)
train_dataset = torchvision.datasets.MNIST(
'/tmp/', train=True, download=True, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=FLAGS['world_size'], rank=rank)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=FLAGS['batch_size'], shuffle=False, num_workers=0, sampler=train_sampler)
for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
...
...
def train(rank, FLAGS): 
...
...
transform = transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
]
)
train_dataset = torchvision.datasets.MNIST(
'/tmp/', train=True, download=True, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=FLAGS['world_size'], rank=rank)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=FLAGS['batch_size'], shuffle=False,
num_workers=0, sampler=train_sampler)
train_loader = pl.MpDeviceLoader(train_loader, device)for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
...
...
def train(rank, FLAGS): 
...
for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward
optimizer.zero_grad()
loss.backward()
# Parameter Update
optimizer.step()
def train(rank, FLAGS): 
...
for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward
optimizer.zero_grad()
loss.backward()
# Reduce gradients
xm.reduce_gradients(optimizer)
# Parameter Update
optimizer.step()

Multiprocessing Spawn

In order to execute data parallel training, PyTorch/XLA provides a spawn method, a wrapper around pytorch multiprocessing spawn. The method performs the necessary TPU devices/mesh configuration before calling the pytorch multiprocessing spawn. Following snippet shows torch multiprocessing spawn call:

mp.spawn(train, nprocs=FLAGS['world_size'], args=(FLAGS,))
xmp.spawn(train, nprocs=FLAGS['world_size'], args=(FLAGS,), start_method='fork')

Saving Checkpoints

Saving checkpoints with PyTorch/XLA in principle is similar that operation on other devices. The tensors/data to be saved in the checkpoint needs to be transferred from device to cpu before it can be written into the checkpoint file.

Conclusion and Next Steps

This article attempts to summarize PyTorch/XLA constructs to help you update your model and training code to run with Cloud TPUs. It can be viewed as a write up companion of the first part of my talk at Cloud Next. Here is the full code of the snippets shared in preceding sections. For a review of PyTorch/XLA internals, the reader is encouraged to watch this talk from PyTorch Developer Day 2020. In the next article we will dive into performance debugging concepts and tools.

Resources

Following table summarizes a few good resources for PyTorch on Cloud TPUs:

Resources to find out more details

Acknowledgments

A big thank you to my outstanding colleagues Daniel Sohn, Isaack Karanja, Jack Cao and and Taylan Bilal for their generous help with the content and review, and to the open source community of PyTorch/XLA for making this amazing project possible. Special Thanks to Zak Stone and Joe Spisak for their leadership and encouragement.

Machine Learning Engineer @Google, Hardware Design Staff Engineer in past life.