2

PyTorch/XLA: Performance debugging on Cloud TPU VM: Part I

 1 year ago
source link: https://cloud.google.com/blog/topics/developers-practitioners/pytorchxla-performance-debugging-tpu-vm-part-1
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Developers & Practitioners

PyTorch/XLA: Performance debugging on Cloud TPU VM: Part I

Vaibhav Singh
Product Manager, Google Cloud
January 5, 2022

In this three part series we explore the performance debugging ecosystem of PyTorch/XLA on Google Cloud TPU VM. TPU VM last year (2021). The TPU VM architecture allows the ML practitioners to work directly on the host where TPU hardware is attached. With the TPU profiler, debugging your PyTorch training on TPU VM is simpler than ever before. While the process to analyze the performance has changed, the fundamentals of PyTorch/XLA that you have acquired with the network attached TPU architecture (aka TPU Node architecture), still apply. 

In this (first) part we will briefly lay out the conceptual framework for PyTorch/XLA in the context of training performance. Please note that training performance in the current scope refers to training throughput, i.e. samples/sec, images/sec or equivalent. We use a case study to make sense of preliminary profiler logs and identify the corrective actions. The solution to solve the performance bottleneck will be left as an exercise to the reader.

In part-II of this series we will discuss the solution left as an exercise in the part-I and introduce further analysis of the performance to identify other performance improvement opportunities.

Finally, in part-III, we introduce the user defined code annotation. We will see how to visualize these annotations in the form of a trace and introduce some basic concepts to understand the trace.

By the end of this series, we aim to give you a better understanding of how to analyze performance of your PyTorch code on Cloud TPUs and things to consider when working with Cloud TPUs.

Pre-Reading

An understanding of inner workings of XLA Tensor can make the following content more accessible and useful. We encourage you to review this talk from PyTorch Developers Day 2020 and this talk from Google Cloud Next for a quick primer on XLA Tensors. You may also find this article helpful if you are new to PyTorch/XLA. This article also assumes that the reader is familiar with Google Cloud Platform SDK and has access to a Google Cloud project with permissions to create resources such as virtual machines and Cloud TPU instances. Most of the profiler concepts will be explained here, however, introductory reading of TPU VM Profiler is also recommended.

Client-Server Terminology for PyTorch/XLA 

As in the TPU Node architecture (before TPU VM) PyTorch XLA still uses the lazy tensor paradigm, i.e. when you are using XLA Tensors, any operations performed on this tensor are simply recorded in an intermediate representation (IR) graph. When a step is marked (xm.mark_step() call), this graph is converted to XLA (HLO format - High Level Operations) and dispatched for execution to TPU runtime (server).

Note that the TPU runtime is the part of TPU server side functionality and all the work done up to the generation of the HLO graph is part of (and henceforth referred to as) the client side functionality. Unlike the previous generation where the TPU runtime (server) was automatically started when you created a TPU instance, incase of TPU VM, PyTorch/XLA library takes care of starting the server when you submit a training. You can also start the XRT (XLA Runtime) server manually on the desired port, Hence the XRT_TPU_CONFIG set in the code snippets later in the post  refers to the default port where PyTorch/XLA starts the XRT server. Unlike the previous generation, client and server run on the same host however the abstractions still hold and are helpful to understand the performance (more details here).

Case Study

Context 

We will examine UniT (Unified Transformer) training on GLUE/QNLI task using the MMF framework for multi-modal learning from Facebook Research. We will discover an interesting aspect of Multihead Attention Implementation (observed in PyTorch 1.8) that incidentally results in sub-optimal training performance with PyTorch/XLA and discuss a potential corrective action.

Environment Setup

The case study uses TPU VM. In the following steps we create a TPU VM. The following commands can be run from Google Cloud Shell or any machine with the Google Cloud SDK installed and the correct credentials provisioned. (For more details please refer to TPU VM user guide.)


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK