Model with an external loss function

Model functions and changes discusses how to convert a PyTorch model that contains a loss function in the model definition (as part of the forward() method). Here, we discuss how use a loss function outside the model definition. You might want to do this if your loss function isn’t currently supported by SambaFlow or if you are using a custom loss function.

With an external loss function we use a host CPU to compute the loss and gradient for backpropagation. We’ll make changes to transport tensors between RDU and CPU.

Most of the changes to the original model are those we discuss in Model functions and changes. The functions below are the functions that are different if you use an external loss function.

forward() used with external loss function

The first change is made to the model’s forward() method. Because loss is no longer computed on the RDU, you don’t need to pass the label tensors to this function and can remove that method parameter and the loss function itself:

forward()
   def forward(self, x: torch.Tensor):
       # Since loss isn't part of the model, we don't pass a label to forward()
       out = self.layer1(x)
       out = self.layer2(out)
       out = out.reshape(out.size(0), -1)
       out = self.drop_out(out)
       out = self.fc1(out)
       out = self.fc2(out)
       return out

Compare this to the original forward() method in the ConvNet class discussed in Define the model.

prepare_dataloader() used with external loss function

The changes in prepare_dataloader() involve the SambaLoaders. By default, a SambaLoader converts each tensor given to it by a DataLoader and passes all of them along. In the case of MNIST, this means both the image and label tensors.

However, our model no longer takes in label tensors, so we filter those out. Labels are used during training - the model uses them to calculate the gradients (loss) that are then used to adjust the model’s weights and biases during the backward pass.

  • When we include the loss function in the model, we have to pass in the labels to the initializer and forward methods of the model so that the gradients and backward pass can be computed on the RDU.

  • When we use an external loss function, we no longer pass the labels to the forward method because the loss isn’t computed on RDU.

We provide an anonymous function to the SambaLoader via the function_hook parameter. The function acts as a filter, removing the tensors that you don’t want. The function must return a list and it must return the same number of tensors as named in the names parameter.

It is possible to retain the original tensors from the DataLoader. If you set the return_original_batch parameter to True, the SambaLoader returns a list that contains the tensors you filtered for and the original tensors, in that order. This allows us to preserve the MNIST labels for use in the loss calculation.

Compare this to the original prepare_dataloader() function in Load data with prepare_dataloader().

prepare_dataloader()
def prepare_dataloader(
    args: argparse.Namespace) -> Tuple[sambaflow.samba.sambaloader.SambaLoader, ...]:
   """
   Transforms MNIST input to tensors and creates training/test dataloaders.

   Downloads the MNIST dataset (if necessary); splits the data into training and test sets;
   transforms the data to tensors; then creates torch DataLoader instances over those sets.
   Each torch DataLoader is wrapped in a SambaLoader.

   Input:
       args: User- and system-defined command line arguments

   Returns:
       A tuple of SambaLoaders over the training and test sets.
   """

   # Transform the raw MNIST data into PyTorch Tensors, which will be converted to SambaTensors
   transform = transforms.Compose(
       [
           transforms.ToTensor(),
           transforms.Normalize((0.1307,), (0.3081,)),
       ]
   )

   # Get the train & test data (images and labels) from the MNIST dataset
   train_dataset = datasets.MNIST(
       root=args.data_path,
       train=True,
       transform=transform,
       download=True,
   )
   test_dataset = datasets.MNIST(root=args.data_path, train=False, transform=transform)

   # Set up the train & test data loaders (input pipeline)
   train_loader = DataLoader(
       dataset=train_dataset, batch_size=args.bs, shuffle=True
   )
   test_loader = DataLoader(
       dataset=test_dataset, batch_size=args.bs, shuffle=False
   )

   # Create SambaLoaders
   sn_train_loader = SambaLoader(
    dataloader=train_loader, names=["image"],
    function_hook=lambda t: [t[0]],
    return_original_batch=True)
   sn_test_loader = SambaLoader(
    dataloader=test_loader, names=["image"],
    function_hook=lambda t: [t[0]],
    return_original_batch=True)

   return sn_train_loader, sn_test_loader

train() used with external loss function

We change the train() method to accommodate an external loss function.

  1. Add a new parameter to the function, allowing a loss function to be passed into it (we will do this in the main() function).

  2. Change the inner training loop: we still loop over the enumerated output from a SambaLoader, but we take an extra step to extract the labels from the original batch.

  3. Change the computation of the model’s forward and backward sections.

    • Modify the samba.session.run() method to only work with image tensors (via the input_tensors parameter) and to only compute the forward section (via setting the section_types parameter to "FWD"). The raw output of the model’s forward() method is captured in the first element of the tuple returned by samba.session.run().

    • We use this output to compute the loss and gradients on the CPU. We pass the output to the CPU via samba.to_torch().

  4. The next few operations are pure PyTorch: set requires_grad to True, call the loss function on the output and labels, and then compute the backward pass.

  5. To finish the computation, we pass the output back from the CPU to the RDU via another call to samba.session.run(). We use the grad_of_outputs parameter, which takes in a list of gradients to be applied in the model’s backward pass on RDU. We set this parameter by calling samba.from_torch_tensor() to convert the output gradients to SambaTensors_.

  6. We set the section_types parameter to a list containing “BCKWD” and “OPT” to run only those model sections on the RDU, thus completing one iteration of the training loop.

Compare this to the original train() function in Train the model.

train()
def train(args: argparse.Namespace, model: nn.Module, criterion: Callable) -> None:
   """
   Trains the model.

   Prepares and loads the data, then runs the training loop with the hyperparameters specified
   by the input arguments with a given loss function.  Calculates loss and accuracy
   over the course of training.

   Inputs:
       args: User- and system-defined command line arguments
       model: ConvNet model
       criterion: Loss function

   Returns:
       None
   """

   sn_train_loader, sn_test_loader = prepare_dataloader(args)
   hyperparam_dict = {"lr": args.learning_rate}

   total_step = len(sn_train_loader)
   loss_list = []
   acc_list = []

   for epoch in range(args.num_epochs):
       for i, (images, original_batch) in enumerate(sn_train_loader):

           # The label tensor is the second element of the original batch
           labels = original_batch[1]

           # Run only the forward pass on RDU and note the section_types argument
           # The first element of the returned tuple contains the raw outputs of forward()
           outputs = samba.session.run(
               input_tensors=(images,),
               output_tensors=model.output_tensors,
               hyperparam_dict=hyperparam_dict,
               section_types=["FWD"]
           )[0]

           # Convert SambaTensors back to torch tensors to carry out loss calculation
           # on the host CPU.  Be sure to set the requires_grad attribute for torch.
           outputs = samba.to_torch(outputs)
           outputs.requires_grad = True

           # Compute loss on host CPU and store it for later tracking
           loss = criterion(outputs, labels)

           # Compute gradients on CPU
           loss.backward()
           loss_list.append(loss.tolist())

           # Run the backward pass and optimizer step on RDU and note the grad_of_outputs
           # and section_types arguments
           samba.session.run(
               input_tensors=(images,),
               output_tensors=model.output_tensors,
               hyperparam_dict=hyperparam_dict,
               grad_of_outputs=[samba.from_torch_tensor(outputs.grad)], # Bring grads back from CPU to RDU
               section_types=["BCKWD", "OPT"])

           # Compute and track the accuracy
           total = labels.size(0)
           _, predicted = torch.max(outputs.data, 1)
           correct = (predicted == labels).sum().item()
           acc_list.append(correct / total)

           if (i + 1) % 100 == 0:
               print(
                   "Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%".format(
                       epoch + 1,
                       args.num_epochs,
                       i + 1,
                       total_step,
                       torch.mean(loss),
                       (correct / total) * 100,
                   )
               )

main() used with external loss function

We only have to make small changes to the main() function.

  • Define the loss function. This could be a built-in PyTorch loss function or a user-defined function. In this example, we call it criterion.

  • Pass this loss function to the training function.

Compare this to the original main() function in Tie the pieces together with main().

main()
def main(argv):

   args = parse_app_args(argv=argv, common_parser_fn=add_user_args)

   # Create the CNN model
   model = ConvNet()

   # Convert model to SambaFlow (SambaTensors)
   samba.from_torch_model_(model)

   # Create optimizer
   # Note that SambaFlow currently supports AdamW, not Adam, as an optimizer
   optimizer = samba.optim.AdamW(model.parameters(), lr=args.learning_rate)

   ###################################################################
   # Define loss function here to be used in the forward pass on CPU #
   ###################################################################
   criterion = nn.CrossEntropyLoss()

   # Create dummy SambaTensor for graph tracing
   inputs = get_inputs(args)

   # The common_app_driver() handles model compilation and various other tasks, e.g.,
   # measure-performance.  Running, or training, a model must be explicitly carried out
   if args.command == "run":
       utils.trace_graph(
        model,
        inputs,
        optimizer,
        init_output_grads=not args.inference,
        pef=args.pef,
        mapping=args.mapping)
       train(args, model, criterion)
   else:
       common_app_driver(
        args=args,
        model=model,
        inputs=inputs,
        optim=optimizer,
        name=model.__class__.__name__,
        init_output_grads=not args.inference,
        app_dir=utils.get_file_dir(__file__))