samba.parallel_patterns

Parallel Patterns

sn_gather(input_tensor: Tensor | SambaTensor, start_indices: Tensor | SambaTensor, gather_dims: List[int], gather_lengths: List[int]) Tensor | SambaTensor

Performs a general gather operation. Supports multiple gather dimensions, different start indices, and different gather lengths.

Parameters:
  • input_tensor – input tensor to gather from, aka the Look Up Table (LUT). This can be any shape.

  • start_indices – tensor of indices which denote the start index to gather from in each gather dimension. Has shape MxNx…xD where: D is the number of specified gather dimensions. MxNx… is the number of sets of start indices, aka the number of gathers being performed. This can be broken up into any number of leading dimensions.

  • gather_dims – List of D dimensions to gather along. Missing dims will be gathered in full.

  • gather_lengths – List of D lengths correspond to each dim in gather_dims. This will represent how many values to gather from each dim starting from the start index in start_indices. Missing dims will be gathered in full.

Example

>>> from sambaflow.samba.functional import sn_gather
>>> # In this example, the input_tensor (LUT), x, has a shape of (5, 5).
>>> # sn_gather is gather over dimensions 0 and 1 for length 2, 3 respectively based on start indices.
>>> # The result is y, an output tensor, with 3 (2x3) blocks, resulting in a shape of (3,2,3).
>>> x = torch.tensor([[0,  1,  2,  3,  4],
...                   [5,  6,  7,  8,  9],
...                   [10, 11, 12, 13, 14],
...                   [15, 16, 17, 18, 19],
...                   [20, 21, 22, 23, 24]])
>>> s_in = torch.tensor([[0,0],[1,1],[3,2]])
>>> g_dims = [0,1]
>>> g_len = [2,3]
>>> y = sn_gather(x, s_in, g_dims, g_len)
>>> y.data
tensor([[[ 0,  1,  2],
         [ 5,  6,  7]],

        [[ 6,  7,  8],
         [11, 12, 13]],

        [[17, 18, 19],
         [22, 23, 24]]])
sn_reduce(input: Tensor, dim: Tuple[int, ...], fn: str) Tensor | Tuple[Tensor, ...]

Performs a reduce operation that reduces the dimensions of a tensor by applying a specific operation across one or more axes (dimensions) of the tensor. Supports multiple functions.

Parameters:
  • input – the tensor to reduce.

  • dim – a tuple specifying the dimensions over which to perform the reduction.

  • fn

    a string that specifies the reduction function to apply. Supported functions include:

Example

>>> from sambaflow.samba.functional import sn_reduce
>>> # In this example, the tensor x has a shape of (2, 2, 2).
>>> # sn_reduce is summing over dimensions 1 and 2, which are the last two dimensions.
>>> # The result is a tensor y with the sums of each 2x2 block in x, resulting in a shape of (2,).
>>> x = torch.tensor([[[1.0, 2.0], [3.0, 4.0]],
...                   [[5.0, 6.0], [7.0, 8.0]]])
>>> y = sn_reduce(x, (1, 2), "SUM")
>>> y.data
tensor([[[10.]],

        [[26.]]])
sn_scatter(operand: Tensor | SambaTensor, update: Tensor | SambaTensor, start_indices: Tensor | SambaTensor, scatter_dims: List[int], rmw_op: str = 'kUpdate', unique_indices: bool = False) Tensor | SambaTensor

Performs a scatter operation into the operand tensor given an update tensor and indices/dims to scatter along. Supports updating, adding, or multiplying the update tensor into the operand tensor

Update is an (G, ..., A, ..., ) tensor that is going to be scattered into the operand tensor. G represents the group dim(s), which gives the number of scatters that will take place. (A, ...,) will match the dimensions of operand, except the scatter dimensions will have size less than the corresponding dimension in operand

Start_indices is (G, ..., (len(scatter_dims))) and represents the indices along each scatter dimension to scatter to.

Parameters:
  • operand – The tensor to scatter into. This can be any shape.

  • update – The tensor of values to scatter. Has a shape of from Nx…Gx…AxBx… G represents “Group Dimensions”. These allow multiple scatters in a single scatter operation AxBx… correspond 1 to 1 with the dimensions of operand. If it IS a scatter dimension, this corresponding dimension will have a size less than the size of the dimension in operand (# of elements scattered along that scatter dimension). If it is NOT a scatter dim, this dimension will have a size equal to the dimension in operand.

  • start_indices – The start indices indicating where to begin each scatter. This has a shape of NxGx…D G represents “Group Dimensions” which are the same as the “Group Dimension” in update D is equal to the number of scatter dimension. There is 1 start index per scatter dimension.

  • scatter_dims – The dimensions of operand to scatter along. Has length D.

  • rmw_op – Stands for Read Modify Write Operation. This denotes which operation to perform between operand & update. A value is read from the operand, the specified operation is performed between this read value and the new value from the update tensor. The result is written to the operand tensor. The default operation is kUpdate which skips the read/modify portion and just performs a normal scatter.

  • include (Supported operations) –

    • "kUpdate" skip Read and Modify portions and just perform a normal scatter (Default)

    • "kAdd" add update value to the value at destination index in operand

    • "kMin" compare old value from operand and new value from update.

      Only scatter if update value is less than value in operand

    • "kMax" compare old value from operand and new value from update.

      Only scatter if update value is greater than value in operand

    • "kMul" multiply update value with the value at destination index in operand

  • unique_indices – Denotes whether start_indices is expected to contain unique elements. Defaults to False, which means the user does not know or anticipates there to be repeated indices in start_indices. True means the user knows there will be no repeated indices in start_indices.

Example

>>> from sambaflow.samba.functional import sn_scatter
>>> # In this example, the operand tensor, x, has a shape of (2, 4, 5).
>>> # sn_scatter is scattered with the update tensor, u , over dimensions 0 and 1  based on start indices
>>> # & rmw_op.
>>> # The result y, is of same shape as the operand tensor.
>>> x = torch.tensor([
...                   [[0,  1,  2,  3,  4],
...                    [5,  6,  7,  8,  9],
...                    [10, 11, 12, 13, 14],
...                    [15, 16, 17, 18, 19]],
...                   [[0,  1,  2,  3,  4],
...                    [5,  6,  7,  8,  9],
...                    [10, 11, 12, 13, 14],
...                    [15, 16, 17, 18, 19]]
...                  ])
>>> u = torch.tensor([
...                   [[[-1,  -2,  -3,  -4,  -5]]],
...                   [[[-6,  -7,  -8,  -9,  -10]]],
...                   [[[-11, -12, -13, -14, -15]]]
...                  ])
>>> s_in = torch.tensor([[1,1],[0,2],[1,3]])
>>> s_dims = [0,1]
>>> rmw_op = "kUpdate"
>>> y = sn_scatter(x, u, s_in, s_dims, rmw_op)
>>> y.data
tensor([[[  0,   1,   2,   3,   4],
         [  5,   6,   7,   8,   9],
         [ -6,  -7,  -8,  -9, -10],
         [ 15,  16,  17,  18,  19]],

        [[  0,   1,   2,   3,   4],
         [ -1,  -2,  -3,  -4,  -5],
         [ 10,  11,  12,  13,  14],
         [-11, -12, -13, -14, -15]]])
sn_zipmapreduce(func: Callable[[Dict[str, Any], Iterable[SambaTensor]], SambaTensor], inputs: List[SambaTensor], reduce_func: str | None = None, reduce_dim: int | List[int] | None = None, stoc_accum: bool | None = False) SambaTensor

Performs a series of Zip/Map operations to the inputs, specified with the given function, optionally followed by a Reduce operation on the result.

Parameters:
  • func

    a Python function or lambda describing the intended operation, where the parameters are inputs to the function.

    The first parameter (labelled attrs conventionally) to func is a placeholder used by sn_zipmapreduce() to pass special attributes to the operations inside the func body. The parameters following, are assigned from the list of inputs, in the order of their specification.

    For example, if func is set to func = lambda attrs, x, y, z: (x + y) * z and the parameter inputs is ipt, the output will be output = (ipt[0] + ipt[1]) * ipt[2].

    func can have arbitrary number of SambaTensor and a single return value. See table flow for supported operations.

  • inputs – list of SambaTensor(s) to use as inputs to the operation.

  • reduce_func (optional) – reduction function to apply after applying func to the inputs. (Default = None). See the fn attribute of sn_reduce() for supported values.

  • reduce_dim (optional) – dimension(s) to reduce along (Default = None).

  • stoc_accum (optional) – enable/disable (True/False) Stochastic Accumulation (Default = False).

Supported Operations inside func

Operation

Syntax

Unary Operations

Absolute

torch.abs(x) (see torch.abs())

Exponentiation

torch.exp(x) (see torch.exp())

Logarithm

torch.log(x) (see torch.log())

Reciprocal

torch.reciprocal(x) (see torch.reciprocal())

Square Root

torch.sqrt(x) (see torch.sqrt())

Floor

torch.floor(x) (see torch.floor())

Ceil

torch.ceil(x) (see torch.ceil())

Negation

torch.neg(x) (see torch.neg())

Sigmoid

torch.sigmoid(x) (see torch.sigmoid())

Reciprocal Square Root

torch.rsqrt(x) (see torch.rsqrt())

Tanh

torch.tanh(x) (see torch.tanh())

Binary Operations

Addition

x + y (see torch.add())

Subtraction

x - y (see torch.sub())

Multiplication

x * y (see torch.mul())

Division

x / y (see torch.div())

Equal

x == y (see torch.eq())

Greater Than

x > y (see torch.gt())

Greater Than Equal

x >= y (see torch.gte())

Less Than

x < y (see torch.lt())

Less Than Equal

x <= y (see torch.lte())

Not Equal

x != y (see torch.ne())

Elementwise Maximum

torch.maximum(x, y) see torch.maximum()

Elementwise Minimum

torch.minimum(x, y) see torch.minimum()

Ternary Operations

Fused Multiply Add

sn_fma(x, y, z) (see sn_fma())

Conditional Select

sn_select(x, y, z) (see sn_select())

Miscellaneous

Iter Index

sn_iteridx(dim) (see sn_iteridx())

Immediate

sn_imm(val) (see sn_imm())

Note

The sizes of provided SambaTensor(s) in inputs must be compatible with PyTorch Broadcasting Semantics

sn_zipmapreduce() expects all input tensors to have datatypes with the same bit-width.

Examples

>>> from sambaflow.samba.functional import sn_zipmapreduce
>>> # Calculate L2 Distance between two vectors of same size
>>> samba.set_seed(1)
>>> x = samba.ones((5,5))
>>> y = samba.randn((5,5))
>>> dist_func = lambda attrs,x,y: torch.pow(x - y, 2)
>>> out = sn_zipmapreduce(dist_func, [x, y])
>>> out.data
tensor([[6.3786, 3.0633, 2.7357, 6.8094, 1.2104],
        [2.5895, 3.9195, 6.8074, 2.9314, 0.0293],
        [0.5889, 1.1998, 0.7407, 6.6487, 1.7440],
        [1.6883, 0.7732, 0.4413, 0.5251, 0.5131],
        [1.1154, 0.0080, 5.7235, 2.8530, 1.2344]])
>>> # Example with broadcast and reduction applied
>>> from sambaflow.samba.functional import sn_fma, sn_select
>>> # Calculate do_something on the inputs, then sum dimensions 0 and 2 of the result
>>> samba.set_seed(1)
>>> ipt_0 = samba.randn((3, 2, 3))
>>> ipt_1 = samba.randn((1, 2, 1))
>>> # Dimensions 0 and 2 for ipt_1 will be broadcasted
>>> ipt_2 = samba.randn((1, 2, 3))
>>> # Dimension 0 for ipt_2 will be broadcasted
>>> def do_something(attrs, x, y, z):
...     max_x_y = torch.maximum(x, y)
...     max_x_y_z = torch.maximum(max_x_y, z)
...     fma_out = sn_fma(x, y, z)
...     x_gt_y = x > y
...     return sn_select(x_gt_y, fma_out, max_x_y_z)
>>> out = sn_zipmapreduce(do_something, [ipt_0,ipt_1,ipt_2], reduce_func="SUM", reduce_dim=[0,2])
>>> out.data
tensor([[[ 6.8619],
         [11.1231]]])