Convert existing models to SambaFlow

Many SambaNova customers have converted an existing model that they built in PyTorch to work with SambaFlow. This doc page uses a simple example to illustrate what is essential for the conversion and discusses some best practices. You’ll see that much of your code remains unchanged and that SambaFlow doesn’t require you to reformat your data.

This doc page is about model conversion. We have a public GitHub repository External link with two scripts for pretraining data creation, pipeline.py and data_prep.py.

Get started with model conversion

In this document, we will:

The example model

The example model is a Convolutional Neural Network that performs image classification on the MNIST dataset. It consists of four layers:

  • 2 Convolutional layers, each containing a:

    • Conv2D

    • ReLU

    • MaxPool2D

  • 2 Fully-connected linear layers

Included or external loss function

You can use existing PyTorch loss functions, loss functions that SambaFlow doesn’t yet support, or fully custom loss functions.

  • It’s possible to include the model’s loss function as part of the model definition. This approach results in great performance enhancements.

  • It’s also possible to use a loss function outside of the model definition. You might want to do this if you are using a loss function that isn’t currently supported by SambaFlow or if you are using a custom loss function.

So, we’ll discuss two versions of this model:

  • A version where the loss function is included in the forward() function. This version fully benefits from the unique SambaNova architecture. See Model functions and changes

  • A version where the loss function is external to the model. An external loss function requires us to use a host CPU to compute the loss and gradients for backpropagation. An external loss function doesn’t require massive changes to your code. You only need to make a few changes to transport tensors between RDU and CPU. See Model with an external loss function.

Original and converted model

This tutorial explains code modifications using a simple example. The model is a 2-layer Convolutional Neural Network.

SambaNova workflow

SambaNova

When you want to run your model on SambaNova hardware, the typical workflow is the following:

  1. Start with the Planning questions to get the most out of your model. You might find some of our background materials interesting, for example the white paper Accelerated computing with a Reconfigurable Dataflow Architecture.

  2. Modify your code following the guidance in this tutorial:

  3. Compile your model. The output of the compilation is a PEF file, a binary file containing the full details of the model that can be deployed onto an RDU. See Compile and run your first model.

  4. Prepare the data you want to feed to your compiled model. See the data preparation scripts in our public GitHub repository External link.

  5. Run the model, passing in the PEF file. See Compile and run your first model.

Planning questions

You can ask yourself some questions to make the conversion process more straightforward. These questions will help you identify where you need to add methods from SambaFlow to your code.

  1. Where are my dataloaders?

    All models need data and one of the easiest ways to feed in that data is with a PyTorch DataLoader. The output tensors that come from the DataLoader need to be converted into SambaTensors. See Load data with prepare_dataloader().

  2. What shape are my input tensors?

    The SambaNova architecture is a reconfigurable one: the compute graph of your model is physically mapped onto an RDU. To perform this mapping, SambaFlow needs to know the shape of the input tensors. See Generate SambaTensors with get_inputs().

  3. Where is my model defined?

    A useful feature of SambaFlow is that a loss function can be included in the definition and forward section of a model. A loss function can be mapped directly onto an RDU, greatly enhancing performance. See Define the model.

  4. Where is my model instantiated?

    The model must be explicitly converted to SambaFlow. Fortunately, only a single SambaFlow method needs to be used to do that. See Tie the pieces together with main().

  5. Where is my loss function defined and what is it?

    As mentioned previously, a loss function can be a part of a model’s definition. So, if your model uses a loss function that SambaFlow supports, the function can be moved, as in Define the model. If your model doesn’t use a supported loss function it can be used externally. See Model with an external loss function.

  6. Where is my optimizer defined and what is it?

    Unlike loss functions, optimizers can’t be added directly to a model’s definition in SambaFlow. Loss functions are passed into SambaFlow during compilation and training. See Tie the pieces together with main().

Model functions and changes

Throughout this document, you will see that extreme modification of your PyTorch code isn’t necessary. SambaFlow integrates tightly with PyTorch and only minimal changes and additions are needed.

Overview of required changes

When you review the SambaFlow versions of the code for this model, you will notice that they look different from the original. This difference is, however, purely aesthetic - it’s meant to more clearly point out the SambaFlow additions.

The required changes are:

  • The SambaFlow Python imports.

  • The shape of the input tensors.

  • The SambaFlow tensor conversion methods.

If you’re converting your own model, there’s no need to refactor and reorganize your code. Make the changes in the right spots and your model works in SambaFlow.

At the code level, additions include:

  • common_app_driver()

  • utils.trace_graph()

  • from_torch_model_()

  • from_torch_tensor() and samba.to_torch()

  • samba.optim.AdamW()

  • utils.argparser.parse_app_args()

Most of of the differences between the original code and the SambaFlow code deal with code readability and modularity.

Imports

As a first step, we import SambaFlow libraries so the code can run on a SambaNova system. See the API Reference for details.

Imports required by SambaFlow

import sambaflow
import sambaflow.samba as samba
import sambaflow.samba.optim as optim
import sambaflow.samba.utils as utils
from sambaflow.samba.utils.argparser import parse_app_args
from sambaflow.samba.utils.common import common_app_driver
from sambaflow.samba.sambaloader import SambaLoader
  • sambaflow is the base package.

  • samba corresponds to PyTorch torch.

  • samba.optim is similar to PyTorch optim and contains all the optimizers available in SambaFlow.

  • samba.utils contains various SambaFlow-specific utilities for graph tracing, compiling and measuring performance, and so on.

  • samba.utils.argparser is similar to Python’s argparse library, but intended for requirements particular to SambaFlow.

    parse_app_args enables argument parsing supporting the SambaFlow execution modes (compile, run, test, measure performance). Users can define their own arguments and pass those into SambaFlow.

  • sambaflow.samba.utils.common

    common_app_driver is a tool to make using a compiled model easier. It provides a single interface for compiling a model, and several means of measuring a model’s performance, such as measure-cpu, measure-gpu, measure-performance, and measure-sections.

  • sambaflow.samba.sambaloader

    SambaLoader is a wrapper around the Pytorch DataLoader and is built to take advantage of the SambaNova architecture to more efficiently parallelize load operations with graph/compute operations. It also automatically converts Torch tensors into SambaTensors.

Additional imports

In this conversion example, there are several Python and PyTorch imports that are typically used for building a CNN. These imports can be left as-is. In fact, you’ll likely want the same imports for any model you bring in: SambaFlow is additive.

SambaFlow can transparently handle several native PyTorch methods. During compilation, those methods are optimized for the SambaNova RDU to take advantage of the SambaNova Reconfigurable Dataflow Architecture.

Define the model

Here’s the modified ConvNet class, with comments that explain changes.

ConvNet class
class ConvNet(nn.Module):
    """
    Instantiate a 4-layer CNN for MNIST Image Classification.

    In SambaNova, we can define the loss function as a part of the model
    and include it in the forward method to be computed.

    Typical SambaFlow usage example:

    model = ConvNet()
    samba.from_torch_model_(model)
    optimizer = ...
    inputs = ...
    if args.command == "run":
        utils.trace_graph(model, inputs, optimizer, pef=args.pef, mapping=args.mapping)
        train(args, model)
    """
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.drop_out = nn.Dropout()
        self.fc1 = nn.Linear(7 * 7 * 64, 1000)
        self.fc2 = nn.Linear(1000, 10)
        self.criterion = nn.CrossEntropyLoss() # Add loss function to model

    def forward(self, x: torch.Tensor, labels: torch.Tensor):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.drop_out(out)
        out = self.fc1(out)
        out = self.fc2(out)
        loss = self.criterion(out, labels)     # Compute loss
        return loss, out

The model definition differs only slightly from the original version: we’ve added a loss function to the model definition. Everything else remains as pure PyTorch.

  • Layers. When defining the layers, we are using only PyTorch nn methods: Sequential(), Conv2d(), ReLU(), MaxPool2d(), Dropout(), and Linear().

  • Loss function. The only difference between this code and the original code is that we define the loss function directly as part of the model. We are using the nn.CrossEntropyLoss() function and we include it in the init() and forward() methods. This change allows us to compute loss directly on the RDU and gives us a performance boost. If you don’t include a loss function, the output tensors of the forward function can be passed out and the loss is computed externally on the host CPU. This results in lower performance, but it does allow a user to leverage custom loss functions.

    When the model is trained solely, the forward() method computes and returns both the output tensors and the loss.

Capture user arguments with add_user_args

The add_user_args function captures and encapsulates user-defined command-line arguments so that they can be more easily passed to SambaFlow via the samba.utils.argparser.parse_app_args() method.

The arguments to add_user_args() are the model’s hyperparameters and the two path variables for storing the data and the model.

The SambaFlow compiler has defaults for most of the arguments. If you don’t provide a value for this argument, the SambaFlow default will be used.

add_user_args() function
def add_user_args(parser: argparse.ArgumentParser) -> None:
   parser.add_argument(
       "--batch-size",
       type=int,
       default=100,
       metavar="N",
       help="input batch size for training (default: 100)",
   )
   parser.add_argument(
       "--num-epochs",
       type=int,
       default=6,
       metavar="N",
       help="number of epochs to train (default: 6)",
   )
   parser.add_argument(
       "--num-classes",
       type=int,
       default=10,
       metavar="N",
       help="number of classes in dataset (default: 10)",
   )
   parser.add_argument(
       "--learning-rate",
       type=float,
       default=0.001,
       metavar="LR",
       help="learning rate (default: 0.001)",
   )
   parser.add_argument(
       "--data-path",
       type=str,
       default="data",
       help="Download location for MNIST data",
   )  # From DATA_PATH
   parser.add_argument(
       "--model-path", type=str, default="model", help="Save location for model"
   )  # From MODEL_STORE_PATH

Generate SambaTensors with get_inputs()

To properly trace the PyTorch model graph to map it onto an RDU, the compiler requires tensors of the same shape as those that are passed to the forward() method during training. The data in these tensors isn’t important to the mapping. However, the compiler must able to determine how the tensors change shape as they flow through the graph. This helps the compiler to generate a file that optimally lays out your model on the RDU.

In the case of MNIST data, two tensors are needed:

  • One tensor that matches the shape of an MNIST image.

  • One tensor that matches the shape of an MNIST label.

The get_inputs() function returns a tuple of the tensors.

get_inputs() function
def get_inputs(args: argparse.Namespace) -> Tuple[samba.SambaTensor]:
    """
    Generates random SambaTensors in the same shape as MNIST image tensors and labels.

    In order to properly compile a PEF and trace the model graph, SambaFlow requires a SambaTensor that is the same shape as the input Torch Tensors, allowing the graph to be optimally mapped onto an RDU.

    Input:
        args: User- and system-defined command line arguments

    Returns:
        A tuple of SambaTensors with random values in the same shape as MNIST image tensors.
    """
    dummy_image = (
        samba.randn(args.batch_size, 1, 28, 28, name="image", batch_dim=0),
        samba.randint(args.num_classes, (args.batch_size,), name="label", batch_dim=0),
    )
    return dummy_image

The function includes the samba.randn() and samba.randint() methods.

These methods are functionally identical to their PyTorch counterparts, but return SambaTensors rather than Torch tensors. SambaTensors are wrappers for Torch tensors and have additional data members and methods to support the SambaNova RDU architecture.

  • The name and batch_dim data members don’t exist in the original PyTorch implementation.

  • There are methods for getting and setting RDU device memory and for syncing data between an RDU and host CPU (not used here).

See the API Reference for details.

Load data with prepare_dataloader()

The prepare_dataloader() function is used to load and transform input data. It is almost purely PyTorch, but requires minor changes because the RDU works with SambaTensors. We use SambaLoader to wrap the PyTorch DataLoaders. SambaLoader enables the loading process to better leverage the SambaNova parallel architecture while converting to and returning SambaTensors.

At minimum, a SambaLoader needs:

  • A DataLoader passed in via the dataloader parameter.

  • A list of names to give to the tensors via the names parameter (all SambaTensors are named, either by the user or automatically by SambaFlow).

This conversion could also be achieved by explicitly calling samba.from_torch_tensor() on the Torch tensors returned from the Pytorch DataLoaders.

prepare_dataloader() function
def prepare_dataloader(args: argparse.Namespace) -> Tuple[sambaflow.samba.sambaloader.SambaLoader, sambaflow.samba.sambaloader.SambaLoader]:
    """
    Transforms MNIST input to tensors and creates training/test dataloaders.

    Downloads the MNIST dataset (if necessary); splits the data into training and test sets; transforms the
    data to tensors; then creates Torch DataLoaders over those sets.  Torch DataLoaders are wrapped in
    SambaLoaders.

    Args:
        args (argparse.Namespace): User- and system-defined command line arguments

    Returns:
        A tuple of SambaLoaders over the training and test sets.
    """

    # Transform the raw MNIST data into PyTorch Tensors, which will be converted to SambaTensors
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ]
    )

    # Get the train & test data (images and labels) from the MNIST dataset
    train_dataset = datasets.MNIST(
        root=args.data_path,
        train=True,
        transform=transform,
        download=True,
    )
    test_dataset = datasets.MNIST(root=args.data_path, train=False, transform=transform)

    # Set up the train & test data loaders (input pipeline)
    train_loader = DataLoader(
        dataset=train_dataset, batch_size=args.bs, shuffle=True
    )
    test_loader = DataLoader(
        dataset=test_dataset, batch_size=args.bs, shuffle=False
    )

    # Create SambaLoaders
    sn_train_loader = SambaLoader(train_loader, ["image", "label"])
    sn_test_loader = SambaLoader(test_loader, ["image", "label"])

    return sn_train_loader, sn_test_loader

Train the model with train()

The train() method contains the training loop for the model. The code is similar to PyTorch.

train() method
def train(args: argparse.Namespace, model: nn.Module) -> None:
    """
    Trains the model.

    Prepares and loads the data, then runs the training loop with the hyperparameters specified
    by the input arguments.  Calculates loss and accuracy over the course of training.

    Args:
        args (argparse.Namespace): User- and system-defined command line arguments
        model (nn.Module): ConvNet model
    """

    sn_train_loader, _ = prepare_dataloader(args)
    hyperparam_dict = {"lr": args.learning_rate}

    total_step = len(sn_train_loader)
    loss_list = []
    acc_list = []

    for epoch in range(args.num_epochs):
        for i, (images, labels) in enumerate(sn_train_loader):

            # Run the model on RDU: forward -> loss/gradients -> backward/optimizer
            loss, outputs = samba.session.run(
                input_tensors=(images, labels),
                output_tensors=model.output_tensors,
                hyperparam_dict=hyperparam_dict
            )

            # Convert SambaTensors back to Torch Tensors to calculate accuracy
            loss, outputs = samba.to_torch(loss), samba.to_torch(outputs)
            loss_list.append(loss.tolist())

            # Track the accuracy
            total = labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct = (predicted == labels).sum().item()
            acc_list.append(correct / total)

            if (i + 1) % 100 == 0:
                print(
                    "Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%".format(
                        epoch + 1,
                        args.num_epochs,
                        i + 1,
                        total_step,
                        torch.mean(loss),
                        (correct / total) * 100,
                    )
                )

Here’s how the function works:

  1. The inner training loop runs over the enumerated samples that are generated by the SambaLoader that is created by the prepare_dataloader() function. In this example, we are only using the SambaLoader that generates training samples.

  2. The SambaTensors are then passed as input to samba.session.run(). This method performs the entire training pass, from the foward pass all the way to the backward pass and optimization.

    A Session is a SambaFlow object that contains the variables and methods that are needed to compile and run a model on an RDU. In SambaFlow parlance, to run a model means to train it. The model object is created during compilation.

    samba.session.run() takes in two key arguments: input_tensors and output_tensors.

    • input_tensors are the data on which the model is to be trained (what we get from prepare_dataloader()).

    • output_tensors capture the output shape that is generated by model compilation. Here’s how it works:

      When SambaFlow compiles a model, it generates a dataflow graph, which is similar to a PyTorch computational graph. To run the model, that graph must be traced before it can be placed onto the RDU. SambaFlow must know about the output shape that is generated by model compilation so that it can terminate the trace and map the graph onto the RDU. This output shape is captured in the output_tensors argument.

      What is actually contained in model.output_tensors is the output of the forward(). Thus, the output of samba.session.run() will also be that of forward().

  3. To track the progress of model training, we output the loss and accuracy per epoch. This is standard practice and SambaFlow doesn’t change that. However, progress tracking should be run on a CPU, not an RDU, so we use the method samba.to_torch() to convert loss and outputs to Torch tensors. The standard Torch functions can then be applied to loss and outputs on the CPU.

Tie the pieces together with main()

The main() function is called to initialize the model, the data, optimizers, arguments, etc. and then kick off compilation and training.

main() function
def main(argv):

    args = parse_app_args(argv=argv, common_parser_fn=add_user_args)

    # Create the CNN model
    model = ConvNet()

    # Convert model to SambaFlow (SambaTensors)
    samba.from_torch_model_(model)

    # Create optimizer
    # Note that SambaFlow currently supports AdamW, not Adam, as an optimizer
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

    # Normally, we'd define a loss function here, but with SambaFlow, it can be defined
    # as part of the model, which we have done in this case

    # Dummy SambaTensor
    inputs = get_inputs(args)

    # The common_app_driver() handles model compilation and various other tasks, e.g.,
    # measure-performance.  Running, or training, a model must be explicitly carried out
    if args.command == "run":
        utils.trace_graph(model, inputs, optimizer, pef=args.pef, mapping=args.mapping)
        train(args, model)
    else:
        common_app_driver(args=args,
                        model=model,
                        inputs=inputs,
                        optim=optimizer,
                        name=model.__class__.__name__,
                        init_output_grads=not args.inference,
                        app_dir=utils.get_file_dir(__file__))

Components of model training

When using SambaFlow, these are the steps to get a model trained:

  1. The process begins by processing command line arguments, and we use samba.utils.argparser.parse_app_args() to capture the necessary arguments from the command line.

    The common_parser_fn variable is used to pass user-defined arguments to the SambaFlow backend. We created the add_user_args() function to make that possible.

  2. Next, we create the PyTorch model in the typical way.

    We call the samba.from_torch_model_() method, which is part of the Session library to convert our PyTorch model into a SambaFlow model. The method recursively, and in-place, goes through a computational graph and converts all the Torch tensors into SambaTensor instances.

    While from_torch_model_() and samba.from_torch_tensor() look similar, they are very different and are not interchangeable.
    • from_torch_model_() is a Session method and converts models.

    • samba.from_torch_tensor() is a SambaTensor method and converts only a single Torch Tensor into a SambaTensor.

  3. Next, we define the optimizer. Currently, SambaFlow supports the AdamW and SGD optimizers. We use AdamW here. The optimizer must be defined externally as in this example. You cannot add an optimizer directly to the model definition in SambaFlow.

  4. We then create the “dummy” inputs to allow the SambaFlow compiler to trace the computational graph and map the resulting Dataflow graph to the RDU. For compilation, it is a best practice to use the common_app_driver(). See Use common_app_driver in main() for compilation.

    The output of compilation is a PEF file, a binary file that contains the full details of the model. The PEF file can be deployed onto an RDU.

  5. To run, i.e., train, the model we have to use two methods:

    • The utils.trace_graph() method traces over the graph in a PEF file, initializing the weights and input/output tensors on the RDU. It takes as input the model, inputs, optimizer, the PEF file and a mapping.

      • The PEF is passed in as args.pef. This argument is part of the SambaFlow ArgParser, so you need not define it. You specify name and location of the PEF at the command line during compile with the --pef argument. See Compile and run the model for an example.

      • The mapping argument tells SambaFlow how to place the model onto the RDU. There are two options: spatial and section. A spatial mapping places the entire model onto the RDU at once (or up to a defined batch size). A section mapping breaks the model into several sections to be deployed on-chip one at a time. The default is section mapping.

    • The train() method indicates that we want to do a training run.

Use common_app_driver in main() for compilation

Compilation can be initiated in one of two ways:

  • With the common_app_driver() (this is a best practice).

  • With the samba.session.compile() method, an earlier approach to compilation.

Using common_app_driver() enables several capabilities (compile, dump, measure-cpu, measure-gpu, measure-performance, and measure-sections) with just one argument on the command line. SambaNova might add capabilities to this method over time. To use common_app_driver() you import it from sambaflow.samba.utils.common (see Imports).

Compile and run the model

To compile:

$ python <model.py> compile --pef-name <pef_name>

To run:

$ python <model.py> run --pef </path/to/pef_name>

Model with an external loss function

Model functions and changes discusses how to convert a PyTorch model that contains a loss function in its definition (as part of its forward() method). It is also possible to use a loss function outside the model definition. You might want to do this if your loss function isn’t currently supported by SambaFlow or if you are using a custom loss function.

With an external loss function we use a host CPU to compute the loss and gradient for backpropagation. We’ll make changes to transport tensors between RDU and CPU.

The following sections show the CNN model with an external loss function. The functions below are the updated functions. Everything else remains unchanged.

forward() used with external loss function

The first change is made to the model’s forward() method. Because loss is no longer computed on the RDU, you don’t need to pass the label tensors to this function and can remove that method parameter and the loss function itself:

forward()
   def forward(self, x: torch.Tensor):
       # Since loss isn't part of the model, we don't pass a label to forward()
       out = self.layer1(x)
       out = self.layer2(out)
       out = out.reshape(out.size(0), -1)
       out = self.drop_out(out)
       out = self.fc1(out)
       out = self.fc2(out)
       return out

Compare this to the original forward() method in Define the model.

prepare_dataloader() used with external loss function

The changes in prepare_dataloader() involve the SambaLoaders. By default, a SambaLoader converts each tensor given to it by a DataLoader and passes all of them along. In the case of MNIST, this means both the image and label tensors.

However, our model no longer takes in label tensors, so we filter those out. Labels are used during training - the model uses them to calculate the gradients (loss) that are then used to adjust the model’s weights and biases during the backward pass.

  • When we include the loss function in the model, we have to pass in the labels to the initializer and forward methods of the model so that the gradients and backward pass can be computed on the RDU.

  • When we use an external loss function, we no longer pass the labels to the forward method because the loss isn’t computed on RDU.

We provide an anonymous function to the SambaLoader via the function_hook parameter. The function acts as a filter, removing the tensors that you don’t want. The function must return a list and it must return the same number of tensors as named in the names parameter.

It is possible to retain the original tensors from the DataLoader. If you set the return_original_batch parameter to True, the SambaLoader returns a list that contains the tensors you filtered for and the original tensors, in that order. This allows us to preserve the MNIST labels for use in the loss calculation.

Compare this to the original prepare_dataloader() function in Load data with prepare_dataloader().

prepare_dataloader()
def prepare_dataloader(args: argparse.Namespace) -> Tuple[sambaflow.samba.sambaloader.SambaLoader, ...]:
   """
   Transforms MNIST input to tensors and creates training/test dataloaders.

   Downloads the MNIST dataset (if necessary); splits the data into training and test sets; transforms the
   data to tensors; then creates Torch DataLoaders over those sets.  Torch DataLoaders are wrapped in
   SambaLoaders.

   Input:
       args: User- and system-defined command line arguments

   Returns:
       A tuple of SambaLoaders over the training and test sets.
   """

   # Transform the raw MNIST data into PyTorch Tensors, which will be converted to SambaTensors
   transform = transforms.Compose(
       [
           transforms.ToTensor(),
           transforms.Normalize((0.1307,), (0.3081,)),
       ]
   )

   # Get the train & test data (images and labels) from the MNIST dataset
   train_dataset = datasets.MNIST(
       root=args.data_path,
       train=True,
       transform=transform,
       download=True,
   )
   test_dataset = datasets.MNIST(root=args.data_path, train=False, transform=transform)

   # Set up the train & test data loaders (input pipeline)
   train_loader = DataLoader(
       dataset=train_dataset, batch_size=args.bs, shuffle=True
   )
   test_loader = DataLoader(
       dataset=test_dataset, batch_size=args.bs, shuffle=False
   )

   # Create SambaLoaders
   sn_train_loader = SambaLoader(dataloader=train_loader, names=["image"], function_hook=lambda t: [t[0]], return_original_batch=True)
   sn_test_loader = SambaLoader(dataloader=test_loader, names=["image"], function_hook=lambda t: [t[0]], return_original_batch=True)

   return sn_train_loader, sn_test_loader

train() used with external loss function

We change the train() method to accommodate an external loss function.

  1. Add a new parameter to the function, allowing a loss function to be passed into it (we will do this in the main() function).

  2. Change the inner training loop: we still loop over the enumerated output from a SambaLoader, but we take an extra step to extract the labels from the original batch.

  3. Change the computation of the model’s forward and backward sections.

    • Modify the samba.session.run() method to only work with image tensors (via the input_tensors parameter) and to only compute the forward section (via setting the section_types parameter to "FWD"). The raw output of the model’s forward() method is captured in the first element of the tuple returned by samba.session.run().

    • We use this output to compute the loss and gradients on the CPU. We pass the output to the CPU via samba.to_torch().

  4. The next few operations are pure PyTorch: set requires_grad to True, call the loss function on the output and labels, and then compute the backward pass.

  5. To finish the computation, we pass the output back from the CPU to the RDU via another call to samba.session.run(). We use the grad_of_outputs parameter, which takes in a list of gradients to be applied in the model’s backward pass on RDU. We set this parameter by calling samba.from_torch_tensor() to convert the output gradients to SambaTensors.

  6. We set the section_types parameter to a list containing “BCKWD” and “OPT” to run only those model sections on the RDU, thus completing one iteration of the training loop.

Compare this to the original train() function in Train the model with train().

train()
def train(args: argparse.Namespace, model: nn.Module, criterion: Callable) -> None:
   """
   Trains the model.

   Prepares and loads the data, then runs the training loop with the hyperparameters specified
   by the input arguments with a given loss function.  Calculates loss and accuracy over the course of training.

   Inputs:
       args: User- and system-defined command line arguments
       model: ConvNet model
       criterion: Loss function

   Returns:
       None
   """

   sn_train_loader, sn_test_loader = prepare_dataloader(args)
   hyperparam_dict = {"lr": args.learning_rate}

   total_step = len(sn_train_loader)
   loss_list = []
   acc_list = []

   for epoch in range(args.num_epochs):
       for i, (images, original_batch) in enumerate(sn_train_loader):

           # The label tensor is the second element of the original batch
           labels = original_batch[1]

           # Run only the forward pass on RDU and note the section_types argument
           # The first element of the returned tuple contains the raw outputs of forward()
           outputs = samba.session.run(
               input_tensors=(images,),
               output_tensors=model.output_tensors,
               hyperparam_dict=hyperparam_dict,
               section_types=["FWD"]
           )[0]

           # Convert SambaTensors back to Torch Tensors to carry out loss calculation
           # on the host CPU.  Be sure to set the requires_grad attribute for Torch.
           outputs = samba.to_torch(outputs)
           outputs.requires_grad = True

           # Compute loss on host CPU and store it for later tracking
           loss = criterion(outputs, labels)

           # Compute gradients on CPU
           loss.backward()
           loss_list.append(loss.tolist())

           # Run the backward pass and optimizer step on RDU and note the grad_of_outputs
           # and section_types arguments
           samba.session.run(
               input_tensors=(images,),
               output_tensors=model.output_tensors,
               hyperparam_dict=hyperparam_dict,
               grad_of_outputs=[samba.from_torch_tensor(outputs.grad)], # Bring the grads back from CPU to RDU
               section_types=["BCKWD", "OPT"])

           # Compute and track the accuracy
           total = labels.size(0)
           _, predicted = torch.max(outputs.data, 1)
           correct = (predicted == labels).sum().item()
           acc_list.append(correct / total)

           if (i + 1) % 100 == 0:
               print(
                   "Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%".format(
                       epoch + 1,
                       args.num_epochs,
                       i + 1,
                       total_step,
                       torch.mean(loss),
                       (correct / total) * 100,
                   )
               )

main() used with external loss function

We only have to make small changes to the main() function.

  • Define the loss function. This could be a built-in PyTorch loss function or a user-defined function. In this example, we call it criterion.

  • Pass this loss function to the training function.

Compare this to the original main() function in Tie the pieces together with main().

main()
def main(argv):

   args = parse_app_args(argv=argv, common_parser_fn=add_user_args)

   # Create the CNN model
   model = ConvNet()

   # Convert model to SambaFlow (SambaTensors)
   samba.from_torch_model_(model)

   # Create optimizer
   # Note that SambaFlow currently supports AdamW, not Adam, as an optimizer
   optimizer = samba.optim.AdamW(model.parameters(), lr=args.learning_rate)

   ###################################################################
   # Define loss function here to be used in the forward pass on CPU #
   ###################################################################
   criterion = nn.CrossEntropyLoss()

   # Create dummy SambaTensor for graph tracing
   inputs = get_inputs(args)

   # The common_app_driver() handles model compilation and various other tasks, e.g.,
   # measure-performance.  Running, or training, a model must be explicitly carried out
   if args.command == "run":
       utils.trace_graph(model, inputs, optimizer, init_output_grads=not args.inference, pef=args.pef, mapping=args.mapping)
       train(args, model, criterion)
   else:
       common_app_driver(args=args,
                       model=model,
                       inputs=inputs,
                       optim=optimizer,
                       name=model.__class__.__name__,
                       init_output_grads=not args.inference,
                       app_dir=utils.get_file_dir(__file__))

How to compile and run the model with the loss function

The commands for compiling and running a model are the same for a model with an external loss function and with a loss function included in of the model. The models are functionally the same, so the commands don’t change. See Compile and run the model.

Model conversion tips and tricks

This section, to be expanded, offers some tips and tricks for model conversion.

  • Torch Dataloaders. If the last batch’s length is not exactly divisible by your batch size, for example, if the size of the last batch is 28 and your PEF batch size is 32, compilation fails with a PEF mismatch error. Set the parameter drop_last_one=True to avoid that problem.

  • Data Visualization. SambaNova recommends that you don’t do data visualization directly on a SambaNova system.

Full model code

The full code of the converted model is available: