Understanding the Performance: PyTorch on Cloud TPUs

Introduction

Pre-Reading

Case Study

Model

Environment Setup

export PROJECT_ID=<YOUR_GCP_PROJECT_ID># Create a VM with PyTorch/XLA image
gcloud compute --project=${PROJECT_ID} \
instances create pytorch-exp \
--zone=us-central1-a \
--machine-type=n1-standard-32 \
--image-family=torch-xla \
--image-project=ml-images \
--boot-disk-size=200GB \
--scopes=https://www.googleapis.com/auth/cloud-platform
# Create a TPU instance with pytorch-1.7 runtime
gcloud compute tpus create transformer-tutorial \
--zone=us-central1-a \
--network=default \
--version=pytorch-1.7 \
--accelerator-type=v3-8
gcloud compute ssh pytorch --zone=us-central1-a \
--ssh-flag=-L6006:localhost:6006
conda activate torch-xla-1.7# Identify TPU IP ADDRESS 
# [ Reported under NETWORK_ENDPOINTS when you run the following command ]
gcloud compute tpus list --zone=us-central1-a
export TPU_IP_ADDRESS=<TPU_IP_ADDRESS>
export XRT_TPU_CONFIG=”tpu_worker;0;$TPU_IP_ADDRESS:8470”
# Download and Install fairseq
git clone https://github.com/pytorch/fairseq
cd fairseq
git checkout fc1c38aa1c70e1d1ef45a6af335e3c6571ba436d
pip install -e .

Dataset

# Download and prepare the data
cd examples/translation/
# WMT'17 data:
bash prepare-wmt14en2de.sh
# or to use WMT'14 data:
# bash prepare-wmt14en2de.sh --icml17
cd ../..
# Binarize the dataset
TEXT=examples/translation/wmt17_en_de
fairseq-preprocess \
--source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid \
--testpref $TEXT/test \
--destdir data-bin/wmt17_en_de --thresholdtgt 0 \
--thresholdsrc 0 \
--workers 20

Training

#Baseline
fairseq-train \
data-bin/wmt17_en_de \
--max-update=6000 \
--save-interval=5 \
--arch=transformer_vaswani_wmt_en_de_big \
--max-target-positions=64 \
--attention-dropout=0.1 \
--no-progress-bar \
--criterion=label_smoothed_cross_entropy \
--source-lang=en \
--lr-scheduler=inverse_sqrt \
--min-lr 1e-09 \
--skip-invalid-size-inputs-valid-test \
--target-lang=de \
--label-smoothing=0.1 \
--update-freq=1 \
--optimizer adam \
--adam-betas '(0.9, 0.98)' \
--warmup-init-lr 1e-07 \
--lr 0.0005 \
--warmup-updates 4000 \
--dropout 0.3 \
--weight-decay 0.0 \
--valid-subset=valid \
--max-epoch=25 \
--disable-validation \
--log-interval=20 \
--log-format=simple \
--tpu \
--distributed-world-size 1 \
--max-tokens 4096 \
--no-save
2020-12-12 00:13:07 | INFO | train_inner | epoch 001:     20 / 28331 loss=14.82, nll_loss=14.675, ppl=26165.6, wps=0, ups=0, wpb=3960, bsz=88, num_updates=20, lr=2.5995e-06, gnorm=5.319, clip=0, train_wall=781, gb_free=3.2, gb_total=16, wall=7932020-12-12 00:13:46 | WARNING | fairseq.trainer | XLA compilation detected on device #0; too many of these can lead to slow training, but we expect a few in the beginning2020-12-12 00:14:25 | WARNING | fairseq.trainer | XLA compilation detected on device #0; too many of these can lead to slow training, but we expect a few in the beginning2020-12-12 00:15:05 | WARNING | fairseq.trainer | XLA compilation detected on device #0; too many of these can lead to slow training, but we expect a few in the beginning2020-12-12 00:15:45 | WARNING | fairseq.trainer | XLA compilation detected on device #0; too many of these can lead to slow training, but we expect a few in the beginning2020-12-12 00:16:25 | WARNING | fairseq.trainer | XLA compilation detected on device #0; too many of these can lead to slow training, but we expect a few in the beginning...
...

Quick background on XLA Compilation

Debugging

Understanding Debug Metrics

CompileTime Metric

Metric: CompileTime
TotalSamples: 823
Accumulator: 07h49m49s571ms672.950us
ValueRate: 861ms439.718us / second
Rate: 0.0289036 / second
Percentiles: 1%=016ms614.840us; 5%=016ms346.698us; 10%=017ms844.917us; 20%=018ms793.410us; 50%=38s342ms703.034us; 80%=40s311ms149.629us; 90%=42s818ms220.966us; 95%=43s310ms543.072us; 99%=46s236ms783.553us

aten::__local_scalar_dense Counter

Counter: aten::_local_scalar_dense
Value: 904

aten::<op_name> Counter

Counter: aten::adaptive_max_pool2d
Value: 12
Counter: aten::adaptive_max_pool2d_backward
Value: 1

Utility Function for Line Debug

def metsumm(stepno=''):
import torch_xla.debug.metrics as met
x = met.metrics_report().split('\n')
for i, line in enumerate(x):
if 'CompileTime' in line or 'aten::' in line:
key = line.split()[-1]
value = x[i+1].split()[-1]
print(
'step {}, key {}, value {}'.format(
stepno, key, value
)
)

Generating Detailed Diagnostics

/usr/share/torch-xla-nightly/pytorch/xla/scripts/debug_run.py \
-- \
fairseq-train \
data-bin/wmt17_en_de \
--max-update=100 \
--save-interval=5
...
… (training commandline used earlier)
/tmp/debug_run-<..>/
├── graphdir
├── graph_report
├── graphs
├── logs
├── metrics
├── metrics_imgdir
├── metrics_report
└── tensorboard

Baseline (Default) Performance

/usr/share/torch-xla-nightly/pytorch/xla/scripts/debug_run.py \
-- \
fairseq-train \
data-bin/wmt17_en_de \
--max-update=6000 \
--arch=transformer_vaswani_wmt_en_de_big \
--max-target-positions=64 \
--attention-dropout=0.1 \
--no-progress-bar \
--criterion=label_smoothed_cross_entropy \
--source-lang=en \
--lr-scheduler=inverse_sqrt \
--min-lr 1e-09 \
--skip-invalid-size-inputs-valid-test \
--target-lang=de \
--label-smoothing=0.1 \
--update-freq=1 \
--optimizer adam \
--adam-betas '(0.9, 0.98)' \
--warmup-init-lr 1e-07 \
--lr 0.0005 \
--warmup-updates 4000 \
--dropout 0.3 \
--weight-decay 0.0 \
--valid-subset=valid \
--max-epoch=25 \
--disable-validation \
--log-interval=20 \
--log-format=simple \
--tpu \
--distributed-world-size 1 \
--max-tokens 4096 \
--no-save
tensorboard --logdir=/tmp
--- frame-191
+++ frame-192
@@ -1,7 +1,7 @@
IR {
s64[] prim::Constant(), location=extract_features_scriptable@transformer.py:770, value=1
- s64[72,43]{0,1} xla::device_data(), location=make_positions@utils.py:243, device=TPU:0
- pred[72,43]{0,1} aten::eq(?, ?), location=extract_features_scriptable@transformer.py:770
+ s64[104,29]{0,1} xla::device_data(), location=make_positions@utils.py:243, device=TPU:0
+ pred[104,29]{0,1} aten::eq(?, ?), location=extract_features_scriptable@transformer.py:770
pred[] aten::any(?), location=extract_features_scriptable@transformer.py:770, dimensions=(0, 1), keep_reduced_dimensions=0, ROOT=0
}

Insight and Action

# Introduce Input Shape Bucketting
/usr/share/torch-xla-nightly/pytorch/xla/scripts/debug_run.py \
-- \
fairseq-train \
data-bin/wmt17_en_de \
--max-update=100 \
--arch=transformer_vaswani_wmt_en_de_big \
--max-target-positions=64 \
--attention-dropout=0.1 \
--no-progress-bar \
--criterion=label_smoothed_cross_entropy \
--source-lang=en \
--lr-scheduler=inverse_sqrt \
--min-lr 1e-09 \
--skip-invalid-size-inputs-valid-test \
--target-lang=de \
--label-smoothing=0.1 \
--update-freq=1 \
--optimizer adam \
--adam-betas '(0.9, 0.98)' \
--warmup-init-lr 1e-07 \
--lr 0.0005 \
--warmup-updates 4000 \
--dropout 0.3 \
--weight-decay 0.0 \
--valid-subset=valid \
--max-epoch=25 \
--disable-validation \
--log-interval=20 \
--log-format=simple \
--tpu \
--distributed-world-size 1 \
--max-tokens 4096 \
--no-save \
--num-batch-buckets 3

Outcome of Input shape bucketing

Insight and Action

--- frame-2
+++ frame-3
@@ -1,7 +1,7 @@
IR {
s64[] prim::Constant(), location=extract_features_scriptable@transformer.py:770, value=1
- s64[16,34]{1,0} xla::device_data(), location=make_positions@utils.py:243, device=TPU:0
- pred[16,34]{1,0} aten::eq(?, ?), location=extract_features_scriptable@transformer.py:770
+ s64[120,21]{0,1} xla::device_data(), location=make_positions@utils.py:243, device=TPU:0
+ pred[120,21]{0,1} aten::eq(?, ?), location=extract_features_scriptable@transformer.py:770
pred[] aten::any(?), location=extract_features_scriptable@transformer.py:770, dimensions=(0, 1), keep_reduced_dimensions=0, ROOT=0
}
770     
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
--        if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
-- self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
++ self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)

Outcome of if <tensor>.any() resolution

def median(self):
d = torch.tensor(list(my_tensor_queue))
return d.median()

What if I have Other aten counters in my metric report?

Next Steps

Acknowledgments

Machine Learning Engineer @Google, Hardware Design Staff Engineer in past life.