Compile with mixed precision (Beta)

Mixed precision combines the use of different numerical formats (such as FP32 and BF16) to reduce memory footprint and speed up large neural network workloads.

SambaFlow provides several convenience methods for enabling mixed precision in the model.

Graph Automatic Mixed Precision (GraphAMP)

The GraphAMP feature is available starting with SambaFlow 1.18. It is a convenient way to choose how aggressive the compiler will downcast operations in the model.

You use the --graphamp-preset compiler argument to toggle between different presets. The available options, shown in the following table, offer a tradeoff between accuracy and performance:

Table 1. GraphAMP options
GraphAMP options Description

mp0

GraphAMP is disabled and no optimization is applied to the model.
NOTE: Unless the model has been hand tuned for mixed precision, expect a performance hit when running models in FP32 precision with mp0.

mp1
(Recommended)

Downcasts the inputs of GEMM-like operators to BF16.

mp4

Downcasts all inputs and operators of the model to BF16.

Default

Forces all inputs and operators of the model to BF16 until it encounters a floating point conversion.

NOTE: The mp2 and mp3 options will become available in a future SambaFlow release.

Example

Below is a simple example that we will use to explain what happens when you compile the model with different --graphamp-preset values.

import torch
import torch.nn.functional as F

class Model(nn.Module):
    def forward(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        matmul_1 = torch.matmul(a, b)
        act_1 = F.relu(matmul_1)
        matmul_2 = torch.matmul(act_1, c)
        out = F.relu(matmul_2)
        return out

mymodel = Model()

GraphAMP mp0 mode

With mp0, GraphAMP is disabled. The model that is being lowered by the compiler is the same as the model provided by the user through the model code.

$ python mymodel.py compile --graphamp-preset mp0
flowchart LR a((a)) -->|FP32| matmul(Matmul) b((b)) -->|FP32| matmul matmul -->|FP32| relu(ReLU) relu -->|FP32| matmul2(Matmul) c((c)) -->|FP32| matmul2 matmul2 -->|FP32| relu2(ReLU) relu2 -->|FP32| out((out))

GraphAMP mp1 mode

With mp1, GraphAMP downcasts the inputs of GEMM-like operators, in this example, 2 Matmul operators. In the resulting mixed-precision model the Matmul operators are also in mixed precision.

$ python mymodel.py compile --graphamp-preset mp1
With only GEMM operators in mixed precision, mp1 offers a conservative tradeoff between accuracy and performance. This mode is the recommened preset.
flowchart LR a((a)) -->|FP32| conversion(To) conversion:::changed -->|BF16| matmul(Matmul) b((b)) -->|FP32| conversion2(To) conversion2:::changed -->|BF16| matmul matmul -->|FP32| relu(ReLU) relu --> |FP32| conversion3(To) conversion3:::changed -->|BF16| matmul2(Matmul) c((c)) -->|FP32| conversion4(To) conversion4:::changed -->|BF16| matmul2 matmul2 -->|FP32| relu2(ReLU) relu2 -->|FP32| out((out)) classDef changed stroke:#f66,stroke-width:2px,stroke-dasharray: 5 5

GraphAMP mp4 mode

With mp4, GraphAMP aggressively downcasts all the inputs and operators to BF16. The resulting model uses BF16 throughput.

$ python mymodel.py compile --graphamp-preset mp4
flowchart LR a((a)) -->|BF16| matmul(Matmul) b((b)) -->|BF16| matmul matmul -->|BF16| relu(ReLU) relu -->|BF16| matmul2(Matmul) c((c)) -->|BF16| matmul2 matmul2 -->|BF16| relu2(ReLU) relu2 -->|BF16| out((out))

Default mode

If you specify --graphamp-preset without an argument value, the result is default mode. In SambaFlow 1.18, the default mode is mp4 (discussed above).

$ python mymodel.py compile
In the SambaFlow 1.19 release, the default mode will become mp0 .

Overrides with disable_graphamp()

SambaFlow provides context managers or decorators that allow you to disable mixed precision in regions of your model. This provides the flexibility to improve performance of your model using GraphAMP while preserving the floating-point precision for a specific operator instance.

In these regions, operators run in an operator-specific datatype chosen by the user to maintain accuracy.

Use disable_graphamp() to wrap only the forward pass(es) of your network, including the loss computation(s). Backward operations run with the same datatype as forward operations.

Example:

import sambaflow.samba as samba
import torch
import torch.nn.functional as F

class Model(nn.Module):
    def forward(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        matmul_1 = torch.matmul(a, b)
        act_1 = F.relu(matmul_1)
        with samba.session.disable_graphamp():
            matmul_2 = torch.matmul(act_1.bfloat16(), c.bfloat16())
            out = F.relu(matmul_2)
        return out

mymodel = Model()
$ python mymodel.py compile --graphamp-preset mp1

For the above example, matmul_2 and its relu are within the disable_graphamp context manager and the remaining graph is using mp1 mode. The resultant graph is shown below.

  • With full mp1 mode both matmul_1 and matmul_2 are in mixed precision.

  • In this example, the resulting graph has matmul_1 in mixed precision and matmul_2 is in bf16 precision.

Resulting model:

flowchart LR a((a)) -->|FP32| conversion(To) conversion:::changed -->|BF16| matmul(Matmul) b((b)) -->|FP32| conversion2(To) conversion2:::changed -->|BF16| matmul matmul -->|FP32| relu(ReLU) relu -->|FP32| conversion3(To) c((c)) -->|FP32| conversion4(To) subgraph "samba.session.disable_graphamp()" conversion3 -->|BF16| matmul2(Matmul) conversion4 -->|BF16| matmul2 matmul2 -->|BF16| relu2(ReLU) end relu2 -->|BF16| out((out)) classDef changed stroke:#f66,stroke-width:2px,stroke-dasharray: 5 5

Operator-specific behavior

  • Matmul and Linear. These are GEMM-like operators that support internal mixed precision compute and can take BF16 inputs and output FP32 tensors directly.

  • See here for a reference to supported PyTorch operators.

Additional precision arguments

You can customize compiler behavior with the following compiler arguments.

--tiling-accum

The accumulation operation associated with tiling can be sensitive to the precision setting. By default accumulation happens in FP32. Use --tiling-accum to change accumulation to BF16 with stochastic rounding to enable better performance with some accuracy degradation.

$ compile --tiling-accum ["fp32" | "bf16sr"]

--weight-grad-reduce

Weight gradient reduce means reduction across weight gradient values computed in parallel. By default weight gradient reduction happens in FP32. Use the --weight-grad-reduce argument to change reduction to BF16 with stochastic rounding to enable better performance with some accuracy degradation.

$ compile --weight-grad-reduce ["fp32" | "bf16sr"]
The weight-gradient-reduce argument is applied only to linear weight gradient updates.

--fp32-params

By default, if the forward pass for a particular operator has BF16 inputs, the backward pass for that operator produces BF16 gradients. Gradient values with small magnitudes may not be representable, causing underflow and the update for the corresponding parameters to be lost.

--fp32-params argument addresses this issue by having the optimizer output two copies of the weight in both BF16 and FP32 precision. The BF16 copy is sent to the next operator, while the FP32 copy is used to update the trainable parameter.

When using full precision weight update mode (--fp32-params), it is expected the model parameters are initialize in BF16.