Model functions and changes

Throughout this document, you will see only minor modifications of your PyTorch code are necessary. In this conversion example, we include some required changes but also some changes that improve the robustness of the code, for example, addition of a main() function.

Overview of required changes

When you review the SambaFlow versions of the code for this model, you will notice that they look different from the original. This difference is, however, purely aesthetic - it’s meant to more clearly point out the SambaFlow additions.

The required changes are:

  • The SambaFlow Python imports.

  • The shape of the input tensors.

  • The SambaFlow tensor conversion methods.

If you’re converting your own model, there’s no need to refactor and reorganize your code. Make the changes in the right spots and your model works in SambaFlow.

At the code level, additions include:

  • common_app_driver() (deprecated utility function for compile)

  • utils.trace_graph()

  • from_torch_model_()

  • from_torch_tensor() and samba.to_torch()

  • samba.optim.AdamW()

  • utils.argparser.parse_app_args()

Most of the differences between the original code and the SambaFlow code deal with code readability and modularity.


As a first step, we import SambaFlow libraries so the code can run on a SambaNova system. See the SambaFlow API Reference External link for details.

Imports required by SambaFlow

import sambaflow
import sambaflow.samba as samba
import sambaflow.samba.optim as optim
import sambaflow.samba.utils as utils

from sambaflow.samba.utils.argparser import parse_app_args
from sambaflow.samba.utils.common import common_app_driver
from sambaflow.samba.sambaloader import SambaLoader
  • sambaflow is the base package.

  • samba corresponds to PyTorch torch.

  • samba.optim is similar to PyTorch optim and contains the optimizers available in SambaFlow.

  • samba.utils contains various SambaFlow-specific utilities for graph tracing, compiling and measuring performance, and so on.

  • samba.utils.argparser is similar to Python’s argparse library, but intended for requirements particular to SambaFlow.

    parse_app_args enables argument parsing supporting the SambaFlow execution modes (compile, run, test, measure performance). Users can define their own arguments and pass those into SambaFlow.

  • sambaflow.samba.utils.common

    common_app_driver is a deprecated utility function. It provides a single interface for compiling a model, and several means of measuring a model’s performance, such as measure-cpu, measure-gpu, and measure-performance. Replacement utilities are in progress.

  • sambaflow.samba.sambaloader

    SambaLoader is a wrapper around the PyTorch DataLoader and is built to take advantage of the SambaNova architecture to more efficiently parallelize load operations with graph/compute operations. It also automatically converts torch tensors into SambaTensors.

Additional imports

In this conversion example, there are several Python and PyTorch imports that are typically used for building a CNN. These imports can be left as-is. In fact, you’ll likely want the same imports for any model you bring in: SambaFlow is additive.

The following imports are required for this example:

import sys
import argparse
from typing import Tuple

SambaFlow can transparently handle many native PyTorch methods. During compilation, those methods are optimized for the SambaNova RDU to take advantage of the SambaNova Reconfigurable Dataflow Architecture.

Define the model

Here’s the modified ConvNet class, with comments that explain changes.

ConvNet class
class ConvNet(nn.Module):
    Instantiate a 4-layer CNN for MNIST Image Classification.

    In SambaNova, we can define the loss function as a part of the model
    and include it in the forward method to be computed.

    Typical SambaFlow usage example:

    model = ConvNet()
    optimizer = ...
    inputs = ...
    if args.command == "run":
        utils.trace_graph(model, inputs, optimizer, pef=args.pef, mapping=args.mapping)
        train(args, model)
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2),
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2),
        self.drop_out = nn.Dropout()
        self.fc1 = nn.Linear(7 * 7 * 64, 1000)
        self.fc2 = nn.Linear(1000, 10)
        self.criterion = nn.CrossEntropyLoss() # Add loss function to model

    def forward(self, x: torch.Tensor, labels: torch.Tensor):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.drop_out(out)
        out = self.fc1(out)
        out = self.fc2(out)
        loss = self.criterion(out, labels)     # Compute loss
        return loss, out                       # Return loss
The model definition differs only slightly from the original version: we've added
a loss function to the model definition.  Everything else remains as pure PyTorch.
  • Layers. Functions are run layer by layer. Our code example doesn’t change the original PyTorch nn methods when defining the layers (Sequential(), Conv2d(), ReLU(), MaxPool2d(), Dropout(), and Linear()).

  • Loss function. The only difference between this code and the original code is the loss function:

    • We define the loss function — nn.CrossEntropyLoss()-- directly as part of the model.

    • We include the loss function in the init() and forward() methods.

      When our example model is trained, the forward() method computes and returns both the output tensors and the loss.

      With this change, loss is computed directly on the RDU. That gives us a performance boost. If you don’t include a loss function, the output tensors of the forward function can be passed out and the loss must be computed externally on the host CPU. This results in lower performance, but it does allow a user to leverage custom loss functions.

  • forward() method. The forward() method is custom. In addition to computing the weights, the method also computes and returns the loss.

If you use a PyTorch function that the SambaFlow API does not yet support, the function is automatically computed on CPU instead of RDU, resulting in slower performance.

Capture user arguments with add_user_args

The add_user_args function captures and encapsulates user-defined command-line arguments so that they can be more easily passed to SambaFlow via the samba.utils.argparser.parse_app_args() method.

The arguments to add_user_args() are the model’s hyperparameters and the two path variables for storing the data and the model.

The SambaFlow compiler has defaults for most of the arguments. If you don’t provide a value for this argument, the SambaFlow default will be used.

add_user_args() function
def add_user_args(parser: argparse.ArgumentParser) -> None:
       help="input batch size for training (default: 100)",
       help="number of epochs to train (default: 6)",
       help="number of classes in dataset (default: 10)",
       help="learning rate (default: 0.001)",
       help="Download location for MNIST data",
   )  # From DATA_PATH
       "--model-path", type=str, default="model", help="Save location for model"

Generate SambaTensors with get_inputs()

To properly trace the PyTorch model graph to map it onto an RDU, the compiler requires tensors of the same shape as those that are passed to the forward() method during training. The data in these tensors isn’t important to the mapping. However, the compiler must able to determine how the tensors change shape as they flow from the input to the output of the graph. This helps the compiler to generate a PEF file that optimally lays out your model on the RDU.

In the case of the MNIST data input, which we use in this example, two tensors are needed:

  • One tensor that matches the shape of an MNIST image. We use samba.randn below to represent that.

  • One tensor that matches the shape of an MNIST label. We use samba.randint below because the label is a number between 1 and 10.

The get_inputs() function returns a tuple of the tensors.

get_inputs() function
def get_inputs(args: argparse.Namespace) -> Tuple[samba.SambaTensor]:
    Generates random SambaTensors in the same shape as MNIST image tensors and labels.

    In order to properly compile a PEF and trace the model graph, SambaFlow requires a SambaTensor
    that is the same shape as the input torch tensors, allowing the graph to be optimally
    mapped onto an RDU.

        args: User- and system-defined command line arguments

        A tuple of SambaTensors with random values in the same shape as MNIST image tensors.
    dummy_image = (
        samba.randn(args.batch_size, 1, 28, 28, name="image", batch_dim=0),
        samba.randint(args.num_classes, (args.batch_size,), name="label", batch_dim=0),
    return dummy_image

The function includes the samba.randn() and samba.randint() methods.

These methods are functionally identical to their PyTorch counterparts, but return SambaTensors rather than torch tensors. SambaTensors are wrappers for torch tensors and have additional data members and methods to support the SambaNova RDU architecture.

In order to place the model graph onto RDU, we have to know how many PCUs (compute units) and PMUs (memory units) are needed to optimally run the model. For that, we have to tell the compiler the shape of input tensors - the actual values in the tensors aren’t important. This is why we create dummy_image. A user can determine the shape of the input tensor(s) from analysis of the dataset or from the model’s hyperparameters (e.g. the height and width of an input image).

  • The name and batch_dim data members don’t exist in the original PyTorch implementation.

    • name is the name of the image.

    • batch_dim is the batch dimension.

Methods for getting and setting RDU device memory and for syncing data between an RDU and host CPU are not used in this example.

See the SambaFlow API Reference External link for details.

Load data with prepare_dataloader()

The prepare_dataloader() function returns SambaLoaders that are wrapped around your original PyTorch DataLoaders. It is used to load and transform input data. SambaLoader enables the loading process to better leverage the SambaNova parallel architecture while converting to Torch tensors and returning SambaTensors. The function is almost purely PyTorch, but it requires minor changes because the RDU works with SambaTensors.

At minimum, a SambaLoader needs:

  • A DataLoader passed in via the dataloader parameter.

  • A list of names to give to the tensors via the names parameter (all SambaTensors are named, either by the user or automatically by SambaFlow).

A less efficient way of achieving this conversion is to explicitly call samba.from_torch_tensor() on the torch tensors returned from the PyTorch DataLoaders. Using SambaLoader instead is recommended.

prepare_dataloader() function
def prepare_dataloader(args: argparse.Namespace) -> Tuple[sambaflow.samba.sambaloader.SambaLoader, sambaflow.samba.sambaloader.SambaLoader]:
    Transforms MNIST input to tensors and creates training/test dataloaders.

    Downloads the MNIST dataset (if necessary); splits the data into training and test sets;
    transforms the data to tensors; then creates torch DataLoaders over those sets.
    Torch DataLoaders are wrapped in SambaLoaders.

        args (argparse.Namespace): User- and system-defined command line arguments

        A tuple of SambaLoaders over the training and test sets.

    # Transform the raw MNIST data into PyTorch Tensors, which will be converted to SambaTensors
    transform = transforms.Compose(
            transforms.Normalize((0.1307,), (0.3081,)), # normalize the MNIST data

    # Get the train & test data (images and labels) from the MNIST dataset
    train_dataset = datasets.MNIST(
    test_dataset = datasets.MNIST(root=args.data_path, train=False, transform=transform)

    # Set up the train & test data loaders (input pipeline)
    train_loader = DataLoader(
        dataset=train_dataset,, shuffle=True
    test_loader = DataLoader(
        dataset=test_dataset,, shuffle=False

    # Create SambaLoaders
    sn_train_loader = SambaLoader(train_loader, ["image", "label"])
    sn_test_loader = SambaLoader(test_loader, ["image", "label"])

    return sn_train_loader, sn_test_loader

Train the model with train()

The train() method contains the training loop for the model. The code is similar to PyTorch.

train() method
def train(args: argparse.Namespace, model: nn.Module) -> None:
    Trains the model.

    Prepares and loads the data, then runs the training loop with the hyperparameters specified
    by the input arguments.  Calculates loss and accuracy over the course of training.

        args (argparse.Namespace): User- and system-defined command line arguments
        model (nn.Module): ConvNet model

    sn_train_loader, _ = prepare_dataloader(args)
    hyperparam_dict = {"lr": args.learning_rate}

    total_step = len(sn_train_loader)
    loss_list = []
    acc_list = []

    for epoch in range(args.num_epochs):
        avg_loss = 0
        for i, (images, labels) in enumerate(train_loader):
            # Run the forward pass
            # Convert the images from torch tensors into SambaTensors
            sn_images = samba.from_torch(images, name="image", batch_dim=0)
            sn_labels = samba.from_torch(labels, name="label", batch_dim=0)

            loss, outputs =
                input_tensors=(images, labels),

            # Convert SambaTensors back to torch tensors to calculate accuracy
            loss, outputs = samba.to_torch(loss), samba.to_torch(outputs)

            # Track the accuracy
            total = labels.size(0)
            _, predicted = torch.max(, 1)
            correct = (predicted == labels).sum().item()
            acc_list.append(correct / total)

            if (i + 1) % 100 == 0:
                    "Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%".format(
                        epoch + 1,
                        i + 1,
                        (correct / total) * 100,

Here’s how the function works:

  1. The inner training loop runs over the enumerated samples that are generated by the SambaLoader that is created by the prepare_dataloader() function. In this example, we are only using the SambaLoader that generates training samples.

  2. The SambaTensors are then passed as input to This method performs the entire training pass, from the foward pass all the way to the backward pass and optimization.

    A Session is a SambaFlow object that contains the variables and methods that are needed to compile and run a model on an RDU. In SambaFlow parlance, to run a model means to train it. The model object is created during compilation. takes in two key arguments: input_tensors and output_tensors.

    • input_tensors are the data on which the model is to be trained (what we get from prepare_dataloader()).

    • output_tensors capture the output shape that is generated by model compilation. Here’s how it works:

      When SambaFlow compiles a model, it generates a dataflow graph, which is similar to a PyTorch computational graph. To run the model, that graph must be traced before it can be placed onto the RDU. SambaFlow must know about the output shape that is generated by model compilation so that it can terminate the trace and map the graph onto the RDU. This output shape is captured in the output_tensors argument.

      What is actually contained in model.output_tensors is the output of the forward(). Thus, the output of will also be that of forward().

  3. To track the progress of model training, we output the loss and accuracy per epoch. This is standard practice and SambaFlow doesn’t change that. However, progress tracking should be run on a CPU, not an RDU, so we use the method samba.to_torch() to convert loss and outputs to torch tensors. The standard torch functions can then be applied to loss and outputs on the CPU.

Tie the pieces together with main()

In contrast to the original code, our code includes a main() function for more flexibility. The main() function is called to initialize the model, the data, optimizers, arguments, etc. and then kick off compilation and training.

main() function
def main(argv):

    args = parse_app_args(argv=argv, common_parser_fn=add_user_args)

    # Create the CNN model
    model = ConvNet()

    # Convert model to SambaFlow (SambaTensors)

    # Create optimizer
    # Note that SambaFlow currently supports AdamW, not Adam, as an optimizer
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

    # Normally, we'd define a loss function here, but with SambaFlow, it can be defined
    # as part of the model, which we have done in this case

    # Dummy SambaTensor
    inputs = get_inputs(args)

    # The common_app_driver() handles model compilation and various other tasks, e.g.,
    # measure-performance.  Running (training) a model must be done explicitly
    if args.command == "run":
        utils.trace_graph(model, inputs, optimizer, pef=args.pef, mapping=args.mapping)
        train(args, model)
                        init_output_grads=not args.inference,

Components of model training

The `main() function goes through these steps to get the model trained with SambaFlow:

  1. The process begins by processing command line arguments, and we use samba.utils.argparser. parse_app_args() to capture the necessary arguments from the command line.

    The common_parser_fn variable is used to pass user-defined arguments to the SambaFlow backend. We created the add_user_args() function to make that possible.

  2. Next, we create the PyTorch model in the typical way.

    We call the samba.from_torch_model_() method, which is part of the Session library to convert our PyTorch model into a SambaFlow model. The method recursively, and in-place, goes through a computational graph and converts all the torch tensors into SambaTensor instances.

    While from_torch_model_() and samba.from_torch_tensor() look similar, they are very different and are not interchangeable.
    • from_torch_model_() is a Session method and converts models.

    • samba.from_torch_tensor() is a SambaTensor method and converts only a single torch Tensor into a SambaTensor.

  3. Next, we define the optimizer. Currently, SambaFlow supports the AdamW and SGD optimizers. We use AdamW here. The optimizer must be defined externally as in this example. You cannot add an optimizer directly to the model definition in SambaFlow.

  4. We then create the “dummy” inputs to allow the SambaFlow compiler to trace the computational graph and map the resulting Dataflow graph to the RDU. For compilation, this example uses common_app_driver(). See Use common_app_driver in main() for compilation.

    The output of compilation is a PEF file, a binary file that contains the full details of the model. The PEF file can be deployed onto an RDU.

  5. To run, i.e., train, the model we have to use two methods:

    • The utils.trace_graph() method traces over the graph in a PEF file, initializing the weights and input/output tensors on the RDU. It takes as input the model, inputs, optimizer, the PEF file and a mapping.

      • The PEF is passed in as args.pef. This argument is part of the SambaFlow ArgParser, so you need not define it. You specify name and location of the PEF at the command line during compile with the --pef argument. See Compile and run the model for an example.

      • The mapping argument tells SambaFlow how to place the model onto the RDU. There are two options: spatial and section. A spatial mapping places the entire model onto the RDU at once (or up to a defined batch size). A section mapping breaks the model into several sections to be deployed on-chip one at a time. The default is section mapping.

    • The train() method indicates that we want to do a training run.

Use common_app_driver in main() for compilation

Compilation can be initiated in one of two ways:

  • With the common_app_driver(). We will soon have a replacement for this deprecated utility function.

  • With the samba.session.compile() method, an earlier approach to compilation.

Using common_app_driver() enables certain capabilities (compile, measure-cpu, measure-gpu, and measure-performance). To use common_app_driver() you import it from sambaflow.samba.utils.common (see Imports).