Code elements of the inference program

The Generative NLP tutorial is available on our public GitHub repo here External link. It includes two main code files, one for training, and one for inference.

  • The code supports compiling and training a model. As part of training, you can generate checkpoints. See Code elements of the training program.

  • The code supports compiling for inference and performing an inference run.

This doc page explores the code.


Inference is the process of making predictions on new data using a trained model. Our code for inference is similar to the code for training, we only made some tweaks to the code. You can expect to make similar tweaks to the training code for your own model.

Several pieces of the code are the same (or slightly different, e.g. different inputs):

  • Configure input arguments for compilation and running inference.

  • Create dummy inputs for graph tracing.

  • Convert the torch tensors (which come from the checkpoint) to SambaTensor instances.

Here’s where the inference model differs:

  • No input data preparation for an inference model.

  • No data loaders, optimizers, or training loop.

  • We call compile --inference, which is faster and results in a smaller PEF than compile for training.

  • Weights are loaded both from the Hugging Face model (as before) and from the checkpoint that was created during model training.

  • The model code processes the prompts that are used to make predictions (in our examples, movie reviews).

It all starts with main()

To support inference, a SambaNova model must go through both compilation and training. You have a choice:

  • Theoretically, you can run inference using the output of compilation for training (PEF file).

  • However, if you compile with the --inference argument, the compiler performs only the forward pass, so the PEF file is much smaller (and compilation is faster).

The following table gives an overview of what main() does.

Function Description See


Collect the arguments coming from add_common_args() and add_run_args(). When users run the model, they specify any of those arguments, plus arguments that come from the compiler (e.g. o0) and arguments that are supported by the SambaFlow framework.

Parse input arguments

AutoConfig.from_pretrained(), AutoModelForCausalLM.from_config()

Download the pretrained model from Hugging Face. We use the AutoConfig.from_pretrained() and AutoModelForCausalLM.from_config() Hugging Face function because we want to use our own configuration and not the configuration prespecified by Hugging Face.

[Pull the pretrained model and configuration]


Patch some parts of the Hugging Face model to improve performance on RDU.

Improve efficiency on RDU by patching


Convert the model to use the SambaFlow framework. The function performs some initialization and related tasks.

Convert the model to use the SambaFlow framework


Convert torch tensors to SambaTensor instances. The checkpoint file that we pass in contains torch tensor instances. We have to convert them for inference on RDU.

Create dummy tensors for compilation


If the user specifies compile --inference on the command-line, compile the model for inference, passing in the checkpoint (saved state of the trained model), inputs, etc.

Define model compilation for inference


If the user specifies run --inference and an input file on the command-line, perform inference on the input data by calling generate(). In this example, we pass in a set of unlabeled movie reviews. We can then examine the output to see if our training run with labeled movie reviews has made the model smart enough to generate good results.

Define the inference process

Here’s the code for main().

main() for inference (compile and run)
def main(argv: List[str]) -> None:
    # Parse the args
    args = parse_app_args(argv=argv, common_parser_fn=add_common_args,

    # Download the model from Hugging Face
    if args.config_name:
        config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)
        model = AutoModelForCausalLM.from_config(config)
    elif args.model_name_or_path:
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
        raise RuntimeError("Must provide --model_name_or_path or --config_name")

    # Patch the model here
    model = patch_model(model, args)


    inputs = get_model_trace_inputs(args)

    if args.command == 'compile':
        samba.session.compile(model, inputs, app_dir=samba.utils.get_file_dir(file),
    elif args.command == 'run':
        traced_outputs = utils.trace_graph(model, inputs, pef=args.pef)
        predictions = generate(args, model, traced_outputs)
        print(*predictions, sep=f"\n{'-' * 20}\n")

if name == "main":

Parse input arguments

Users can call with input arguments to affect its behavior.

The list of arguments for compile for inference is shorter than for training because we don’t worry about weights or other training-specific items. Some arguments mean something different, for example, for training, checkpoint_name is to the checkpoint to generate, during inference, checkpoint_name is the checkpoint to pass in.

Here’s the code we use to support model-specific arguments for inference:

Adding arguments for compile and run
def add_common_args(parser: argparse.ArgumentParser):
    """Adds common arguments to an ArgumentParser object

        parser (argparse.ArgumentParser): The argument parser object to add arguments to
                        help='Path to pretrained model or model identifier
                        help='Path to pretrained model config or model identifier
                        help='Where to store pretrained models and data downloaded'
                        help='The maximum total input sequence length after tokenization. '
                        'Data in your data dir will be truncated or padded to this length. ')
                        help='The number of prompts to run generation on')

def add_run_args(parser: argparse.ArgumentParser):
    """Adds arguments used at runtime to an argument parser object

        parser (argparse.ArgumentParser): The argument parser object to add arguments to
        help='Path to a .json file, .jsonl file or a directory containing .jsonl files. '
        'Each json object should contain a "prompt" key of text used '
        'for prompting text generation.')
                        help='Maximum number of tokens to generate after each prompt.')
        help='Path to a checkpoint containing weights with names matching those provided '
        'by the --model_name_or_path')

Pull the pretrained models

Hugging Face supports two functions for pulling a model:

  • AutomodelForCausalLM.from_pretrained() uses the Hugging Face model and its configuration.

  • AutoConfig.from_pretrained() and AutoModelForCausalLM.from_config(), used together, allow us to use our own configuration file.

Pull the model
 if args.model_name_or_path:
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
            cache_dir=args.cache_dir) = True
    elif args.config_name:
        config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)
        # Read dropout rate from config
        args.dropout = config.resid_pdrop
        model = AutoModelForCausalLM.from_config(config)

For this tutorial, the configuration is stored as config/gpt2_small_config.json. We’ve fine-tuned those numbers to work well on the RDU architecture.

Configuration file: gpt2_small_config.json
  "activation_function": "gelu_new",
  "architectures": [
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
  "vocab_size": 50257

Improve efficiency on RDU by patching

Hugging Face supports the concept of patching a model. In our tutorial, we use patching to make the model run more efficiently on RDUs. In the file, we’ve defined gpt2_patch_helper, which patches module forward calls within a gpt2-based transformer model.

The model would still run without patching, but the optimization improves performance. The function is the same for training and inference.

Patch the model
def patch_model(model: nn.Module, args: argparse.Namespace) -> nn.Module:
    """Patch the Hugging Face model to make it more efficient when running on RDU.

        model (nn.Module): The Hugging Face model instance
        args (argparse.Namespace): The parsed command line args

        nn.Module: The patched model instance
    return gpt2_patch_helper(model,

Convert the model to use the SambaFlow framework

To convert the model to use the SambaFlow framework, we call samba.from_torch_model. You find the documentation in the API reference.

Create dummy tensors for compilation

The SambaFlow compiler maps the model graph onto an RDU. The compiler traces how the model’s input tensor’s change shape and produces the final output tensors. For tracing, the compiler doesn’t require actual data, but does need tensors of the same shape.

We achieve this with a call to get_model_trace_inputs() which:

  • Calls samba.from_torch_tensor(), which takes torch tensors as input and generates SambaTensor instances as output. The RDU manipulates SambaTensor instances, not torch tensors.

  • Performs other conversion tasks.

In contrast to the same get_model_trace_inputs() for training, this function does not include labels but does include the attention masks.

The conversion from torch tensor to SambaTensor is a minimum requirement for any model you want to run on RDU.
Create dummy tensors for tracing
def get_model_trace_inputs(args: argparse.Namespace) -> Tuple[Any]:
    """Get input tensors to use for tracing the model.

    Since they're only used for tracing, these tensors are composed of dummy data.

        args (argparse.Namespace): Parsed command line arguments

        Tuple[Any]: Inputs to use for tracing

    batch_size = args.batch_size
    length = args.max_seq_length

    assert batch_size == 1, "Only batch size 1 is supported at the moment"

    # Input IDs
    input_ids = torch.randint(0, 5000, (batch_size, length)).int()
    input_ids = samba.from_torch_tensor(input_ids, name='input_ids')

    # Position IDs
    position_ids = torch.arange(length)
    position_ids = position_ids.short()
    position_ids = samba.from_torch_tensor(
        position_ids.unsqueeze(0).expand(input_ids.shape), name='input_position_ids')

    # Attention Mask
    # Prepare the attention mask for the Hugging Face Module
    attention_mask = torch.randint(2, (batch_size, length), dtype=torch.bfloat16)
    attention_mask = attention_mask[:, None, :].to(torch.float32)
    attention_mask_name = 'attention_mask'
    attention_mask = samba.from_torch_tensor(attention_mask, name=attention_mask_name)

    # Items in traced_inputs match the order of inputs to forward() for the model
    traced_inputs = (input_ids, None, attention_mask, None, position_ids, None, None, None)

    return traced_inputs

Define model compilation for inference

The compile function is defined in SambaSession.compile(). See the API Reference for details.

We call compile() with the following arguments:

  • model is the RDU-ready model created by samba.from_torch_model()

  • inputs are returned by get_model_trace_inputs().

  • app_dir is the location of the model.

  • inference is deprecated because the user input determines if we compile for inference.

Compile the model
 if args.command == 'compile':

Define the inference process

The goal of our tutorial model is inference, the process of making predictions on new data using a trained model, and not generation, the creation of new data using a generative model. However, the GPT-2 model we are using in this tutorial, uses the generate() function for either inference or generation.

To define the inference process, we call generate() with the user-specified arguments, the model, and certain outputs of compilation. We then print prediction information.

Perform inference
elif args.command == 'run':
        traced_outputs = utils.trace_graph(model, inputs, pef=args.pef)
        predictions = generate(args, model, traced_outputs)
        print(*predictions, sep=f"\n{'-' * 20}\n")

Generate function overview

The generate() function replaces the model’s internal forward call with a call to model_rdu_step(). We patch the Hugging Face forward function with `model_rdu_step() so it always runs on RDU.

The function performs these tasks:

  1. Load checkpoints.

  2. Define a single model step on RDU (model_rdu_step()).

    • In contrast to the training loop, which defines a separate and more complex model_step() function, we define model_rdu_step inside generate()

    • generate() handles preparing inputs by calling get_runtime_inputs()

    • The call to runs only the forward pass to return the logits.

    • Finally, the single step returns a CausalLMOuputWithCrossAttentions object.

  3. Use the GPT-2 to raw text from the validation file we pass in into token sequences.

  4. Call GenerativeDataset() to convert the validation file, which is in .jsonl format, into a Toch dataset

  5. Iterate over the validation dataset and return tensors.

  6. Decode the generated tokens to generate text output in which a sentiment has been assigned to each movie review. See Compile, train, and perform inference with a Hugging Face GPT model for example output.

Generate function
def generate(args: argparse.Namespace, model: nn.Module, traced_outputs: Tuple[SambaTensor]):
    """Generate some outputs from the model, hooking into the Hugging Face generate function.

        args (argparse.Namespace): The parsed command line arguments
        model (nn.Module): The transformer model instance
        traced_outputs (Tuple[SambaTensor]): Output tensors generated by the tracing process

        List[str]: A list of predictions from the model

    # Load the checkpoint
    if args.checkpoint_name:
        load_checkpoint(model, args.checkpoint_name)

    # Define the internal forward pass in terms of
    def model_rdu_step(self, *input, **kwargs):
        input_id_length = kwargs['input_ids'].shape[1]
        samba_inputs = get_runtime_inputs(kwargs, args.max_seq_length)

        output_logits =,
                                          hyperparam_dict={'p': 0.0},
        logits = samba.to_torch(output_logits)[:, :input_id_length, :].float()
        return CausalLMOutputWithCrossAttentions(loss=None, logits=logits)

    # Replace the model's internal forward call with the RDU step call so
    # model_rdu_step is automatically called during generation.
    # The Hugging Face model generate function calls the model's forward function
    # to generate text. This function runs the model on CPU.
    # To make it run on RDU, we patch the forward function with model_rdu_step
    base_model_class = model.class
    base_model_class.torch_call = = model_rdu_step

    # Make a tokenizer. The model checkpoint folder has vocab.json (tokenizer info)
    # and merges.txt files
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)

    # Make a dataset from a .jsonl file or folder of .jsonl files
    dataset = GenerativeDataset(args.data_dir)
    predictions = []

    # Generate predictions from the model
    for k, example in enumerate(dataset):
        if k >= args.examples_to_generate:
        # Tokenize inputs
        model_inputs = tokenizer(example['prompt'], return_tensors='pt')
        input_ids = model_inputs['input_ids']
        input_length = input_ids.shape[-1]

        # Hook into HF model.generate to generate predictions.
        # The above call patching will ensure the model runs on RDU
        generated_ids = model.generate(model_inputs['input_ids'],
            max_length=input_length + args.max_tokens_to_generate)
        generated_text = tokenizer.decode(generated_ids.squeeze(0))

    return predictions

get_runtime_inputs() function

As part of generate(), we call get_runtime_inputs(). This function is similar to which performs these tasks:

  • Creates input IDs

  • Creates the attention mask

  • Creates position IDs

  • Converts torch tensors to SambaTensor instances

get_runtime_inputs() function
def get_runtime_inputs(inputs: Dict[str, List[Any]], max_seq_length: int) ->
    """Given inputs from the dataset, create inputs for

    These inputs must be the same dtype and shape as the compile inputs

        inputs (Dict[str, List[Any]]): Inputs from the data loader
        max_seq_length (int): The max sequence length that the PEF supports

        Sequence[Optional[samba.SambaTensor]]: The named input tensors to use
        in running the model

    # Create input_ids
    input_ids = inputs['input_ids']

    # Pad the inputs to the appropriate max sequence length
    input_ids = F.pad(input_ids, (0, max_seq_length - input_ids.shape[1]))
    input_ids = samba.from_torch_tensor(, name="input_ids")

    # Create attention_mask
    attention_mask = inputs['attention_mask']
    attention_mask = F.pad(attention_mask, (0, max_seq_length - attention_mask.shape[1]))
    attention_mask = attention_mask[:, None, :].to(torch.float32)
    attention_mask = samba.from_torch_tensor(attention_mask, name='attention_mask')

    # Create position_ids
    position_ids_torch = torch.arange(max_seq_length).short()
    position_ids = samba.from_torch_tensor
        (position_ids_torch.unsqueeze(0).expand(input_ids.shape), name='input_position_ids')

    # Runtime traced inputs match the compile time inputs
    traced_inputs = (input_ids, None, attention_mask, None, position_ids, None, None, None)

    return traced_inputs