samba.parallel_patterns#

Parallel Patterns#

sn_fma(a: SambaTensor, b: SambaTensor, c: SambaTensor) SambaTensor:#

New in version 1.18.

Performs a fused multiply add operation. Equivalent to

\[a \times b + c\]

Example

>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_fma
>>> import sambaflow.samba as samba
>>> compute_fma = lambda attrs, x, y, z: sn_fma(x, y, z)
>>> samba.set_seed(1)
>>> x = samba.randn((2,5), dtype=torch.bfloat16)
>>> y = samba.randn((2,5), dtype=torch.bfloat16)
>>> z = samba.randn((2,5), dtype=torch.bfloat16)
>>> out = sn_zipmapreduce(compute_fma, [x, y, z])
>>> out.data
tensor([[ 3.0000, -1.8438,  1.5859, -1.8750,  0.1357],
        [-0.5234,  2.0312,  2.3438,  2.0312, -0.9180]], dtype=torch.bfloat16)

Note

This operation is only supported when used in a function/lambda that is input to sn_zipmapreduce()

sn_gather(input_tensor: SambaTensor, start_indices: SambaTensor, gather_dims: List[int], gather_lengths: List[int]) SambaTensor#

Performs a general gather operation. Supports multiple gather dimensions, different start indices, and different gather lengths.

New in version 1.18.

Parameters:
  • input_tensor (SambaTensor) – input tensor to gather from

  • start_indices (SambaTensor) – tensor of indices along each dimension to gather from, of shape MxNx…xD where D is the number of specified dimensions

  • gather_dims – (List[int]): List of D dimensions to gather along. Missing dims will be gathered in full

  • gather_lengths – (List[int]): List of D lengths correspond to each dim in gather_dims. This will represent how many values to gather from each dim starting from the start index in start_indices. Missing dims will be gathered in full.

Example

>>> from sambaflow.samba.functional import sn_gather
>>> # In this example, the input_tensor, x, has a shape of (5, 5).
>>> # sn_gather is gather over dimensions 0 and 1 for length 2, 3 respectively based on start indices.
>>> # The result is y, an output tensor, with 3 (2x3) blocks, resulting in a shape of (3,2,3).
>>> x = samba.SambaTensor(torch.tensor([[0,1,2,3,4],
...                                    [5,6,7,8,9 ],
...                                    [10,11,12,13,14],
...                                    [15,16,17,18,19],
...                                    [20,21,22,23,24]]))
>>> s_in = samba.SambaTensor(torch.tensor([[0,0],[1,1],[3,2]]))
>>> g_dims = [0,1]
>>> g_len = [2,3]
>>> y = sn_gather(x, s_in, g_dims, g_len)
>>> y.data
tensor([[[ 0,  1,  2],
         [ 5,  6,  7]],

        [[ 6,  7,  8],
         [11, 12, 13]],

        [[17, 18, 19],
         [22, 23, 24]]])
sn_imm(input: float | int, dtype: torch.dtype | SNType) SambaTensor:#

New in version 1.18.

Creates an immediate tensor.

Parameters:
  • input (float or int) – input tensor

  • dtype (torch.dtype or SNType) – dtype of the input

sn_iteridx(attrs: dict, dim: int, dtype: SNType | None) SambaTensor#

New in version 1.18.

Create an iterator within the body of a sn_zipmapreduce() func. idx_dim specifies the dimension that is iterated on, of the output tensor. For example, if sn_zipmapreduce() has two inputs, [A,1], [1,B] and reduce_dim set to 1, the tensor shape after broadcast will be [A,B]. And the tensor shape after reduce will be [A,1]. Then there will be two iterators available within the func body:

sn_iteridx(dim=0) = (0 until A by 1)
sn_iteridx(dim=1) = (0 until B by 1)

sn_iteridx() can be treated as an input into the sn_zipmapreduce() func and be used in the func body.

Parameters:
  • attrs (dict) – attrs dictionary passed to the calling sn_zipmapreduce()

  • dim (int) – dimension of the broadcasted shape

  • dtype (Optional[SNType]) –

    datatype of the iterator. dtype must be either a signed or unsigned integer datatype and its bit-width must match the bit-width of the inputs to sn_zipmapreduce(), the operation belongs to. If unspecified, then this is inferred based on the inputs to the func passed to the sn_zipmapreduce() the operation belongs to. Note that the sign of the datatype is also inferred based on inputs i.e for signed inputs, the dtype will be int and for unsigned it will be uint. Thus, for int16 inputs the iterator can only count up to \(2^{15} - 1\), and so the output dimension must fit in \(2^{15}\). The following types are supported:

    • SNType.UINT16

    • SNType.INT16

    • SNType.INT32

Example

>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_iteridx, sn_select, sn_imm
>>> from sambaflow.samba.utils import SNType
>>> x = samba.SambaTensor(shape=(3, 3), dtype=torch.int32)
>>> # Example: Creating a upper triangular matrix
>>> def upper_tri(attrs, x):
...     # if specified, dtype must match bitwidth of other dtypes in the lambda
...     dim_r = sn_iteridx(attrs=attrs, dim=0, dtype=SNType.INT32)
...     dim_c = sn_iteridx(attrs=attrs, dim=1, dtype=SNType.INT32)
...     mask = dim_r <= dim_c
...     one_imm = sn_imm(1, dtype=torch.int32)
...     zero_imm = sn_imm(0, dtype=torch.int32)
...     return sn_select(mask, one_imm, zero_imm)
>>> diag_matrix = sn_zipmapreduce(upper_tri, [x])
>>> diag_matrix.data
tensor([[1, 1, 1],
        [0, 1, 1],
        [0, 0, 1]], dtype=torch.int32)

Note

This operation is only supported when used inside a Python function/lambda that is input to sn_zipmapreduce()

sn_reduce(input: SambaTensor, dim: Tuple[int, ...], fn: str) SambaTensor | Tuple[SambaTensor, ...]#

New in version 1.18.

Performs a reduce operation that reduces the dimensions of a tensor by applying a specific operation across one or more axes (dimensions) of the tensor. Supports multiple functions.

Parameters:
  • input – the tensor to reduce.

  • dim – a tuple specifying the dimensions over which to perform the reduction.

  • fn

    a string that specifies the reduction function to apply. Supported functions include:

Example

>>> from sambaflow.samba.functional import sn_reduce
>>> # In this example, the tensor x has a shape of (2, 2, 2).
>>> # sn_reduce is summing over dimensions 1 and 2, which are the last two dimensions.
>>> # The result is a tensor y with the sums of each 2x2 block in x, resulting in a shape of (2,).
>>> x = samba.SambaTensor(torch.tensor([[[1.0, 2.0], [3.0, 4.0]],
...                                     [[5.0, 6.0], [7.0, 8.0]]]))
>>> y = sn_reduce(x, (1, 2), "SUM")
>>> y.data
tensor([[[10.]],

        [[26.]]])
sn_scatter(operand: SambaTensor, update: SambaTensor, start_indices: SambaTensor, scatter_dims: List[int], rmw_op: str = 'kUpdate', unique_indices: bool = False) SambaTensor#

New in version 1.18.

Performs a scatter operation into the operand tensor given an update tensor and indices/dims to scatter along. Supports updating, adding, or multiplying the update tensor into the operand tensor

Update is an (L, A, ..., D) tensor that is going to be scattered into the operand tensor. L represents the batch dim, which gives the number of scatters that will take place.

(A, ..., D) will match the dimensions of operand, except the scatter dimensions will be less than or equal to the corresponding dimension in operand

Start_indices will be (L, (len(scatter_dims))) and will represent the indices along each scatter dimension to scatter to.

Parameters:
  • operand (SambaTensor) – The tensor to scatter into

  • update (SambaTensor) – The tensor of values to scatter

  • start_indices (SambaTensor) – The start indices along each dimension to scatter one part of update

  • scatter_dims (List[int]) – The dimensions of operand to scatter along

  • rmw_op (str) –

    The operation to perform between operand & update. Defaults to “kUpdate”. Supported operations include:

    • "kUpdate"

    • "kAdd"

    • "kMin"

    • "kMax"

    • "kMul"

  • unique_indices (bool) – Denotes whether start_indices is expected to contain unique elements. Defaults to False, which means the user does not know or anticipates there to be repeated indices in start_indices. True means the user knows there will be no repeated indices in start_indices

Example

>>> from sambaflow.samba.functional import sn_scatter
>>> # In this example, the operand tensor, x, has a shape of (2, 4, 5).
>>> # sn_scatter is scattered with the update tensor, u , over dimensions 0 and 1  based on start indices
>>> # & rmw_op.
>>> # The result y, is of same shape as the operand tensor.
>>> x = samba.SambaTensor(torch.tensor([
...                                    [[0,1,2,3,4],
...                                    [5,6,7,8,9],
...                                    [10,11,12,13,14],
...                                    [15,16,17,18,19]],
...                                    [[0,1,2,3,4],
...                                    [5,6,7,8,9],
...                                    [10,11,12,13,14],
...                                    [15,16,17,18,19]]
...                                    ]))
>>> u = samba.SambaTensor(torch.tensor([
...                                    [[[-1,-2,-3,-4,-5]]],
...                                    [[[-6,-7,-8,-9,-10]]],
...                                    [[[-11,-12,-13,-14,-15]]]
...                                    ]))
>>> s_in = samba.SambaTensor(torch.tensor([[1,1],[0,2],[1,3]]))
>>> s_dims = [0,1]
>>> rmw_op = "kUpdate"
>>> y = sn_scatter(x, u, s_in, s_dims, rmw_op)
>>> y.data
tensor([[[  0,   1,   2,   3,   4],
         [  5,   6,   7,   8,   9],
         [ -6,  -7,  -8,  -9, -10],
         [ 15,  16,  17,  18,  19]],

        [[  0,   1,   2,   3,   4],
         [ -1,  -2,  -3,  -4,  -5],
         [ 10,  11,  12,  13,  14],
         [-11, -12, -13, -14, -15]]])
sn_select(cond: SambaTensor, true_val: SambaTensor, false_val: SambaTensor) SambaTensor#

New in version 1.18.

Performs a select operation on tensors, similar to the torch.where() function. This function selects elements from true_val or false_val based on the condition specified in cond.

Parameters:
  • cond – a mask tensor where each element is a condition. If the condition is true, the corresponding element from true_val is selected; otherwise, the element from false_val is chosen.

  • true_val – the tensor from which elements are selected when the corresponding condition in cond is true.

  • false_val – the tensor from which elements are selected when the corresponding condition in cond is false.

Example

>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_select
>>> # Example showing the use of sn_select with tensors.
>>> from sambaflow.samba.functional import sn_select
>>> mask = samba.SambaTensor(torch.tensor([1, 0, 1], dtype=torch.int32))
>>> x = samba.SambaTensor(torch.tensor([1, 2, 3], dtype=torch.float))
>>> y = samba.SambaTensor(torch.tensor([4, 5, 6], dtype=torch.float))
>>> f = lambda attrs, mask, x, y: sn_select(mask, x, y)
>>> result = sn_zipmapreduce(f, [mask, x, y])
>>> result.data
tensor([1., 5., 3.])
>>> # result contains elements from x or y based on the mask.
>>> # Here 1 and 3 are selected from x and 5 from y.

Note

This operation is only supported when used inside a Python function/lambda that is input to sn_zipmapreduce()

See also

For more details see torch.where().

sn_zipmapreduce(func: Callable[[Dict[str, Any], Iterable[SambaTensor]], SambaTensor], inputs: List[SambaTensor], reduce_func: str | None = None, reduce_dim: int | List[int] | None = None, stoc_accum: bool | None = False) SambaTensor#

New in version 1.18.

Performs a series of Zip/Map operations to the inputs, specified with the given function, optionally followed by a Reduce operation on the result.

Parameters:
  • func

    A Python function or lambda describing the intended operation, where the parameters are inputs to the function.

    The first parameter (labelled attrs conventionally) to func is a placeholder used by sn_zipmapreduce() to pass special attributes to the operations inside the func body. The parameters following, are assigned from the list of inputs, in the order of their specification.

    For example, if func is set to func = lambda attrs, x, y, z: (x + y) * z and the parameter inputs is ipt, the output will be output = (ipt[0] + ipt[1]) * ipt[2]

    func can have arbitrary number of SambaTensor and a single return value. See table flow for supported operations

  • inputs – List of SambaTensor(s) to use as inputs to the operation.

  • reduce_func (optional) – Reduction function to apply after applying func to the inputs. (Default = None). See the fn attribute of sn_reduce() for supported values.

  • reduce_dim (optional) – Dimension(s) to reduce along. (Default = None)

  • stoc_accum (optional) – Enable/Disable (True/False) Stochastic Accumulation. (Default = False)

Supported Operations inside func#

Operation

Syntax

Unary Operations

Absolute

torch.abs(x) (see torch.abs())

Exponentiation

torch.exp(x) (see torch.exp())

Logarithm

torch.log(x) (see torch.log())

Reciprocal

torch.reciprocal(x) (see torch.reciprocal())

Square Root

torch.sqrt(x) (see torch.sqrt())

Floor

torch.floor(x) (see torch.floor())

Ceil

torch.ceil(x) (see torch.ceil())

Negation

torch.neg(x) (see torch.neg())

Sigmoid

torch.sigmoid(x) (see torch.sigmoid())

Reciprocal Square Root

torch.rsqrt(x) (see torch.rsqrt())

Tanh

torch.tanh(x) (see torch.tanh())

Binary Operations

Addition

x + y (see torch.add())

Subtraction

x - y (see torch.sub())

Multiplication

x * y (see torch.mul())

Division

x / y (see torch.div())

Equal

x == y (see torch.eq())

Greater Than

x > y (see torch.gt())

Greater Than Equal

x >= y (see torch.gte())

Less Than

x < y (see torch.lt())

Less Than Equal

x <= y (see torch.lte())

Not Equal

x != y (see torch.ne())

Elementwise Maximum

torch.maximum(x, y) see torch.maximum()

Elementwise Minimum

torch.minimum(x, y) see torch.minimum()

Ternary Operations

Fused Multiply Add

sn_fma(x, y, z) (see sn_fma())

Conditional Select

sn_select(x, y, z) (see sn_select())

Miscellaneous

Iter Index

sn_iteridx(dim) (see sn_iteridx())

Immediate

sn_imm(val) (see sn_imm())

Note

The sizes of provided SambaTensor(s) in inputs must be compatible with PyTorch Broadcasting Semantics

sn_zipmapreduce() expects all input tensors to have datatypes with the same bit-width.

Examples

>>> from sambaflow.samba.functional import sn_zipmapreduce
>>> # Calculate L2 Distance between two vectors of same size
>>> samba.set_seed(1)
>>> x = samba.ones((5,5))
>>> y = samba.randn((5,5))
>>> dist_func = lambda attrs,x,y: torch.pow(x - y, 2)
>>> out = sn_zipmapreduce(dist_func, [x, y])
>>> out.data
tensor([[6.3786, 3.0633, 2.7357, 6.8094, 1.2104],
        [2.5895, 3.9195, 6.8074, 2.9314, 0.0293],
        [0.5889, 1.1998, 0.7407, 6.6487, 1.7440],
        [1.6883, 0.7732, 0.4413, 0.5251, 0.5131],
        [1.1154, 0.0080, 5.7235, 2.8530, 1.2344]])
>>> # Example with broadcast and reduction applied
>>> from sambaflow.samba.functional import sn_fma, sn_select
>>> # Calculate do_something on the inputs, then sum dimensions 0 and 2 of the result
>>> samba.set_seed(1)
>>> ipt_0 = samba.randn((3, 2, 3))
>>> ipt_1 = samba.randn((1, 2, 1))
>>> # Dimensions 0 and 2 for ipt_1 will be broadcasted
>>> ipt_2 = samba.randn((1, 2, 3))
>>> # Dimension 0 for ipt_2 will be broadcasted
>>> def do_something(attrs, x, y, z):
...     max_x_y = torch.maximum(x, y)
...     max_x_y_z = torch.maximum(max_x_y, z)
...     fma_out = sn_fma(x, y, z)
...     x_gt_y = x > y
...     return sn_select(x_gt_y, fma_out, max_x_y_z)
>>> out = sn_zipmapreduce(do_something, [ipt_0,ipt_1,ipt_2], reduce_func="SUM", reduce_dim=[0,2])
>>> out.data
tensor([[[ 6.8619],
         [11.1231]]])