Training PyTorch on Cloud TPUs

Image for post
Image for post

PyTorch/XLA on TPU

Basic Concepts

Other scenarios which have effects similar to mark_step are: 1) A tensor value is accessed on the CPU. (e.g. print statement or a .item() call, if (..).any() statement) 2) An operation for which there is no XLA lowering is unavailable. In both of the scenarios subgraph leading to input of the said op/tensor must be (materialized) i.e. all the steps from conversion to HLO to compilation and execution on TPU are performed. Additionally the output of the execution is fetched from TPU to CPU. In case the op has no XLA lowering, The IR graph is truncated at the op. The inputs leading to the op are calculated and the CPU implementation is used to determine the inputs for the truncated portion. Evidently, this has a negative impact on the training performance. And therefore, it is recommended to create a request for the XLA lowering on PyTorch/XLA github project.

Library Elements

The following snippet shows the usual pytorch modules:

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist

From PyTorch XLA we will use the following three modules:

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

The xla_model provides the abstraction for a TPU core and the methods for common operations on the core. A TPU core is the smallest hardware unit usable in the programming model.

xla_multiprocessing modules provides methods for set up and execution of parallel operations on multiple cores. A TPU device consists of 8 TPU cores. xla_multiprocessing allows to work with either a single TPU core or all 8 cores.

parallel_loader module provides methods to augment PyTorch dataloders such that dataloading operation overlap with the execution on TPU cores in the data pipeline.

Please note that the modules mentioned here are the minimum set. For more details on PyTorch/XLA API please to the API guide.

Model Code

class ToyModel(nn.Module):
""" Toy Classifier """
def __init__(self):
super(ToyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.mp1 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1440, 10)
def forward(self, x):
x = nn.ReLU()(self.conv1(x))
x = self.mp1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.Softmax(dim=-1)(x)
return x

In this example, you notice that there are no dynamic elements in the model. In such scenarios no code changes are required in the model code to work with Cloud TPUs.

A dynamic element in the models is any construct or a group thereof which results in variations in the model graph seen across training steps. Variation in the model graph triggers XLA recompilation. For large scale models these recompilations are expensive and may result in less than optimal training speed.

Here is an example, where the model graph changes frequently. The scale of variation has been made synthetically large here to illustrate the concept.

class ToyModel (nn.Module):
""" Toy Classifier """
def __init__(self):
super(ToyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.mp1 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1440, 20)
self.fc2 = nn.Linear(20, 20)
self.fc3 = nn.Linear(20, 10)
self.layerdrop_prob = 0.5
def forward(self, x):
x = nn.ReLU()(self.conv1(x))
x = self.mp1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
for i in range(50):
if torch.rand(1,1) > self.layerdrop_prob :
x = self.fc2(x)
x = self.fc3(x)
x = nn.Softmax(dim=-1)(x)
return x

(*Layer drop, shown here, is a commonly used technique in models with multiple transformer layers. Understandably it may help the generalization performance of the model.[1])

The aspect to note here is the range of variation. In this example with loop-size 50, and a choice to keep or drop every layer, since the layers are identical we have 50 distinct graphs possible. If layers were distinct the space would be 2⁵⁰ possible graphs. This means possibly a new graph every training step. In such scenarios, reduction of the space of the variation is recommended.

PyTorch-XLA uses a compilation cache mechanism which prevents recompilation of graphs of shapes seen earlier. If the variation space is limited, this prevents the compilation penalties beyond the initial few steps. However if the variation space is too large and the compilation penalty is not contained.

It’s self-evident that there are more possible scenarios of dynamism in the model graph. In future articles we will explore the more of these scenarios and ways to remedy them.

Training Method

def train(rank, FLAGS):
print("Starting train method on rank: {}".format(rank))
dist.init_process_group(
backend='nccl', world_size=FLAGS['world_size'], init_method='env://',
rank=rank)
model = ToyModel()
torch.cuda.set_device(rank)
model.cuda(rank)
...
...

With Pytorch/XLA this loop becomes:

def train(rank, FLAGS):
device = xm.xla_device()
rank = xm.get_ordinal()
print("Starting train method on rank: {}".format(rank))
model = ToyModel()
model = model.to(device)
...
...

xm in this snippet refers to torch_xla.core.xla_model (described here). xm.xla_device() provides the device abstraction. Rank variable is not required to transfer the model to the device, however it is illustrated here (through xm.get_ordinal method) for any potential usage in your training code.

The element examined next is the dataloader. In a typical pytorch training method it might be used similar to:

def train(rank, FLAGS): 
...
...
transform = transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
]
)
train_dataset = torchvision.datasets.MNIST(
'/tmp/', train=True, download=True, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=FLAGS['world_size'], rank=rank)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=FLAGS['batch_size'], shuffle=False, num_workers=0, sampler=train_sampler)
for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
...
...

PyTorch/XLA `parallel_loader` module introduced earlier, provides a wrapper around the PyTorch Dataloader to optimize data pipeline by overlapping the dataloading and the computation on the device. Following snippet illustrates the usage:

def train(rank, FLAGS): 
...
...
transform = transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
]
)
train_dataset = torchvision.datasets.MNIST(
'/tmp/', train=True, download=True, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=FLAGS['world_size'], rank=rank)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=FLAGS['batch_size'], shuffle=False,
num_workers=0, sampler=train_sampler)
train_loader = pl.MpDeviceLoader(train_loader, device)for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
...
...

Here `pl` refers to the torch_xla.distributed.parallel_loader module introduced here. And ‘device’ refers to xm.xla_device() introduced here.

The final element of a typical training method is the training loop, where we iterate for the required number of updates over the dataloader. Within this loop, we perform the forward and backward passes and the parameter update. Here is an example snippet:

def train(rank, FLAGS): 
...
for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward
optimizer.zero_grad()
loss.backward()
# Parameter Update
optimizer.step()

With PyTorch/XLA for data parallel training, similar to GPU, the training method is executed on each core on replicas of the model. Not unlike GPUs, the forward and backward passes are executed on the model replica. At the end of the backward pass, an ALL_REDUCE operation is performed across cores before the parameter update. Here is an illustration:

def train(rank, FLAGS): 
...
for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward
optimizer.zero_grad()
loss.backward()
# Reduce gradients
xm.reduce_gradients(optimizer)
# Parameter Update
optimizer.step()

xm(xla_model) also provides an optimizer_step method which can be used to execute the last two steps in one call. These methods can be used to perform reduction operations with custom core groups (often used in model parallel settings).

Multiprocessing Spawn

mp.spawn(train, nprocs=FLAGS['world_size'], args=(FLAGS,))

With PyTorch/XLA multiprocessing module, this becomes:

xmp.spawn(train, nprocs=FLAGS['world_size'], args=(FLAGS,), start_method='fork')

Notice the start_method=’fork’. It is not required specifically for PyTorch, but illustrated here as a good practice to enable the spawned threads to access the memory of the parent process. Particularly for larger models, using this helps avoid saving the redundant replica of the model in memory.

Saving Checkpoints

PyTorch/XLA provides a method to ease this process via `xm.save`. This method automatically recognizes if it is called from the master process. It translates the device data to cpu for all the processes, But only for the master process it saves the data using torch.save. It also provides necessary synchronization across TPU cores. Since xm.save performs the synchronization inside, in a distributed setting, it must be called in each process. Not doing so will result in the master core waiting for synchronization to finish when it was not triggered on other cores, and will cause training to get stuck.

Refer to this example for more details on model saving and loading process for PyTorch/XLA.

Conclusion and Next Steps

Resources

Resources to find out more details

Acknowledgments

Machine Learning Engineer @Google, Hardware Design Staff Engineer in past life.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store