samba#
- from_torch_model_(model: Module)#
Converts a PyTorch model to a SambaFlow model. PyTorch tensors are converted in place to SambaTensors and parameters are converted to SambaParameters. Call this function before compile-time tracing (
sambaflow.samba.session.compile()
) and runtime tracing (sambaflow.samba.utils.utils.trace_graph()
) because the compiler requires SambaTensors and SambaParameters to perform tracing. This function also registers a SambaFlow name for each submodule and the parameters in the module.Model buffers are not converted to SambaTensors. Instead, they remain as Torch tensors, and
sambaflow.samba.session.buffer_dict
is populated with a mapping from buffer names to the corresponding SambaTensor for the buffer. You can reference buffers via this dictionary.- Parameters:
model – the PyTorch model to convert to a SambaFlow model.
Example:
>>> model = nn.Linear(10, 10) >>> samba.from_torch_model_(model) >>> model.name ... 'linear' >>> model.weight.sn_name ... 'linear__weight'