In the past two (also last two) curriculum weeks, I have focused on:
- Manually implementing mixed-precision training on a toy MLP problem. Ran experiments to understand when to expect memory efficiency and speedup.
- Reviewing literature and drafting for my project proposal on variable binding. I will update here again, once the material is packaged up.
The premise of mixed-precision training is that if all data and parameters are 16-bit instead of 32-bit, we cut the size of all tensors in half in GPU memory. Therefore, we can 1) double our batch size, and 2) increase arithmetic bandwidth within GPU and reduce network bandwidth between processes in distributed settings. NVIDIA’s Volta generation GPUs (e.g. V100) have hardware that supports these theoretical premises, and practitioners often observe a 3+X speedup in their experiments.
Mixed-precision training is becoming the status quo for research engineering, so I decided to do a short exercise to understand exactly how and when tensors are cast into 16-bit and the new class of bugs that can arise. Besides the PyTorch amp library and a few youtube videos that explain numeric representation using 16-bit versus 32-bit, I find this notebook from fast.ai quite useful. It explained away several bugs that I encountered when I naively set every tensor to 16-bit in a large model like the transformer and matches my expectation of where and when true 16-bit training fails due to numerical underflow and overflow. The tutorial suggested three techniques:
1) avoid imprecise weight update: maintain a 32-bit master copy of weight, with which we update with 32-bit gradients.
2) avoid numeric underflow in the 16-bit gradients computed during backprop, especially when the model is very deep: scale the loss as much as we can without overflowing (i.e. dynamically), backprop with this scaled-up loss to get 16-bit gradients (which are not too close to zero anymore), cast to 32-bit, downscale by the same loss scaler, and use this to update the 32-bit master weights.
3) avoid numeric overflow in activations or loss (the opposite problem for gradients): compute the loss in 32-bit, which means casting 16-bit inputs (e.g. logits) into 32-bit and use that to compute the loss. Apply the same to batchnorm layers.
To understand this workflow in detail, I tried the first two suggestions on a toy MLP problem, using the recommended apex library utilities, and without the fast.ai wrappers. The implementation is here, look for `def main_train` which breaks down the workflow into sixteen small steps.
I ran a few experiments with to observe gains in memory and wall-time efficiency. The experiments compared a 160M parameters model, 2.5M parameters model and 116K parameters model. Here are two main take-aways:
1) Wall-time speedup is only observed when the model is large(e.g. many wide layers in a MLP) . However, when the model is very large, the master copy of 32-bit weights and 32-bit gradients can take up so much memory that we are not able to scale up the batch size.
2) Memory efficiency (by measuring max memory allocated regularly) is only observed when the model is small. However, when the model is small, the number of matrix multiplications is trivial. We don’t get to take advantage of the increased arithmetic bandwidth—consequently, no wall-time speed ups.
In practice, there are several ways to implement mixed-precision training. These libraries have been optimized for speedup and memory efficiency. They are also straightforward to incorporate into any given PyTorch module code.
- Certain operations are autocast into float-32. These Ops are more sensitive to numeric underflow or overflow. Hence, the list includes functions that involve logarithm, exponentiation, and normalization.
- Certain operations are autocast into float-16. These Ops typically involve typical linear algebra.
- GradScaler handles loss scaling.
- My implementation of the MLP toy problem using torch.cuda.amp. Compared to the manual implementation, this has less education value but is a lot simpler.
- The trainer class accepts arguments that support full-precision, conservative mixed precision (only ops robust to numerical issues are cast into 16-bit), mixed-precision (maintains a master copy of 32-bit weights) true 16-bit precision.