Understanding the Performance: PyTorch on Cloud TPUs


You are familiar with PyTorch/XLA. You have tested some example code, it works, encouraged by the quick win you set out to train your own model. Your model trains but you find that it is slower than expected. You want to find out why. This article attempts to share some of the tools to help you find out why and derive actionable steps to improve the training performance.


In the remainder of this article we will walk through a performance debug case study. A sound understanding of inner workings of XLA Tensor can make the following content more accessible. We encourage the reader to review this talk from PyTorch Developers Day 2020 and this talk from Google Cloud Next for a quick primer. You may also find this article helpful if you are new to PyTorch/XLA. This article also assumes that the reader is familiar with Google Cloud Platform SDK and has a project with permissions to create resources such as virtual machines and Cloud TPU instances.

Case Study


We will train the fairseq/transformer model for English to German translation task. Fairseq supports TPUs and several models including transformer/roberta/wav2vec have been tuned for performance. In this example we will take an incremental approach to develop an understanding of the debug process. We will use a specific commit as a startpoint to ensure that code and concepts in this article are reproducible. And we will make modifications in the command line/ source code as we go along.

Environment Setup

The development environment will be set on Google Cloud Virtual Machine. In the following steps we create a virtual machine (we will refer to this machine as the Host) and a TPU instance (following commands can be run from a cloud shell of any machine with google cloud SDK installed with right credentials.

Notice that we have chosen v3–8 as the accelerator type. This means an instance with 8 TPU V3 cores (4 TPU V3 chips). For benchmarking or performance baseline one TPU V3 chip and V100 GPU chip should be considered comparable.

When the host (VM) and TPU instance have been created, ssh into the host:

Once on the host, execute the following commands to setup the development environment:


Dataset download and pre-processing steps will follow the official fairseq instructions (repeated here for completeness). We will use WMT17 English to German dataset:


In this example we will experiment with a few variations of training options. We begin with the following to start the training:

When we execute the command above we observe the following log:

Notice XLA compilation warnings. What is XLA compile?

Quick background on XLA Compilation

PyTorch XLA uses a lazy tensor paradigm. i.e. when you are using XLA Tensors, any operations performed on this tensor are simply recorded in an intermediate representation (IR) graph. When a step is marked (mark_step() call), this graph is converted to XLA (hlo format) and dispatched for execution to TPU Runtime. The graph is compiled and optimized further before execution. It is this compilation that we refer to as XLA compilation. XLA compilations can be slow and therefore too many of these compilations can adversely affect the training throughput. Fortunately Compiled graph is cached based on the graph hash. Therefore if the model graph including the input and output shape results in limited variants, the training performance can stabilize to an optimum rate after initial slower steps. However if the graph changes frequently and throughout the training rate never recovers.

Now let’s investigate the performance further.


In order to understand the slow training we try to answer the following three questions:

  1. Does the number of XLA compilations grow linearly with the number of training steps?
  2. Does the device to host context switches grow linearly?
  3. Does the model use any op which does not have an xla lowering?

To answer these questions, PyTorch/XLA provides a few tools. The quickest way to find these metrics is to use metrics_report as explained on the PyTorch/XLA troubleshooting page. You will notice that the three questions we seek to answer correspond to a Metric and two counters in the metrics report. Let’s get to know them.

Understanding Debug Metrics

CompileTime Metric

A typical snippet of the CompileTime metrics looks similar to the following:

A few important fields to notice here are TotalSamples, Counter, 50% compilation time. TotalSample indicates how many times XLA compilation happened. Counter indicates overall time spent compilation, and 50%= indicates median completion time.

aten::__local_scalar_dense Counter

A typical snippet with this counter looks similar to:

This counter indicates the number of device to host transfers. Once XLA compilation is complete, the execution of graph is done on the device, however the tensors still live on the device until something in the user’s code does not require the value of the tensor and thus causing the device to host transfer. Common examples of such instances include .item() calls or a control structure in the code which requires the value such as if (…).any() statements. At the execution point when these calls are encountered, if the compilation and execution has not been done, it results in early compilation and evaluation, making training further slower.

aten::<op_name> Counter

A typical snippet in the log with this counter looks similar to:

This counter indicates the number of instances the said op was seen. The prefix aten:: indicates that cpu/aten default implementation of this op is being used and XLA implementation is not available. Since the IR graph is to be converted to XLA format and executed on the device, this means that in the forward pass at the instances of these ops, the IR graph needs to be truncated. The inputs to the op are evaluated on device, brought to host and the op is executed with the said inputs. The output from the op is then plugged into the remainder of the graph and execution continues. Based on the number of instances and the location of such ops.

Utility Function for Line Debug

Now let’s introduce another brief code snippet which can pick and reports these three kinds of metrics:

The function call can be sprinkled in the code path to identify the regions which result in metric changes. It’s particularly useful to detect the device to host transfers (aten::__local_scalar_dense) calls.

Generating Detailed Diagnostics

In order to observe the trendlines of these metrics and more detailed information about the IR graph changes, PyTorch/XLA provides a debug_run wrapper. Now let’s use this wrapper for the transformer training under investigation:

You can interrupt the training at the desired time and examine the debug_run output. You can notice the following directory structure:

The two most useful contents in the directory are the metrics_report and graph_report. The metrics_report sums up all the metric values seen through steps. And graph_report shows the diff of the graphs in two successive steps where the XLA compilation was performed. Since compilation only happens if the current IR graph has not been seen before, graph report is very helpful to identify what is changing.

You will also notice a tensorboard directory which can help to observe the visualization of metrics variation throughout the training run.

Baseline (Default) Performance

Now let us use the tools we have learned so far to understand the training performance in our case study. For debugging purposes we recommend to use a single TPU core (if possible). If the problem under study was using model parallelism we would use 8 cores, however in the example under study is only using data parallelism, therefore for debug purposes we scale down to 1 core. We will use PyTorch/XLA’s debug_run wrapper to start. The full commandline is presented here for reproducibility.

After the run finishes. We can start tensorboard to visualize the CompileTime metric and aten___.* counters.

This command will start the tensorboard server on the localhost:6006 of the VM. If you started the ssh connection with port forwarding (bookmark) you can connect to localhost:6006 of your machine to view the tensorboard. There are other ways to do port forwarding and ssh tunneling to the same results, however passing the sshflag command line as explained here is the easiest (recommended) way. Here are CompileTime metric and aten local scalar Dense counter as seen in the tensorboard.

Tip: You can enter: aten___|CompileTime in the tag search field to pull the following two graphs to the top.)

First, we observe that CompileTime grows linearly with the training step. This indicates that caching mechanisms will not be effective in this condition and the training is (understandably) slow. For more details of total and median CompileTime(s), you can view /tmp/debug.*/metrics log.

Second observation is the linear growth of local_scalar dense. And finally, We also note that there are no other aten::<ops>, this indicates that we don’t need to worry about any op with no XLA implementation. More details on other aten counters and how to address them are provided here.

Now let’s find out possible reasons for linear CompileTime growth. Examining the graph_report we notice the following snippet.

Graph report shows the diff of graphs along with the stack traces to mark_step or some tensor evaluation which caused compilation. Trace provides helpful clues to locate the device to host transfers in the code. However, in this snippet we will focus on the diff of the graph. Notice the two lines which have been highlighted and the shape of the device data. This snippet indicates dynamic shapes at the said location in the model graph. Notice also that it’s an internal location and not the input to the model graph. However, the change in shapes at input does cause change in shapes for internal nodes as well.

Insight and Action

With this insight, we check the input shapes for the model and we realize that indeed inputs shapes changes are too frequent.

Tip: In case of frequent input shape changes we recommend bucketting on input shapes and padding to the bucket length to minimize input shape changes and without wasting too many FLOPs (Floating point operations).

Fortunately fairseq does provide an option for input_shape bucketing for translation tasks. Let’s experiment with num-input-buckets 3 and observe the results. Updated command line (full snippet) for completeness:

After the run finishes we examine the CompileTime metrics and aten local scalar dense counter once again.

Outcome of Input shape bucketing

After the Following diagrams show the updated CompileTime and local_scalar_dense metrics:

We notice that CompileTime has reduced linearity, It grows initially and then stabilizes. This indicates that we have addressed the dynamic shapes issue. In cases where there are more sources of dynamism (apart from input), input shape bucketing is not enough, and we would need to further track down the sources of input shape changes.

Notice also that local_scalar_dense metric remains linear. It means there is some portion of the code that is exercised every step and is causing the device to host transfers.

Tip: As described in the earlier section some of the common constructs which cause device to host transfers are .item() call, or if(..).any(). Note that torch tensors __format__ method also has a .item() call implicit inside. Therefore for logging purposes, you don’t require converting tensor to scalar by explicit .item() calls. Line debug technique introduced here is often an effective way to trackdown the source of device to host transfers. However in this study we will lean into graph_report in the debug data and see what we can learn.

Insight and Action

Once again let’s examine the first diff snippet in the graph_report:

Notice we still see some shape changes in the diff, it’s because even though we have introduced input shape bucketting we still see a set of different shapes. However, another interesting area to note here is the location. In the stack trace above this diff, you can find the full path to transformer.py in question. Now let’s examine the source code around line 770:

In this line the model updates the padding mask based on the overlapping indices of the previous output token in the padded region. There we have if <tensor>.any() call, which causes the device to host transfer. In order to implement an equivalent logic without the device to host transfers would require a few changes in padding representation as metrics of binary mask at the data loading stage. This would allow us to express this logic as matrix multiplication. In the scope of this article we will not cover alternate padding representation. However for study purposes we will allow the attention padding mask to always update based on the overlap of prev output token and padding indices (cross self_attention scenario).

Outcome of if <tensor>.any() resolution

Executing the debug run after the change and we observing the following graphs in Tensorboard for Compile Time metric and local scalar dense counter:

Notice that local_scalar_dense (device to host transfer) no longer grows linearly. Also not the step growth corresponds to log_interval. That indicates that the device to host transfer are only happening when we are printing the log. As explained earlier, based on the number of tensors printed, such growth is understandable and any penalty in performance can be reduced by changing the log interval to a higher value.

Finally wrt the baseline performance (default case) we started with, by these simple changes we are able to improve the performance by a factor of X.

Tip: In the scenario where every log interval the growth is higher than the number of tensor printed, examine the metrics computation logic for some computation which may be causing this growth.

For example, imagine you want to print the median value of some tensor over a list of observed value kept in queue:

The operation of torch tensor creation from the list of tensors results in the local_scalar_dense calls (device to host transfers) by as many as the elements in the list.

As general practice it is recommended to do a xm.mark_step() call before accessing tensor values. This results in an evaluation of the graph on the device (tensor materialization). mark_step() followed by explicit device to host transfers (if the tensors are collectible in a simple or nested dictionary or tree structure this function can be used) will eliminate local_scalar_dense growth. However, the benefits of this strategy to overall performance may be limited based on specifics such as number of tensors and associated graph sizes.

What if I have Other aten counters in my metric report?

If you observe another counter in the metric report with aten namespace (prefix aten::) as described here, you can either attempt to create and contribute the lowering following the official instructions or simply create a request (similar to this) using github issue. Please note that aten::nonzero is an exception to this category. This counter indicates usage of tensor as index to another tensor. Such a pattern generally causes a dynamic structure in the model graph which can potentially lead to slow down. Alternative matrix based implementations should be considered for equivalent logic implementation.

Next Steps

We walked through some of the basic concepts of PyTorch/XLA performance through this case study. In most cases, these concepts suffice to get the most out of Pytorch training on TPUs. At the time of writing this article, input-batch-shapes functionality in fairseq master branch has only been implemented for language pair dataset, for models using other datasets the simplest way to enable experimentation with fixed shapes is to use get_batch_shapes.

The reader is encouraged to experiment with the presented model or their own models (fairseq or otherwise) based on the concepts presented here. Other areas of performance study such as data pipeline and profiling will be explored in the future articles.


A big thank you to the usual suspects: my outstanding colleagues Daniel Sohn, Isaack Karanja, Jack Cao and and Taylan Bilal for their generous help with the content and review, and to the open source community of PyTorch/XLA for making this amazing project possible.

Machine Learning Engineer @Google, Hardware Design Staff Engineer in past life.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store