Training PyTorch on Cloud TPUs

PyTorch/XLA on TPU

Basic Concepts

Library Elements

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

Model Code

class ToyModel(nn.Module):
""" Toy Classifier """
def __init__(self):
super(ToyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.mp1 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1440, 10)
def forward(self, x):
x = nn.ReLU()(self.conv1(x))
x = self.mp1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.Softmax(dim=-1)(x)
return x
class ToyModel (nn.Module):
""" Toy Classifier """
def __init__(self):
super(ToyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.mp1 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1440, 20)
self.fc2 = nn.Linear(20, 20)
self.fc3 = nn.Linear(20, 10)
self.layerdrop_prob = 0.5
def forward(self, x):
x = nn.ReLU()(self.conv1(x))
x = self.mp1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
for i in range(50):
if torch.rand(1,1) > self.layerdrop_prob :
x = self.fc2(x)
x = self.fc3(x)
x = nn.Softmax(dim=-1)(x)
return x

Training Method

def train(rank, FLAGS):
print("Starting train method on rank: {}".format(rank))
dist.init_process_group(
backend='nccl', world_size=FLAGS['world_size'], init_method='env://',
rank=rank)
model = ToyModel()
torch.cuda.set_device(rank)
model.cuda(rank)
...
...
def train(rank, FLAGS):
device = xm.xla_device()
rank = xm.get_ordinal()
print("Starting train method on rank: {}".format(rank))
model = ToyModel()
model = model.to(device)
...
...
def train(rank, FLAGS): 
...
...
transform = transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
]
)
train_dataset = torchvision.datasets.MNIST(
'/tmp/', train=True, download=True, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=FLAGS['world_size'], rank=rank)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=FLAGS['batch_size'], shuffle=False, num_workers=0, sampler=train_sampler)
for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
...
...
def train(rank, FLAGS): 
...
...
transform = transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
]
)
train_dataset = torchvision.datasets.MNIST(
'/tmp/', train=True, download=True, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=FLAGS['world_size'], rank=rank)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=FLAGS['batch_size'], shuffle=False,
num_workers=0, sampler=train_sampler)
train_loader = pl.MpDeviceLoader(train_loader, device)for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
...
...
def train(rank, FLAGS): 
...
for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward
optimizer.zero_grad()
loss.backward()
# Parameter Update
optimizer.step()
def train(rank, FLAGS): 
...
for epoch in range(FLAGS['epochs']):
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward
optimizer.zero_grad()
loss.backward()
# Reduce gradients
xm.reduce_gradients(optimizer)
# Parameter Update
optimizer.step()

Multiprocessing Spawn

mp.spawn(train, nprocs=FLAGS['world_size'], args=(FLAGS,))
xmp.spawn(train, nprocs=FLAGS['world_size'], args=(FLAGS,), start_method='fork')

Saving Checkpoints

Conclusion and Next Steps

Resources

Resources to find out more details

Acknowledgments

Machine Learning Engineer @Google, Hardware Design Staff Engineer in past life.