Compile with mixed precision (Beta)
Mixed precision combines the use of different numerical formats (such as FP32 and BF16) to reduce memory footprint and speed up large neural network workloads.
SambaFlow provides several convenience methods for enabling mixed precision in the model.
Graph Automatic Mixed Precision (GraphAMP)
The GraphAMP feature is available starting with SambaFlow 1.18. It is a convenient way to choose how aggressively the compiler will downcast operations in the model.
You use the --graphamp-preset
compiler argument to switch between different presets. The available options, shown in the following table, offer a tradeoff between accuracy and performance:
GraphAMP options | Description |
---|---|
|
GraphAMP is disabled and no optimization is applied to the model. |
|
Downcasts the inputs of GEMM-like operators to BF16. |
|
Downcasts all inputs and operators of the model to BF16. |
Default |
Forces all inputs and operators of the model to BF16 until it encounters a floating point conversion. |
NOTE: The |
Example
Below is a simple example that we will use to explain what happens when you compile the model with different --graphamp-preset
values.
import torch
import torch.nn.functional as F
class Model(nn.Module):
def forward(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
matmul_1 = torch.matmul(a, b)
act_1 = F.relu(matmul_1)
matmul_2 = torch.matmul(act_1, c)
out = F.relu(matmul_2)
return out
mymodel = Model()
GraphAMP mp0 mode
With mp0
, GraphAMP is disabled. The model that is being lowered by the compiler is the same as the model provided by the user through the model code.
$ python mymodel.py compile --graphamp-preset mp0
GraphAMP mp1 mode
With mp1
, GraphAMP downcasts the inputs of GEMM-like operators, in this example, 2 Matmul
operators. In the resulting mixed-precision model the Matmul
operators are also in mixed precision.
$ python mymodel.py compile --graphamp-preset mp1
With only GEMM operators in mixed precision, mp1 offers a conservative tradeoff between accuracy and performance. This mode is the recommened preset.
|
GraphAMP mp4 mode
With mp4
, GraphAMP aggressively downcasts all the inputs and operators to BF16. The resulting model uses BF16 throughput.
$ python mymodel.py compile --graphamp-preset mp4
Overrides with disable_graphamp()
SambaFlow provides context managers or decorators that allow you to disable mixed precision in regions of your model. This provides the flexibility to improve performance of your model using GraphAMP while preserving the floating-point precision for a specific operator instance.
In these regions, operators run in an operator-specific datatype chosen by the user to maintain accuracy.
Use disable_graphamp()
to wrap only the forward pass(es) of your network, including the loss computation(s). Backward operations run with the same datatype as forward operations.
Example:
import sambaflow.samba as samba
import torch
import torch.nn.functional as F
class Model(nn.Module):
def forward(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
matmul_1 = torch.matmul(a, b)
act_1 = F.relu(matmul_1)
with samba.session.disable_graphamp():
matmul_2 = torch.matmul(act_1.bfloat16(), c.bfloat16())
out = F.relu(matmul_2)
return out
mymodel = Model()
$ python mymodel.py compile --graphamp-preset mp1
For the above example, matmul_2
and its relu
are within the disable_graphamp
context manager and the remaining graph is using mp1
mode. The resultant graph is shown below.
-
With full
mp1
mode bothmatmul_1
andmatmul_2
are in mixed precision. -
In this example, the resulting graph has
matmul_1
in mixed precision andmatmul_2
is in bf16 precision.
Resulting model:
Operator-specific behavior
-
Matmul and Linear. These are GEMM-like operators that support internal mixed precision compute and can take BF16 inputs and output FP32 tensors directly.
-
See here for a reference to supported PyTorch operators.
Additional precision arguments
You can customize compiler behavior with the following compiler arguments.
--tiling-accum
The accumulation operation associated with tiling can be sensitive to the precision setting. By default accumulation happens in FP32. Use --tiling-accum
to change accumulation to BF16 with stochastic rounding to enable better performance with some accuracy degradation.
$ compile --tiling-accum ["fp32" | "bf16sr"]
--weight-grad-reduce
Weight gradient reduce means reduction across weight gradient values computed in parallel. By default weight gradient reduction happens in FP32. Use the --weight-grad-reduce
argument to change reduction to BF16 with stochastic rounding to enable better performance with some accuracy degradation.
$ compile --weight-grad-reduce ["fp32" | "bf16sr"]
The weight-gradient-reduce argument is applied only to linear weight gradient updates.
|
--fp32-params
By default, if the forward pass for a particular operator has BF16 inputs, the backward pass for that operator produces BF16 gradients. Gradient values with small magnitudes may not be representable, causing underflow and the update for the corresponding parameters to be lost.
--fp32-params
argument addresses this issue by having the optimizer output two copies of the weight in both BF16 and FP32 precision. The BF16 copy is sent to the next operator, while the FP32 copy is used to update the trainable parameter.
When using full precision weight update mode (--fp32-params ), it is expected the model parameters are initialize in BF16.
|
--enable-mixed-precision-ops
By default, an operator computes and returns in the same precision as its input. However, for some operations, you might want to have certain operators in mixed precision without relying on the GraphAMP.
enable-mixed-precision-ops
allows you to specify operators in mixed precision on the whole graph. Two types of operations are supported, GEMM-like operations (matmul
and linear
) and softmax
.
$ compile --enable-mixed-precision-ops gemm softmax
For GEMM-like operations, enable-mixed-precision-ops enables bf16 input and fp32 output. For softmax operations, enable-mixed-precision-ops enables fp32 input and bf16 output.
|