samba.parallel_patterns#
Parallel Patterns#
- sn_fma(a: SambaTensor, b: SambaTensor, c: SambaTensor) SambaTensor: #
New in version 1.18.
Performs a fused multiply add operation. Equivalent to
\[a \times b + c\]Example
>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_fma >>> import sambaflow.samba as samba >>> compute_fma = lambda attrs, x, y, z: sn_fma(x, y, z) >>> samba.set_seed(1) >>> x = samba.randn((2,5), dtype=torch.bfloat16) >>> y = samba.randn((2,5), dtype=torch.bfloat16) >>> z = samba.randn((2,5), dtype=torch.bfloat16) >>> out = sn_zipmapreduce(compute_fma, [x, y, z]) >>> out.data tensor([[ 3.0000, -1.8438, 1.5859, -1.8750, 0.1357], [-0.5234, 2.0312, 2.3438, 2.0312, -0.9180]], dtype=torch.bfloat16)
Note
This operation is only supported when used in a function/lambda that is input to
sn_zipmapreduce()
- sn_gather(input_tensor: SambaTensor, start_indices: SambaTensor, gather_dims: List[int], gather_lengths: List[int]) SambaTensor #
Performs a general gather operation. Supports multiple gather dimensions, different start indices, and different gather lengths.
New in version 1.18.
- Parameters:
input_tensor (SambaTensor) – input tensor to gather from
start_indices (SambaTensor) – tensor of indices along each dimension to gather from, of shape MxNx…xD where D is the number of specified dimensions
gather_dims – (List[int]): List of D dimensions to gather along. Missing dims will be gathered in full
gather_lengths – (List[int]): List of D lengths correspond to each dim in gather_dims. This will represent how many values to gather from each dim starting from the start index in start_indices. Missing dims will be gathered in full.
Example
>>> from sambaflow.samba.functional import sn_gather >>> # In this example, the input_tensor, x, has a shape of (5, 5). >>> # sn_gather is gather over dimensions 0 and 1 for length 2, 3 respectively based on start indices. >>> # The result is y, an output tensor, with 3 (2x3) blocks, resulting in a shape of (3,2,3). >>> x = samba.SambaTensor(torch.tensor([[0,1,2,3,4], ... [5,6,7,8,9 ], ... [10,11,12,13,14], ... [15,16,17,18,19], ... [20,21,22,23,24]])) >>> s_in = samba.SambaTensor(torch.tensor([[0,0],[1,1],[3,2]])) >>> g_dims = [0,1] >>> g_len = [2,3] >>> y = sn_gather(x, s_in, g_dims, g_len) >>> y.data tensor([[[ 0, 1, 2], [ 5, 6, 7]], [[ 6, 7, 8], [11, 12, 13]], [[17, 18, 19], [22, 23, 24]]])
- sn_imm(input: float | int, dtype: torch.dtype | SNType) SambaTensor: #
New in version 1.18.
Creates an immediate tensor.
- Parameters:
input (float or int) – input tensor
dtype (torch.dtype or SNType) – dtype of the input
- sn_iteridx(attrs: dict, dim: int, dtype: SNType | None) SambaTensor #
New in version 1.18.
Create an iterator within the body of a
sn_zipmapreduce()
func
.idx_dim
specifies the dimension that is iterated on, of the output tensor. For example, ifsn_zipmapreduce()
has two inputs, [A,1], [1,B] andreduce_dim
set to 1, the tensor shape after broadcast will be [A,B]. And the tensor shape after reduce will be [A,1]. Then there will be two iterators available within the func body:sn_iteridx(dim=0) = (0 until A by 1) sn_iteridx(dim=1) = (0 until B by 1)
sn_iteridx()
can be treated as an input into thesn_zipmapreduce()
func
and be used in thefunc
body.- Parameters:
attrs (dict) – attrs dictionary passed to the calling
sn_zipmapreduce()
dim (int) – dimension of the broadcasted shape
dtype (Optional[SNType]) –
datatype of the iterator.
dtype
must be either a signed or unsigned integer datatype and its bit-width must match the bit-width of the inputs tosn_zipmapreduce()
, the operation belongs to. If unspecified, then this is inferred based on the inputs to thefunc
passed to thesn_zipmapreduce()
the operation belongs to. Note that the sign of the datatype is also inferred based on inputs i.e for signed inputs, the dtype will be int and for unsigned it will be uint. Thus, forint16
inputs the iterator can only count up to \(2^{15} - 1\), and so the output dimension must fit in \(2^{15}\). The following types are supported:SNType.UINT16
SNType.INT16
SNType.INT32
Example
>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_iteridx, sn_select, sn_imm >>> from sambaflow.samba.utils import SNType >>> x = samba.SambaTensor(shape=(3, 3), dtype=torch.int32) >>> # Example: Creating a upper triangular matrix >>> def upper_tri(attrs, x): ... # if specified, dtype must match bitwidth of other dtypes in the lambda ... dim_r = sn_iteridx(attrs=attrs, dim=0, dtype=SNType.INT32) ... dim_c = sn_iteridx(attrs=attrs, dim=1, dtype=SNType.INT32) ... mask = dim_r <= dim_c ... one_imm = sn_imm(1, dtype=torch.int32) ... zero_imm = sn_imm(0, dtype=torch.int32) ... return sn_select(mask, one_imm, zero_imm) >>> diag_matrix = sn_zipmapreduce(upper_tri, [x]) >>> diag_matrix.data tensor([[1, 1, 1], [0, 1, 1], [0, 0, 1]], dtype=torch.int32)
Note
This operation is only supported when used inside a Python function/lambda that is input to
sn_zipmapreduce()
- sn_reduce(input: SambaTensor, dim: Tuple[int, ...], fn: str) SambaTensor | Tuple[SambaTensor, ...] #
New in version 1.18.
Performs a reduce operation that reduces the dimensions of a tensor by applying a specific operation across one or more axes (dimensions) of the tensor. Supports multiple functions.
- Parameters:
input – the tensor to reduce.
dim – a tuple specifying the dimensions over which to perform the reduction.
fn –
a string that specifies the reduction function to apply. Supported functions include:
"SUM"
: seetorch.sum()
"MEAN"
: seetorch.mean()
"MIN"
: seetorch.min()
"MAX"
: seetorch.max()
Example
>>> from sambaflow.samba.functional import sn_reduce >>> # In this example, the tensor x has a shape of (2, 2, 2). >>> # sn_reduce is summing over dimensions 1 and 2, which are the last two dimensions. >>> # The result is a tensor y with the sums of each 2x2 block in x, resulting in a shape of (2,). >>> x = samba.SambaTensor(torch.tensor([[[1.0, 2.0], [3.0, 4.0]], ... [[5.0, 6.0], [7.0, 8.0]]])) >>> y = sn_reduce(x, (1, 2), "SUM") >>> y.data tensor([[[10.]], [[26.]]])
- sn_scatter(operand: SambaTensor, update: SambaTensor, start_indices: SambaTensor, scatter_dims: List[int], rmw_op: str = 'kUpdate', unique_indices: bool = False) SambaTensor #
New in version 1.18.
Performs a scatter operation into the operand tensor given an update tensor and indices/dims to scatter along. Supports updating, adding, or multiplying the update tensor into the operand tensor
Update is an
(L, A, ..., D)
tensor that is going to be scattered into the operand tensor.L
represents the batch dim, which gives the number of scatters that will take place.(A, ..., D)
will match the dimensions of operand, except the scatter dimensions will be less than or equal to the corresponding dimension in operandStart_indices will be
(L, (len(scatter_dims)))
and will represent the indices along each scatter dimension to scatter to.- Parameters:
operand (SambaTensor) – The tensor to scatter into
update (SambaTensor) – The tensor of values to scatter
start_indices (SambaTensor) – The start indices along each dimension to scatter one part of update
scatter_dims (List[int]) – The dimensions of operand to scatter along
rmw_op (str) –
The operation to perform between operand & update. Defaults to “kUpdate”. Supported operations include:
"kUpdate"
"kAdd"
"kMin"
"kMax"
"kMul"
unique_indices (bool) – Denotes whether start_indices is expected to contain unique elements. Defaults to False, which means the user does not know or anticipates there to be repeated indices in start_indices. True means the user knows there will be no repeated indices in start_indices
Example
>>> from sambaflow.samba.functional import sn_scatter >>> # In this example, the operand tensor, x, has a shape of (2, 4, 5). >>> # sn_scatter is scattered with the update tensor, u , over dimensions 0 and 1 based on start indices >>> # & rmw_op. >>> # The result y, is of same shape as the operand tensor. >>> x = samba.SambaTensor(torch.tensor([ ... [[0,1,2,3,4], ... [5,6,7,8,9], ... [10,11,12,13,14], ... [15,16,17,18,19]], ... [[0,1,2,3,4], ... [5,6,7,8,9], ... [10,11,12,13,14], ... [15,16,17,18,19]] ... ])) >>> u = samba.SambaTensor(torch.tensor([ ... [[[-1,-2,-3,-4,-5]]], ... [[[-6,-7,-8,-9,-10]]], ... [[[-11,-12,-13,-14,-15]]] ... ])) >>> s_in = samba.SambaTensor(torch.tensor([[1,1],[0,2],[1,3]])) >>> s_dims = [0,1] >>> rmw_op = "kUpdate" >>> y = sn_scatter(x, u, s_in, s_dims, rmw_op) >>> y.data tensor([[[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [ -6, -7, -8, -9, -10], [ 15, 16, 17, 18, 19]], [[ 0, 1, 2, 3, 4], [ -1, -2, -3, -4, -5], [ 10, 11, 12, 13, 14], [-11, -12, -13, -14, -15]]])
- sn_select(cond: SambaTensor, true_val: SambaTensor, false_val: SambaTensor) SambaTensor #
New in version 1.18.
Performs a select operation on tensors, similar to the
torch.where()
function. This function selects elements fromtrue_val
orfalse_val
based on the condition specified incond
.- Parameters:
cond – a mask tensor where each element is a condition. If the condition is true, the corresponding element from
true_val
is selected; otherwise, the element fromfalse_val
is chosen.true_val – the tensor from which elements are selected when the corresponding condition in
cond
is true.false_val – the tensor from which elements are selected when the corresponding condition in
cond
is false.
Example
>>> from sambaflow.samba.functional import sn_zipmapreduce, sn_select >>> # Example showing the use of sn_select with tensors. >>> from sambaflow.samba.functional import sn_select >>> mask = samba.SambaTensor(torch.tensor([1, 0, 1], dtype=torch.int32)) >>> x = samba.SambaTensor(torch.tensor([1, 2, 3], dtype=torch.float)) >>> y = samba.SambaTensor(torch.tensor([4, 5, 6], dtype=torch.float)) >>> f = lambda attrs, mask, x, y: sn_select(mask, x, y) >>> result = sn_zipmapreduce(f, [mask, x, y]) >>> result.data tensor([1., 5., 3.]) >>> # result contains elements from x or y based on the mask. >>> # Here 1 and 3 are selected from x and 5 from y.
Note
This operation is only supported when used inside a Python function/lambda that is input to
sn_zipmapreduce()
See also
For more details see
torch.where()
.
- sn_zipmapreduce(func: Callable[[Dict[str, Any], Iterable[SambaTensor]], SambaTensor], inputs: List[SambaTensor], reduce_func: str | None = None, reduce_dim: int | List[int] | None = None, stoc_accum: bool | None = False) SambaTensor #
New in version 1.18.
Performs a series of Zip/Map operations to the inputs, specified with the given function, optionally followed by a Reduce operation on the result.
- Parameters:
func –
A Python function or lambda describing the intended operation, where the parameters are inputs to the function.
The first parameter (labelled
attrs
conventionally) tofunc
is a placeholder used bysn_zipmapreduce()
to pass special attributes to the operations inside the func body. The parameters following, are assigned from the list of inputs, in the order of their specification.For example, if
func
is set tofunc = lambda attrs, x, y, z: (x + y) * z
and the parameterinputs
isipt
, the output will beoutput = (ipt[0] + ipt[1]) * ipt[2]
func
can have arbitrary number ofSambaTensor
and a single return value. See table flow for supported operationsinputs – List of
SambaTensor
(s) to use as inputs to the operation.reduce_func (optional) – Reduction function to apply after applying
func
to the inputs. (Default =None
). See thefn
attribute ofsn_reduce()
for supported values.reduce_dim (optional) – Dimension(s) to reduce along. (Default =
None
)stoc_accum (optional) – Enable/Disable (
True
/False
) Stochastic Accumulation. (Default =False
)
# Operation
Syntax
Unary Operations
Absolute
torch.abs(x)
(seetorch.abs()
)Exponentiation
torch.exp(x)
(seetorch.exp()
)Logarithm
torch.log(x)
(seetorch.log()
)Reciprocal
torch.reciprocal(x)
(seetorch.reciprocal()
)Square Root
torch.sqrt(x)
(seetorch.sqrt()
)Floor
torch.floor(x)
(seetorch.floor()
)Ceil
torch.ceil(x)
(seetorch.ceil()
)Negation
torch.neg(x)
(seetorch.neg()
)Sigmoid
torch.sigmoid(x)
(seetorch.sigmoid()
)Reciprocal Square Root
torch.rsqrt(x)
(seetorch.rsqrt()
)Tanh
torch.tanh(x)
(seetorch.tanh()
)Binary Operations
Addition
x + y
(seetorch.add()
)Subtraction
x - y
(seetorch.sub()
)Multiplication
x * y
(seetorch.mul()
)Division
x / y
(seetorch.div()
)Equal
x == y
(seetorch.eq()
)Greater Than
x > y
(seetorch.gt()
)Greater Than Equal
x >= y
(seetorch.gte()
)Less Than
x < y
(seetorch.lt()
)Less Than Equal
x <= y
(seetorch.lte()
)Not Equal
x != y
(seetorch.ne()
)Elementwise Maximum
torch.maximum(x, y)
seetorch.maximum()
Elementwise Minimum
torch.minimum(x, y)
seetorch.minimum()
Ternary Operations
Fused Multiply Add
sn_fma(x, y, z)
(seesn_fma()
)Conditional Select
sn_select(x, y, z)
(seesn_select()
)Miscellaneous
Iter Index
sn_iteridx(dim)
(seesn_iteridx()
)Immediate
sn_imm(val)
(seesn_imm()
)Note
The sizes of provided
SambaTensor
(s) ininputs
must be compatible with PyTorch Broadcasting Semanticssn_zipmapreduce()
expects all input tensors to have datatypes with the same bit-width.Examples
>>> from sambaflow.samba.functional import sn_zipmapreduce >>> # Calculate L2 Distance between two vectors of same size >>> samba.set_seed(1) >>> x = samba.ones((5,5)) >>> y = samba.randn((5,5)) >>> dist_func = lambda attrs,x,y: torch.pow(x - y, 2) >>> out = sn_zipmapreduce(dist_func, [x, y]) >>> out.data tensor([[6.3786, 3.0633, 2.7357, 6.8094, 1.2104], [2.5895, 3.9195, 6.8074, 2.9314, 0.0293], [0.5889, 1.1998, 0.7407, 6.6487, 1.7440], [1.6883, 0.7732, 0.4413, 0.5251, 0.5131], [1.1154, 0.0080, 5.7235, 2.8530, 1.2344]])
>>> # Example with broadcast and reduction applied >>> from sambaflow.samba.functional import sn_fma, sn_select >>> # Calculate do_something on the inputs, then sum dimensions 0 and 2 of the result >>> samba.set_seed(1) >>> ipt_0 = samba.randn((3, 2, 3)) >>> ipt_1 = samba.randn((1, 2, 1)) >>> # Dimensions 0 and 2 for ipt_1 will be broadcasted >>> ipt_2 = samba.randn((1, 2, 3)) >>> # Dimension 0 for ipt_2 will be broadcasted >>> def do_something(attrs, x, y, z): ... max_x_y = torch.maximum(x, y) ... max_x_y_z = torch.maximum(max_x_y, z) ... fma_out = sn_fma(x, y, z) ... x_gt_y = x > y ... return sn_select(x_gt_y, fma_out, max_x_y_z) >>> out = sn_zipmapreduce(do_something, [ipt_0,ipt_1,ipt_2], reduce_func="SUM", reduce_dim=[0,2]) >>> out.data tensor([[[ 6.8619], [11.1231]]])