Coverage for src/flag_gems/experimental_ops/expand.py: 0%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def expand(
8 x_ptr,
9 out_ptr,
10 n_elements,
11 ndims,
12 out_shape_ptr,
13 out_cumprod_ptr,
14 in_stride_ptr,
15 BLOCK_SIZE: tl.constexpr,
16 MAX_DIMS: tl.constexpr,
17):
18 pid = tl.program_id(axis=0)
19 block_start = pid * BLOCK_SIZE
20 offsets = block_start + tl.arange(0, BLOCK_SIZE)
21 mask = offsets < n_elements
23 # Compute input offsets corresponding to each output linear index
24 in_offsets = tl.zeros([BLOCK_SIZE], dtype=tl.int64)
26 # Accumulate contributions per dimension
27 for d in range(MAX_DIMS):
28 # Load scalars defining the output decomposition and input strides
29 s = tl.load(out_shape_ptr + d)
30 stride_right = tl.load(out_cumprod_ptr + d)
31 in_stride = tl.load(in_stride_ptr + d)
32 # idx along dimension d for each linear offset
33 idx_d = (offsets // stride_right) % s
34 # contribution to input linear offset
35 in_offsets += idx_d * in_stride
37 # Load from input using computed offsets and store to output
38 x = tl.load(x_ptr + in_offsets, mask=mask)
39 tl.store(out_ptr + offsets, x, mask=mask)
42_expand_kernel = expand
45def expand(*args, **kwargs):
46 x = args[0]
47 size = args[1]
48 implicit = kwargs.get( # noqa: F841
49 "implicit", False
50 ) # not used but accepted for signature compatibility
52 if not isinstance(size, (list, tuple, torch.Size)):
53 raise TypeError("expand size must be a list/tuple/torch.Size of ints")
55 size = list(size)
56 in_shape = list(x.shape)
57 in_strides = list(x.stride())
59 out_ndim = len(size)
60 in_ndim = len(in_shape)
62 if in_ndim > out_ndim:
63 raise RuntimeError(
64 f"expand: requested size has fewer dimensions ({out_ndim}) than input ({in_ndim})"
65 )
67 # Pad input shape/strides on the left to match output ndim
68 if in_ndim < out_ndim:
69 pad = out_ndim - in_ndim
70 in_shape = [1] * pad + in_shape
71 # For padded (new) leading dims, stride effectively is 0 since they will be broadcast
72 in_strides = [0] * pad + in_strides
74 # Resolve -1 and validate broadcastability
75 out_shape = []
76 for d in range(out_ndim):
77 req = size[d]
78 src = in_shape[d]
79 if req == -1:
80 target = src
81 else:
82 target = req
83 if src != target and src != 1:
84 raise RuntimeError(
85 f"The expanded size of the tensor ({target}) must match the existing size ({src}) at non-singleton "
86 f"dimension {d}. Target sizes must be the same, or -1, or the size of dimension in the original tensor must be 1." # noqa: E501
87 )
88 out_shape.append(int(target))
90 # Effective input strides: 0 for broadcasted dims, original stride otherwise
91 in_stride_eff = [
92 int(in_strides[d]) if in_shape[d] != 1 else 0 for d in range(out_ndim)
93 ]
95 # Prepare decomposition multipliers: product of sizes to the right for each dim
96 out_cumprod_right = [0] * out_ndim
97 prod = 1
98 for d in range(out_ndim - 1, -1, -1):
99 out_cumprod_right[d] = prod
100 prod *= out_shape[d]
102 # Allocate output
103 out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
105 n_elements = out.numel()
106 if n_elements == 0:
107 return out
109 # Triton kernel parameters
110 BLOCK_SIZE = 1024
111 MAX_DIMS = max(out_ndim, 1) # at least 1
112 # Round up MAX_DIMS to a reasonable static upper bound for compilation (e.g., 16)
113 # but ensure arrays we pass match MAX_DIMS in kernel
114 STATIC_MAX = 16
115 if MAX_DIMS > STATIC_MAX:
116 STATIC_MAX = MAX_DIMS
118 # Create device arrays for shapes/strides with padding for MAX_DIMS
119 pad_len = STATIC_MAX - out_ndim
120 out_shape_arr = torch.tensor(
121 out_shape + [1] * pad_len, dtype=torch.int64, device=x.device
122 )
123 out_cumprod_arr = torch.tensor(
124 out_cumprod_right + [1] * pad_len, dtype=torch.int64, device=x.device
125 )
126 in_stride_arr = torch.tensor(
127 in_stride_eff + [0] * pad_len, dtype=torch.int64, device=x.device
128 )
130 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
132 _expand_kernel[grid](
133 x,
134 out,
135 n_elements,
136 out_ndim,
137 out_shape_arr,
138 out_cumprod_arr,
139 in_stride_arr,
140 BLOCK_SIZE=BLOCK_SIZE,
141 MAX_DIMS=STATIC_MAX,
142 )
143 return out