samba.parallel_patterns¶
Parallel Patterns¶
- sn_gather(input_tensor: Tensor | SambaTensor, start_indices: Tensor | SambaTensor, gather_dims: List[int], gather_lengths: List[int]) Tensor | SambaTensor ¶
Performs a general gather operation. Supports multiple gather dimensions, different start indices, and different gather lengths.
- Parameters:
input_tensor – input tensor to gather from, aka the Look Up Table (LUT). This can be any shape.
start_indices – tensor of indices which denote the start index to gather from in each gather dimension. Has shape MxNx…xD where: D is the number of specified gather dimensions. MxNx… is the number of sets of start indices, aka the number of gathers being performed. This can be broken up into any number of leading dimensions.
gather_dims – List of D dimensions to gather along. Missing dims will be gathered in full.
gather_lengths – List of D lengths correspond to each dim in gather_dims. This will represent how many values to gather from each dim starting from the start index in start_indices. Missing dims will be gathered in full.
Example
>>> from sambaflow.samba.functional import sn_gather >>> # In this example, the input_tensor (LUT), x, has a shape of (5, 5). >>> # sn_gather is gather over dimensions 0 and 1 for length 2, 3 respectively based on start indices. >>> # The result is y, an output tensor, with 3 (2x3) blocks, resulting in a shape of (3,2,3). >>> x = torch.tensor([[0, 1, 2, 3, 4], ... [5, 6, 7, 8, 9], ... [10, 11, 12, 13, 14], ... [15, 16, 17, 18, 19], ... [20, 21, 22, 23, 24]]) >>> s_in = torch.tensor([[0,0],[1,1],[3,2]]) >>> g_dims = [0,1] >>> g_len = [2,3] >>> y = sn_gather(x, s_in, g_dims, g_len) >>> y.data tensor([[[ 0, 1, 2], [ 5, 6, 7]], [[ 6, 7, 8], [11, 12, 13]], [[17, 18, 19], [22, 23, 24]]])
- sn_reduce(input: Tensor, dim: Tuple[int, ...], fn: str) Tensor | Tuple[Tensor, ...] ¶
Performs a reduce operation that reduces the dimensions of a tensor by applying a specific operation across one or more axes (dimensions) of the tensor. Supports multiple functions.
- Parameters:
input – the tensor to reduce.
dim – a tuple specifying the dimensions over which to perform the reduction.
fn –
a string that specifies the reduction function to apply. Supported functions include:
"SUM"
: seetorch.sum()
"MEAN"
: seetorch.mean()
"MIN"
: seetorch.min()
"MAX"
: seetorch.max()
Example
>>> from sambaflow.samba.functional import sn_reduce >>> # In this example, the tensor x has a shape of (2, 2, 2). >>> # sn_reduce is summing over dimensions 1 and 2, which are the last two dimensions. >>> # The result is a tensor y with the sums of each 2x2 block in x, resulting in a shape of (2,). >>> x = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], ... [[5.0, 6.0], [7.0, 8.0]]]) >>> y = sn_reduce(x, (1, 2), "SUM") >>> y.data tensor([[[10.]], [[26.]]])
- sn_scatter(operand: Tensor | SambaTensor, update: Tensor | SambaTensor, start_indices: Tensor | SambaTensor, scatter_dims: List[int], rmw_op: str = 'kUpdate', unique_indices: bool = False) Tensor | SambaTensor ¶
Performs a scatter operation into the operand tensor given an update tensor and indices/dims to scatter along. Supports updating, adding, or multiplying the update tensor into the operand tensor
Update is an
(G, ..., A, ..., )
tensor that is going to be scattered into the operand tensor.G
represents the group dim(s), which gives the number of scatters that will take place.(A, ...,)
will match the dimensions of operand, except the scatter dimensions will have size less than the corresponding dimension in operandStart_indices is
(G, ..., (len(scatter_dims)))
and represents the indices along each scatter dimension to scatter to.- Parameters:
operand – The tensor to scatter into. This can be any shape.
update – The tensor of values to scatter. Has a shape of from Nx…Gx…AxBx… G represents “Group Dimensions”. These allow multiple scatters in a single scatter operation AxBx… correspond 1 to 1 with the dimensions of operand. If it IS a scatter dimension, this corresponding dimension will have a size less than the size of the dimension in operand (# of elements scattered along that scatter dimension). If it is NOT a scatter dim, this dimension will have a size equal to the dimension in operand.
start_indices – The start indices indicating where to begin each scatter. This has a shape of NxGx…D G represents “Group Dimensions” which are the same as the “Group Dimension” in update D is equal to the number of scatter dimension. There is 1 start index per scatter dimension.
scatter_dims – The dimensions of operand to scatter along. Has length D.
rmw_op – Stands for Read Modify Write Operation. This denotes which operation to perform between operand & update. A value is read from the operand, the specified operation is performed between this read value and the new value from the update tensor. The result is written to the operand tensor. The default operation is kUpdate which skips the read/modify portion and just performs a normal scatter.
include (Supported operations) –
"kUpdate"
skip Read and Modify portions and just perform a normal scatter (Default)"kAdd"
add update value to the value at destination index in operand"kMin"
compare old value from operand and new value from update.Only scatter if update value is less than value in operand
"kMax"
compare old value from operand and new value from update.Only scatter if update value is greater than value in operand
"kMul"
multiply update value with the value at destination index in operand
unique_indices – Denotes whether start_indices is expected to contain unique elements. Defaults to False, which means the user does not know or anticipates there to be repeated indices in start_indices. True means the user knows there will be no repeated indices in start_indices.
Example
>>> from sambaflow.samba.functional import sn_scatter >>> # In this example, the operand tensor, x, has a shape of (2, 4, 5). >>> # sn_scatter is scattered with the update tensor, u , over dimensions 0 and 1 based on start indices >>> # & rmw_op. >>> # The result y, is of same shape as the operand tensor. >>> x = torch.tensor([ ... [[0, 1, 2, 3, 4], ... [5, 6, 7, 8, 9], ... [10, 11, 12, 13, 14], ... [15, 16, 17, 18, 19]], ... [[0, 1, 2, 3, 4], ... [5, 6, 7, 8, 9], ... [10, 11, 12, 13, 14], ... [15, 16, 17, 18, 19]] ... ]) >>> u = torch.tensor([ ... [[[-1, -2, -3, -4, -5]]], ... [[[-6, -7, -8, -9, -10]]], ... [[[-11, -12, -13, -14, -15]]] ... ]) >>> s_in = torch.tensor([[1,1],[0,2],[1,3]]) >>> s_dims = [0,1] >>> rmw_op = "kUpdate" >>> y = sn_scatter(x, u, s_in, s_dims, rmw_op) >>> y.data tensor([[[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [ -6, -7, -8, -9, -10], [ 15, 16, 17, 18, 19]], [[ 0, 1, 2, 3, 4], [ -1, -2, -3, -4, -5], [ 10, 11, 12, 13, 14], [-11, -12, -13, -14, -15]]])
- sn_zipmapreduce(func: Callable[[Dict[str, Any], Iterable[SambaTensor]], SambaTensor], inputs: List[SambaTensor], reduce_func: str | None = None, reduce_dim: int | List[int] | None = None, stoc_accum: bool | None = False) SambaTensor ¶
Performs a series of Zip/Map operations to the inputs, specified with the given function, optionally followed by a Reduce operation on the result.
- Parameters:
func –
a Python function or lambda describing the intended operation, where the parameters are inputs to the function.
The first parameter (labelled
attrs
conventionally) tofunc
is a placeholder used bysn_zipmapreduce()
to pass special attributes to the operations inside the func body. The parameters following, are assigned from the list of inputs, in the order of their specification.For example, if
func
is set tofunc = lambda attrs, x, y, z: (x + y) * z
and the parameterinputs
isipt
, the output will beoutput = (ipt[0] + ipt[1]) * ipt[2]
.func
can have arbitrary number ofSambaTensor
and a single return value. See table flow for supported operations.inputs – list of
SambaTensor
(s) to use as inputs to the operation.reduce_func (optional) – reduction function to apply after applying
func
to the inputs. (Default =None
). See thefn
attribute ofsn_reduce()
for supported values.reduce_dim (optional) – dimension(s) to reduce along (Default =
None
).stoc_accum (optional) – enable/disable (
True
/False
) Stochastic Accumulation (Default =False
).
¶ Operation
Syntax
Unary Operations
Absolute
torch.abs(x)
(seetorch.abs()
)Exponentiation
torch.exp(x)
(seetorch.exp()
)Logarithm
torch.log(x)
(seetorch.log()
)Reciprocal
torch.reciprocal(x)
(seetorch.reciprocal()
)Square Root
torch.sqrt(x)
(seetorch.sqrt()
)Floor
torch.floor(x)
(seetorch.floor()
)Ceil
torch.ceil(x)
(seetorch.ceil()
)Negation
torch.neg(x)
(seetorch.neg()
)Sigmoid
torch.sigmoid(x)
(seetorch.sigmoid()
)Reciprocal Square Root
torch.rsqrt(x)
(seetorch.rsqrt()
)Tanh
torch.tanh(x)
(seetorch.tanh()
)Binary Operations
Addition
x + y
(seetorch.add()
)Subtraction
x - y
(seetorch.sub()
)Multiplication
x * y
(seetorch.mul()
)Division
x / y
(seetorch.div()
)Equal
x == y
(seetorch.eq()
)Greater Than
x > y
(seetorch.gt()
)Greater Than Equal
x >= y
(seetorch.gte()
)Less Than
x < y
(seetorch.lt()
)Less Than Equal
x <= y
(seetorch.lte()
)Not Equal
x != y
(seetorch.ne()
)Elementwise Maximum
torch.maximum(x, y)
seetorch.maximum()
Elementwise Minimum
torch.minimum(x, y)
seetorch.minimum()
Ternary Operations
Fused Multiply Add
sn_fma(x, y, z)
(seesn_fma()
)Conditional Select
sn_select(x, y, z)
(seesn_select()
)Miscellaneous
Iter Index
sn_iteridx(dim)
(seesn_iteridx()
)Immediate
sn_imm(val)
(seesn_imm()
)Note
The sizes of provided
SambaTensor
(s) ininputs
must be compatible with PyTorch Broadcasting Semanticssn_zipmapreduce()
expects all input tensors to have datatypes with the same bit-width.Examples
>>> from sambaflow.samba.functional import sn_zipmapreduce >>> # Calculate L2 Distance between two vectors of same size >>> samba.set_seed(1) >>> x = samba.ones((5,5)) >>> y = samba.randn((5,5)) >>> dist_func = lambda attrs,x,y: torch.pow(x - y, 2) >>> out = sn_zipmapreduce(dist_func, [x, y]) >>> out.data tensor([[6.3786, 3.0633, 2.7357, 6.8094, 1.2104], [2.5895, 3.9195, 6.8074, 2.9314, 0.0293], [0.5889, 1.1998, 0.7407, 6.6487, 1.7440], [1.6883, 0.7732, 0.4413, 0.5251, 0.5131], [1.1154, 0.0080, 5.7235, 2.8530, 1.2344]])
>>> # Example with broadcast and reduction applied >>> from sambaflow.samba.functional import sn_fma, sn_select >>> # Calculate do_something on the inputs, then sum dimensions 0 and 2 of the result >>> samba.set_seed(1) >>> ipt_0 = samba.randn((3, 2, 3)) >>> ipt_1 = samba.randn((1, 2, 1)) >>> # Dimensions 0 and 2 for ipt_1 will be broadcasted >>> ipt_2 = samba.randn((1, 2, 3)) >>> # Dimension 0 for ipt_2 will be broadcasted >>> def do_something(attrs, x, y, z): ... max_x_y = torch.maximum(x, y) ... max_x_y_z = torch.maximum(max_x_y, z) ... fma_out = sn_fma(x, y, z) ... x_gt_y = x > y ... return sn_select(x_gt_y, fma_out, max_x_y_z) >>> out = sn_zipmapreduce(do_something, [ipt_0,ipt_1,ipt_2], reduce_func="SUM", reduce_dim=[0,2]) >>> out.data tensor([[[ 6.8619], [11.1231]]])