Introduction

You are familiar with PyTorch/XLA. You have tested some example code, it works, encouraged by the quick win you set out to train your own model. Your model trains but you find that it is slower than expected. You want to find out why. This article attempts to share some of the tools to help you find out why and derive actionable steps to improve the training performance.

Pre-Reading

In the remainder of this article we will walk through a performance debug case study. A sound understanding of inner workings of XLA Tensor can make the following content more accessible. We…


PyTorch/XLA on TPU

You have read PyTorch/XLA on TPU is GA now. Perhaps you have run through some example Colabs as well. But you want to understand how you can run your own PyTorch model on TPUs. You are in the right place. This article deconstructs a simple example in terms of key concepts. The goal is to equip you with necessary fundamentals to update your own training code to train on TPUs.

Basic Concepts

PyTorch/XLA introduces a new tensor abstraction called XLA Tensor. Any operation performed on XLA tensor is traced into an IR (Intermediate Representation) graph. In the simplest scenario the tracing continues…

Vaibhav Singh

Machine Learning Engineer @Google, Hardware Design Staff Engineer in past life.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store