You are familiar with PyTorch/XLA. You have tested some example code, it works, encouraged by the quick win you set out to train your own model. Your model trains but you find that it is slower than expected. You want to find out why. This article attempts to share some of the tools to help you find out why and derive actionable steps to improve the training performance.
You have read PyTorch/XLA on TPU is GA now. Perhaps you have run through some example Colabs as well. But you want to understand how you can run your own PyTorch model on TPUs. You are in the right place. This article deconstructs a simple example in terms of key concepts. The goal is to equip you with necessary fundamentals to update your own training code to train on TPUs.
PyTorch/XLA introduces a new tensor abstraction called XLA Tensor. Any operation performed on XLA tensor is traced into an IR (Intermediate Representation) graph. In the simplest scenario the tracing continues…
Machine Learning Engineer @Google, Hardware Design Staff Engineer in past life.