samba#

from_torch_model_(model: Module)#

Converts a PyTorch model to a SambaFlow model. PyTorch tensors are converted in place to SambaTensors and parameters are converted to SambaParameters. Call this function before compile-time tracing (sambaflow.samba.session.compile()) and runtime tracing (sambaflow.samba.utils.utils.trace_graph()) because the compiler requires SambaTensors and SambaParameters to perform tracing. This function also registers a SambaFlow name for each submodule and the parameters in the module.

Model buffers are not converted to SambaTensors. Instead, they remain as Torch tensors, and sambaflow.samba.session.buffer_dict is populated with a mapping from buffer names to the corresponding SambaTensor for the buffer. You can reference buffers via this dictionary.

Parameters:

model – the PyTorch model to convert to a SambaFlow model.

Example:

>>> model = nn.Linear(10, 10)
>>> samba.from_torch_model_(model)
>>> model.name
... 'linear'
>>> model.weight.sn_name
... 'linear__weight'