Special Lambda Operations¶
- sn_fma(a: SambaTensor, b: SambaTensor, c: SambaTensor) SambaTensor: ¶
Performs a fused multiply add operation. Equivalent to
\[a \times b + c\]Example
>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_fma >>> import sambaflow.samba as samba >>> compute_fma = lambda attrs, x, y, z: sn_fma(x, y, z) >>> samba.set_seed(1) >>> x = samba.randn((2,5), dtype=torch.bfloat16) >>> y = samba.randn((2,5), dtype=torch.bfloat16) >>> z = samba.randn((2,5), dtype=torch.bfloat16) >>> out = sn_zipmapreduce(compute_fma, [x, y, z]) >>> out.data tensor([[ 3.0000, -1.8438, 1.5859, -1.8750, 0.1357], [-0.5234, 2.0312, 2.3438, 2.0312, -0.9180]], dtype=torch.bfloat16)
Note
This operation is only supported when used in a function/lambda that is input to
sn_zipmapreduce()
- sn_imm(input: float | int, dtype: torch.dtype | SNType) SambaTensor: ¶
Creates an tensor containing a constant.
- Parameters:
input – input tensor.
dtype – dtype of the input.
- sn_iteridx(attrs: dict, dim: int, dtype: SNType | None) SambaTensor ¶
creates an iterator within the body of a
sn_zipmapreduce()
func
.idx_dim
specifies the dimension that is iterated on, of the output tensor. For example, ifsn_zipmapreduce()
has two inputs, [A,1], [1,B] andreduce_dim
set to 1, the tensor shape after broadcast will be [A,B]. The tensor shape after reduce will be [A,1]. Then there will be two iterators available within thefunc
body:sn_iteridx(dim=0) = (0 until A by 1) sn_iteridx(dim=1) = (0 until B by 1)
sn_iteridx()
can be treated as an input into thesn_zipmapreduce()
func
and be used in thefunc
body.- Parameters:
attrs – attrs dictionary passed to the calling
sn_zipmapreduce()
.dim – dimension of the broadcasted shape.
dtype (optional) –
datatype of the iterator.
dtype
must be either a signed or unsigned integer datatype and its bit-width must match the bit-width of the inputs tosn_zipmapreduce()
, the operation belongs to. If unspecified, then this is inferred based on the inputs to thefunc
passed to thesn_zipmapreduce()
the operation belongs to. Note that the sign of the datatype is also inferred based on inputs i.e for signed inputs, the dtype will be int and for unsigned it will be uint. Thus, forint16
inputs the iterator can only count up to \(2^{15} - 1\), and so the output dimension must fit in \(2^{15}\). The following types are supported:SNType.UINT16
SNType.INT16
SNType.INT32
Example
>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_iteridx, sn_select, sn_imm >>> from sambaflow.samba.utils import SNType >>> x = samba.SambaTensor(shape=(3, 3), dtype=torch.int32) >>> # Example: Creating a upper triangular matrix >>> def upper_tri(attrs, x): ... # if specified, dtype must match bitwidth of other dtypes in the lambda ... dim_r = sn_iteridx(attrs=attrs, dim=0, dtype=SNType.INT32) ... dim_c = sn_iteridx(attrs=attrs, dim=1, dtype=SNType.INT32) ... mask = dim_r <= dim_c ... one_imm = sn_imm(1, dtype=torch.int32) ... zero_imm = sn_imm(0, dtype=torch.int32) ... return sn_select(mask, one_imm, zero_imm) >>> diag_matrix = sn_zipmapreduce(upper_tri, [x]) >>> diag_matrix.data tensor([[1, 1, 1], [0, 1, 1], [0, 0, 1]], dtype=torch.int32)
Note
This operation is only supported when used inside a Python function/lambda that is input to
sn_zipmapreduce()
- sn_select(cond: SambaTensor, true_val: SambaTensor, false_val: SambaTensor) SambaTensor ¶
Performs a select operation on tensors, similar to the
torch.where()
function. This function selects elements fromtrue_val
orfalse_val
based on the condition specified incond
.- Parameters:
cond – a mask tensor where each element is a condition. If the condition is true, the corresponding element from
true_val
is selected; otherwise the element fromfalse_val
is chosen.true_val – the tensor from which elements are selected when the corresponding condition in
cond
is true.false_val – the tensor from which elements are selected when the corresponding condition in
cond
is false.
Example
>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_select >>> # Example showing the use of sn_select with tensors. >>> from sambaflow.samba.functional import sn_select >>> mask = samba.SambaTensor(torch.tensor([1, 0, 1], dtype=torch.int32)) >>> x = samba.SambaTensor(torch.tensor([1, 2, 3], dtype=torch.float)) >>> y = samba.SambaTensor(torch.tensor([4, 5, 6], dtype=torch.float)) >>> f = lambda attrs, mask, x, y: sn_select(mask, x, y) >>> result = sn_zipmapreduce(f, [mask, x, y]) >>> result.data tensor([1., 5., 3.]) >>> # result contains elements from x or y based on the mask. >>> # Here 1 and 3 are selected from x and 5 from y.
Note
This operation is only supported when used inside a Python function/lambda that is input to
sn_zipmapreduce()
See also
For more details see
torch.where()
.