Special Lambda Operations

sn_fma(a: SambaTensor, b: SambaTensor, c: SambaTensor) SambaTensor:

Performs a fused multiply add operation. Equivalent to

\[a \times b + c\]

Example

>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_fma
>>> import sambaflow.samba as samba
>>> compute_fma = lambda attrs, x, y, z: sn_fma(x, y, z)
>>> samba.set_seed(1)
>>> x = samba.randn((2,5), dtype=torch.bfloat16)
>>> y = samba.randn((2,5), dtype=torch.bfloat16)
>>> z = samba.randn((2,5), dtype=torch.bfloat16)
>>> out = sn_zipmapreduce(compute_fma, [x, y, z])
>>> out.data
tensor([[ 3.0000, -1.8438,  1.5859, -1.8750,  0.1357],
        [-0.5234,  2.0312,  2.3438,  2.0312, -0.9180]], dtype=torch.bfloat16)

Note

This operation is only supported when used in a function/lambda that is input to sn_zipmapreduce()

sn_imm(input: float | int, dtype: torch.dtype | SNType) SambaTensor:

Creates an tensor containing a constant.

Parameters:
  • input – input tensor.

  • dtype – dtype of the input.

sn_iteridx(attrs: dict, dim: int, dtype: SNType | None) SambaTensor

creates an iterator within the body of a sn_zipmapreduce() func. idx_dim specifies the dimension that is iterated on, of the output tensor. For example, if sn_zipmapreduce() has two inputs, [A,1], [1,B] and reduce_dim set to 1, the tensor shape after broadcast will be [A,B]. The tensor shape after reduce will be [A,1]. Then there will be two iterators available within the func body:

sn_iteridx(dim=0) = (0 until A by 1)
sn_iteridx(dim=1) = (0 until B by 1)

sn_iteridx() can be treated as an input into the sn_zipmapreduce() func and be used in the func body.

Parameters:
  • attrsattrs dictionary passed to the calling sn_zipmapreduce().

  • dim – dimension of the broadcasted shape.

  • dtype (optional) –

    datatype of the iterator. dtype must be either a signed or unsigned integer datatype and its bit-width must match the bit-width of the inputs to sn_zipmapreduce(), the operation belongs to. If unspecified, then this is inferred based on the inputs to the func passed to the sn_zipmapreduce() the operation belongs to. Note that the sign of the datatype is also inferred based on inputs i.e for signed inputs, the dtype will be int and for unsigned it will be uint. Thus, for int16 inputs the iterator can only count up to \(2^{15} - 1\), and so the output dimension must fit in \(2^{15}\). The following types are supported:

    • SNType.UINT16

    • SNType.INT16

    • SNType.INT32

Example

>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_iteridx, sn_select, sn_imm
>>> from sambaflow.samba.utils import SNType
>>> x = samba.SambaTensor(shape=(3, 3), dtype=torch.int32)
>>> # Example: Creating a upper triangular matrix
>>> def upper_tri(attrs, x):
...     # if specified, dtype must match bitwidth of other dtypes in the lambda
...     dim_r = sn_iteridx(attrs=attrs, dim=0, dtype=SNType.INT32)
...     dim_c = sn_iteridx(attrs=attrs, dim=1, dtype=SNType.INT32)
...     mask = dim_r <= dim_c
...     one_imm = sn_imm(1, dtype=torch.int32)
...     zero_imm = sn_imm(0, dtype=torch.int32)
...     return sn_select(mask, one_imm, zero_imm)
>>> diag_matrix = sn_zipmapreduce(upper_tri, [x])
>>> diag_matrix.data
tensor([[1, 1, 1],
        [0, 1, 1],
        [0, 0, 1]], dtype=torch.int32)

Note

This operation is only supported when used inside a Python function/lambda that is input to sn_zipmapreduce()

sn_select(cond: SambaTensor, true_val: SambaTensor, false_val: SambaTensor) SambaTensor

Performs a select operation on tensors, similar to the torch.where() function. This function selects elements from true_val or false_val based on the condition specified in cond.

Parameters:
  • cond – a mask tensor where each element is a condition. If the condition is true, the corresponding element from true_val is selected; otherwise the element from false_val is chosen.

  • true_val – the tensor from which elements are selected when the corresponding condition in cond is true.

  • false_val – the tensor from which elements are selected when the corresponding condition in cond is false.

Example

>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_select
>>> # Example showing the use of sn_select with tensors.
>>> from sambaflow.samba.functional import sn_select
>>> mask = samba.SambaTensor(torch.tensor([1, 0, 1], dtype=torch.int32))
>>> x = samba.SambaTensor(torch.tensor([1, 2, 3], dtype=torch.float))
>>> y = samba.SambaTensor(torch.tensor([4, 5, 6], dtype=torch.float))
>>> f = lambda attrs, mask, x, y: sn_select(mask, x, y)
>>> result = sn_zipmapreduce(f, [mask, x, y])
>>> result.data
tensor([1., 5., 3.])
>>> # result contains elements from x or y based on the mask.
>>> # Here 1 and 3 are selected from x and 5 from y.

Note

This operation is only supported when used inside a Python function/lambda that is input to sn_zipmapreduce()

See also

For more details see torch.where().