Use multigraph to partition models

Multigraph is a feature where you can partition a model into individual graphs so you can run each graph separately. The flexibility is useful when the model is composed of separate PyTorch modules and you want to have fine-grained control on when to run them. For example, in a generative adversarial network, you can define the generator and the discriminator as their own graphs and run them separately. Currently, multigraph only supports inference.

How to use Multigraph

To compile a multigraph PEF:

  1. Trace each graph with samba.trace_multigraph() to retrieve handles to the graphs’ output tensors.

  2. Use the input tensors, output tensors, and optimizers to construct the FwdGraph objects that partition the model. Create each FwdGraph object with a unique name that you will use to refer to that graph in the future.

  3. Register the graphs to SambaSession via samba.session.add_graph().

  4. Compile the PEFs into a PEF with samba.session.compile_multigraph().

To run a multigraph PEF:

  1. Use the same trace_multigraph() and add_graph() commands that you used at compile time (steps 1 - 3) to define the same graphs.

  2. Call samba.session.init_multigraph_runtime(args.pef) to load the PEF, initialize the SambaRuntime backend, and transfer the tensors to the device.

  3. Use the normal samba.session.run() API to run the PEF, and use the graph_name parameter to specify which graph to run.

Basic Example

Below is a simple example of compiling and running an inference app with multigraph.

import torch.nn as nn
import sambaflow.samba as samba
from sambaflow.samba.utils import trace_multigraph
from sambaflow.samba.graph import FwdGraph

# define graphs and convert to Samba models
model0 = nn.Linear(10, 10)
model1 = nn.Linear(10, 10)
samba.from_torch_model(model0)
samba.from_torch_model(model1)

# create inputs to each graph
inputs0 = samba.randn(args.batch_size, 10, name="ipt0")
inputs1 = samba.randn(args.batch_size, 10, name="ipt1")

# trace each graph and get handle to their outputs
outputs0 = trace_multigraph(model0, inputs0)
outputs1 = trace_multigraph(model1, inputs1)

# register the torch modules as a FwdGraph
graph0 = FwdGraph(inputs0, outputs0, name="linear0")
graph1 = FwdGraph(inputs1, outputs1, name="linear1")

# register the SambaGraphs with SambaSession
samba.session.add_graph(graph0)
samba.session.add_graph(graph1)

if args.command == 'compile':
    # compile the registered graphs
    samba.session.compile_multigraph(name='multigraph_inference')
elif args.command == 'test':
    # initialize the Runtime context
    samba.session.init_multigraph_runtime(args.pef)

    # run graph0 by its name "linear0"
    samba_out0 = samba.session.run(inputs0, output_tensors=model0.output_tensors, graph_name="linear0")[0]

    # run graph1 by its name "linear1"
    samba_out1 = samba.session.run(inputs1, output_tensors=model1.output_tensors, graph_name="linear1")[0]

Weight Tying

If you want multiple graphs to share the same weights, you can tie the weights together by assigning one weight to the other. For example,

model1.linear.weight = model0.linear.weight

will tie model0’s linear weight and model1’s linear weight together. Both parameters will share the same location in device memory, so updates to model0’s weights will also be reflected in model1’s weights and vice versa.

Cached Inference

During inference in a model with multi-head attention like Llama 2, the multi-head attention layer iteratively computes the key and value (K and V) tensors for each token it generates. On a GPU, the generated key and value tensors are appended to the keys and values computed from the prompt. However, the RDU does not support dynamic tensor shapes, which is needed for concatenation, so we use multigraph instead.

How SambaNova does Cached Inference with Multigraph

SambaNova implements the caching by having two copies of the model, one cache graph and one no-cache graph. Each copy shares the same weights and biases as the other copy via weight tying. At the beginning of inference, the no-cache graph computes the K and V values from the prompt text to generate the cache as well as the first token. For subsequent tokens, the cache graph generates the K and V values while using the cached prompt K and V values. In SambaNova’s implementation, the K and V tensors are padded so the lengths are the max sequence size. When the model generates new tokens, the model assigns the new K and V values for that token to the appropriate index in the K and V tensors and replaces the padding.

Some models, such as Llama 2, GPT 13B, and Bloom, implement cached inference with multigraph on the RDU as well. However, not all models with multi-head attention use cached inference with multigraph on the RDU.