Examine LeNet model code

This tutorial is an implementation of LeNet, a convolutional neural network created by Yann LeCun, which was used by the American Post office to automatically identify handwritten zip code numbers. See this Medium article External link.

What you’ll learn

In the Hello SambaFlow! tutorial (logreg), you learned how to compile and run a program. You also looked at the code for this simple model, which included defining input arguments, data preparation, and code to compile and run the model. The model downloaded the training data from the internet.

This tutorial builds on logreg. We’ll look at the code for each part of the model, but here’s what’s new in this model:

  • Download data explicitly.

  • Perform a validation (test) run, passing in a test dataset.

  • Use an optimizer for training.

  • Save a checkpoint.

  • Perform an inference run, passing in a validation dataset.

Files

See the complete set of files in our sambanova/tutorials/lenet tutorial. This doc page includes collapsible code snippets for each code component we discuss.

Data

The tutorial uses the Fashion MNIST dataset. In contrast to the classic MNIST dataset, which consists of hand-written digits, Fashion MNIST uses images from Zalando External link a fast-fashion online retailer.

  • 60,000 images in the training set

  • 10,000 images in the test set.

Because the data and the labels are in separate files, you can use the test set for both validation and inference.

Code files

See Download the model code from GitHub in the code discussion doc page for a list of files and what they do.

Larger tutorials use separate .py files for for training and inference. Those programs include a compilation for inference step. For this model, compilation for inference doesn’t make sense — and we’ve simplified the code for ease of use.

It all starts with main()

A SambaNova model must go through both compilation and training. A real-world model also includes validation (testing) and inference, and this tutorial includes them. The main() function includes the functions to perform these tasks, and also does some preparation.

Function Description See

parse_app_args()

Collect the arguments coming from add_common_args() and add_run_args(). When users run the model, they can specify predefined arguments that come from the compiler (e.g. o0) and the SambaFlow framework, and application-specific arguments.

Parse input arguments

utils.set_seed()

Set a random seed for reproducibility while we’re in the development phases of our tutorial.

samba.from_torch_model_()

Create a model to use the SambaFlow framework. The function, which also converts a PyTorch model to a SambaFlow model, performs some initialization and related tasks. We pass in model, a class we create to represent the model.

Create a model that uses the SambaFlow framework

samba.optim.SGD()

Define the optimizer we’ll use for training the model. The SambaFlow framework supports AdamW and SGD out of the box. You can also specify a different optimizer.

See the API Reference

compile()

If the user specified compile() on the command line, call samba.session.compile(). Some of the inputs are specified in this code file, others directly in main().

Compile the model

run()

If the user specified run() on the command line, perform training, testing, or inference, based on other arguments that are passed in.

Run training, test, or inference

Code for main()
def main():
    args = parse_app_args(dev_mode=True,
                          common_parser_fn=add_common_args,
                          test_parser_fn=add_run_args,
                          run_parser_fn=add_run_args)
    utils.set_seed(42)
    params = vars(args)
    if args.print_params:
        print_params(params)

    model = LeNet(args.num_classes)
    samba.from_torch_model_(model)

    inputs = get_inputs(params)

    optimizer = samba.optim.SGD(model.parameters(),
                                lr=0.0) if not args.inference else None
    if args.command == "compile":
        samba.session.compile(model,
                              inputs,
                              optimizer,
                              name='lenet',
                              app_dir=utils.get_file_dir(file),
                              squeeze_bs_dim=True,
                              config_dict=vars(args),
                              pef_metadata=get_pefmeta(args, model))

    elif args.command == "test":
        print("Test is not implemented in this version.")
    elif args.command == "run":
        if args.inference:
            prepare(model, optimizer, params)
            batch_predict(model, params['dataset_name'], params)
        elif args.test:
            prepare(model, optimizer, params)
            test(model, params['dataset_name'], params)
        else:
            prepare(model, optimizer, params)
            train(model, optimizer, params)

Parse input arguments

Users can call lenet_tutorial.py with input arguments to affect its behavior. Some arguments are available out-of-the box, for example, compiler arguments and arguments predefined by the SambaFlow framework.

  • The first argument is usually compile or run. Based on that argument, users can specify additional options that are predefined by the compiler or the SambaFlow layer, or by the application itself.

  • Some arguments are available for any model:

    • Compiler arguments are supported in conjunction with compile, for example, mymodel.py compile --o0.

    • Runtime arguments are supported in conjunction with run, for example, mymodel.py run --data-dir=/tmp/data.

  • Each model can specify additional arguments, usually by using the following functions:

    • add_common_args() specifies arguments for use with either compile or run.

    • add_run_args() specifies arguments for use with run, that is, for training or inference.

    • add_compile_args() is not defined in this example. Because SambaNova supports a rich set of arguments out of the box, not all applications specify additional compiler arguments.

The code to support additional input arguments includes arguments that typically support an AI model, such as learning rate, and also --print-params, which shows only the model-specific arguments. In combination with run, noteworty supported arguments are --checkpoint-dir and --test, used for evaluating the model.

Here’s the code that supports model-specific arguments:

Adding common arguments and arguments for run
def add_common_args(parser: argparse.ArgumentParser):
    """
    Adds common arguments to the given ArgumentParser object.

    Args:
        parser (argparse.ArgumentParser): The ArgumentParser object to add the arguments to.

    Returns:
        None
    """
    parser.add_argument('--num-classes',
                        type=int,
                        default=10,
                        help="Number of output classes (default=10)")
    parser.add_argument('--num-features',
                        type=int,
                        default=784,
                        help="Number of input features (default=784)")
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        help="Learning rate (default=0.1)")
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=32,
                        help="Batch size (default=32)")
    parser.add_argument('--momentum',
                        type=float,
                        default=0.0,
                        help="Momentum (default=0.0)")
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.01,
                        help="Weight decay (default=0.01)")
    parser.add_argument('--print-params',
                        action='store_true',
                        default=False,
                        help="Print the model parameters (default=False)")


def add_run_args(parser: argparse.ArgumentParser):
    """
    Add runtime arguments to the parser.

    Args:
        parser (argparse.ArgumentParser): The parser to which the arguments will be added.

    Returns:
        None
    """
    parser.add_argument('-e', '--num-epochs', type=int, default=1)
    parser.add_argument('--log-path', type=str, default='checkpoints')
    parser.add_argument('--test',
                        action="store_true",
                        help="Test the trained model")
    parser.add_argument('--init-ckpt-path',
                        type=str,
                        default='',
                        help='Path to load checkpoint')
    parser.add_argument('--ckpt-dir',
                        type=str,
                        default=os.getcwd(),
                        help='Path to save checkpoint')
    parser.add_argument('--data-dir',
                        type=str,
                        default='./data',
                        help="Directory containing datasets")
    parser.add_argument('--dataset-name',
                        type=str,
                        help="Dataset name: train, t10k, inference, etc.")
    parser.add_argument('--results-dir',
                        type=str,
                        default='./results',
                        help="Directory to store inference results")

Create a model that uses the SambaFlow framework

To set up the model to use the SambaFlow framework, we call samba.from_torch_model. You find the documentation in the API ReferenceExternal link. There’s also some discussion in the conversion example.

Use an optimizer

We’re setting the optimizer variable to an instance of the SGD optimizer with a learning rate of 0.0. In inference mode, we’ll set optimizer to None because it doesn’t make sense to perform optimization during inference.

An optimizer is an algorithm that adjusts the parameters (weights and biases) of a machine learning model during the training process. The goal of an optimizer is to find the optimal set of parameters that minimize the model’s loss. When you train a SambaFlow model, information about the loss is sent to stdout.

Compile the model

Compilation generates a PEF file that encapsulates the dataflow graph on the RDU. You have to compile a model before you perform training or inference:

  • For training lenet, you compile the model, and then pass in the generated PEF to train the model.

  • For large models, compilation for inference is a separate step.

    • First you compile for training, and perform one or more training runs with that PEF.

    • Then you compile for inference, and perform inference runs with new data.

For compilation with this tutorial, we pass in the model, the inputs, and the optimizer.

Compile the model
if args.command == "compile":
        samba.session.compile(model,
                              inputs,
                              optimizer,
                              name='lenet',
                              app_dir=utils.get_file_dir(file),
                              squeeze_bs_dim=True,
                              config_dict=vars(args),
                              pef_metadata=get_pefmeta(args, model))

Run training, test, or inference

Training offers users different options, based on what was passed in on the command line:

  • If the user calls run --inference, the code calls first prepare() and then batch_predict().

  • If the user calls run --test, the code calls first prepare() and then test().

  • If the user calls run --train, the code calls first prepare() and then test().

Here’s how this looks inside main(). We’ll discuss the different functions next.

Run the model
elif args.command == "run":
        if args.inference:
            prepare(model, optimizer, params)
            batch_predict(model, params['dataset_name'], params)
        elif args.test:
            prepare(model, optimizer, params)
            test(model, params['dataset_name'], params)
        else:
            prepare(model, optimizer, params)
            train(model, optimizer, params)

Let’s look at each function in more detail.

prepare() function

The prepare() function loads the model and a checkpoint if one was passed in on the command line, retrieves other inputs, and calls trace_graph().

The function will print a warning if no valid initial checkpoint is provided.

prepare() function
def prepare(model: nn.Module, optimizer, params):
    """
    Prepares the model by loading a checkpoint and tracing the graph.

    Args:
        model (nn.Module): The model to prepare.
        optimizer: The optimizer for the model.
        params: A dictionary of parameters.

    Returns:
        None
    """

    # We need to load the checkpoint first and then trace the graph to sync the weights from CPU to RDU
    if params['init_ckpt_path']:
        load_checkpoint(model, optimizer, params['init_ckpt_path'])
    else:
        print('[WARNING] No valid initial checkpoint has been provided')

    inputs = get_inputs(params)
    utils.trace_graph(model,
                      inputs,
                      optimizer,
                      pef=params['pef'],
                      mapping=params['mapping'])

train() function (includes checkpointing)

The train() function for lenet includes checkpointing. Here’s an overview of what happens:

  1. The user passes in the model to be trained, optimizer, and other parameters.

  2. train() gets the dataset.

  3. To train the model, we look at the model state (completed steps and completed epochs) and print the information to stdout.

  4. The actual training proceeds one step at a time.

  5. As a final step, train() calls save_checkpoint(). It’s a best practice to do training in batches and pass a previously trained checkpoint to the train() function to incrementally improve the results.

train() function
def train(model: LeNet, optimizer, params) -> None:
    """
    Trains the given model using the specified optimizer and parameters.

    Args:
        model (LeNet): The model to be trained.
        optimizer: The optimizer to be used during training.
        params: A dictionary containing the parameters for training.

    Returns:
        None
    """
    if params['dataset_name'] is None:
        dataset_name = "train"
    else:
        dataset_name = params['dataset_name']
    data_dir = Path(params['data_dir'])
    print(f"Using dataset: {data_dir / dataset_name}")
    train_dataset = get_dataset(dataset_name, params)
    train_loader = DataLoader(train_dataset,
                              batch_size=params['batch_size'],
                              drop_last=True,
                              shuffle=True)

    # Train the model
    current_step = model.state['completed_steps']
    current_epoch = model.state['completed_epochs']
    total_steps = len(train_loader) * params['num_epochs']
    if current_epoch == params['num_epochs']:
        print(
            f"Epochs trained: {current_epoch} is equal to epochs requested: {params['num_epochs']}. Exiting..."
        )
        return
    print("=" * 30)
    print(f"Initial epoch: {current_epoch:3n}, initial step: {current_step:6n}")
    print(
        f"Target epoch:  {params['num_epochs']:3n}, target step:  {total_steps:6n}"
    )
    hyperparam_dict = {
        "lr": params['lr'],
        "momentum": params['momentum'],
        "weight_decay": params['weight_decay']
    }
    for epoch in range(current_epoch + 1, params['num_epochs'] + 1):
        avg_loss = 0
        for i, (images, labels) in enumerate(train_loader):
            sn_images = samba.from_torch_tensor(images,
                                                name='image',
                                                batch_dim=0)
            sn_labels = samba.from_torch_tensor(labels,
                                                name='label',
                                                batch_dim=0)

            loss, outputs = samba.session.run(
                input_tensors=[sn_images, sn_labels],
                output_tensors=model.output_tensors,
                hyperparam_dict=hyperparam_dict,
                data_parallel=params['data_parallel'],
                reduce_on_rdu=params['reduce_on_rdu'])
            loss, outputs = samba.to_torch(loss), samba.to_torch(outputs)
            avg_loss += loss.mean()
            current_step += 1

            if (i + 1) % 100 == 0:
                log_step(epoch, params['num_epochs'], current_step, total_steps,
                         avg_loss / (i + 1))

    current_epoch = epoch

    samba.session.to_cpu(model)
    save_checkpoint(model, optimizer, current_step, current_epoch,
                    params['ckpt_dir'])

test() function

As part of an AI model, you need functions that perform training, testing, and inference.

  • Testing passes a test dataset to a trained model, and determines how accurate the model’s result are. The test dataset has a parallel structure to the training dataset, for example, images with labels or text with labels.

  • Inference passes a validation dataset to a trained model. In this tutorial, we pass in a checkpoint of a trained model, and we pass in a dataset that includes images but no labels.

test() function
def test(model, dataset_name, params):
    """
    Calculates the test accuracy and loss for the given model and dataset.

    Parameters:
        model (object): The model to be tested.
        dataset_name (str): The name of the dataset to be used.
        params (dict): A dictionary of parameters.

    Returns:
        None
    """
    if dataset_name is None:
        dataset_name = "t10k"
    data_dir = Path(params['data_dir'])
    print(f"Using dataset: {data_dir / dataset_name}")
    test_dataset = get_dataset(dataset_name, params)
    test_loader = DataLoader(test_dataset,
                             drop_last=True,
                             batch_size=params['batch_size'])

    samba.session.to_cpu(model)
    test_acc = 0.0
    with torch.no_grad():
        correct = 0
        total = 0
        total_loss = 0
        for images, labels in test_loader:
            loss, outputs = model(images, labels)
            loss, outputs = samba.to_torch(loss), samba.to_torch(outputs)
            total_loss += loss.mean()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()

        test_acc = 100.0 * correct / total
        print('Test Accuracy: {:.2f}'.format(test_acc),
              ' Loss: {:.4f}'.format(total_loss.item() / (len(test_loader))))

batch_predict() function

The batch_predict() function performs inference with the trained model. We pass in data that the model hasn’t seen before. The function is similar to the test() function.

batch_predict() function
def batch_predict(model, dataset_name: str, params):
    """
    Generates the predictions for a given model on a dataset.

    Args:
        model (object): The trained model to use for prediction.
        dataset_name (str): The name of the dataset to use for prediction.
        params (dict): Additional parameters for the prediction.

    Returns:
        None
    """
    if dataset_name is None:
        dataset_name = "inference"
    data_dir = Path(params['data_dir'])
    print(f"Using dataset: {data_dir / dataset_name}")
    dataset = get_dataset(dataset_name, params)

    loader = DataLoader(dataset,
                        batch_size=params.get('batch_size', 32),
                        drop_last=True,
                        shuffle=False)

    predicted_labels = []
    for _, (images, labels) in enumerate(loader):
        sn_images = samba.from_torch_tensor(images, name='image', batch_dim=0)
        sn_labels = samba.from_torch_tensor(labels, name='label', batch_dim=0)

        loss, predictions = samba.session.run(
            input_tensors=[sn_images, sn_labels],
            output_tensors=model.output_tensors,
            section_types=['fwd'])
        loss, predictions = samba.to_torch(loss), samba.to_torch(predictions)
        _, predicted_indices = torch.max(predictions, axis=1)  # type: ignore

        predicted_labels += predicted_indices.tolist()

    # write to the file in the same format labels are stored
    results_dir = Path(params['results_dir'])
    results_dir.mkdir(parents=True, exist_ok=True)
    write_labels(predicted_labels,
                 str(results_dir / "prediction-labels-idx1-ubyte"))

A note on imports

We’re including Pytorch imports as well as SambaFlow imports. Which imports your model needs depends on the model itself.

Imports
import argparse
import os
from pathlib import Path
from typing import Tuple

import sambaflow.samba.utils as utils
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from mnist_utils import CustomMNIST, write_labels
from sambaflow import samba
from sambaflow.samba.utils.argparser import parse_app_args
from sambaflow.samba.utils.pef_utils import get_pefmeta
from torch.utils.data.dataloader import DataLoader

See Imports required by SambaFlow for a discussion of some of the imports.