Coverage for src/flag_gems/utils/shape_utils.py: 70%
245 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import enum
2import functools
3import operator
4from typing import Iterable, Sequence, Tuple
6import torch
7import triton
8import triton.language as tl
10from flag_gems.utils import triton_lang_extension as tle
11from flag_gems.utils.codegen_config_utils import get_heuristics_for_num_warps
13Shape = Tuple[int]
14Stride = Tuple[int]
15MultiIndex = Tuple[int]
16Perm = Tuple[int]
19def bracket_next_power_of_2(N, lower, upper):
20 return min(max(triton.next_power_of_2(N), lower), upper)
23def broadcast(s1: Shape, s2: Shape) -> Shape:
24 _s1, _s2 = s1, s2
25 r1 = len(s1)
26 if r1 == 0:
27 return s2
28 r2 = len(s2)
29 if r2 == 0:
30 return s1
32 s1, s2 = (s1, s2) if r1 >= r2 else (s2, s1)
33 r1, r2 = (r1, r2) if r1 >= r2 else (r2, r1)
35 d = r1 - r2
36 s = list(s1)
38 for i in range(r2):
39 if s1[d + i] == 1:
40 s[d + i] = s2[i]
41 elif s2[i] == 1:
42 s[d + i] = s1[d + i]
43 elif s2[i] == s1[d + i]:
44 s[d + i] = s2[i]
45 else:
46 raise ValueError(f"Unbroadcastable {_s1} and {_s2}")
47 s = tuple(s)
48 return s
51def broadcastable(s1: Shape, s2: Shape) -> bool:
52 r1 = len(s1)
53 if r1 == 0:
54 return True
55 r2 = len(s2)
56 if r2 == 0:
57 return True
59 s1, s2 = (s1, s2) if r1 >= r2 else (s2, s1)
60 r1, r2 = (r1, r2) if r1 >= r2 else (r2, r1)
62 d = r1 - r2
63 for i in range(r2):
64 if s1[d + i] == 1 or s2[i] == 1 or s1[d + i] == s2[i]:
65 continue
66 return False
67 return True
70def broadcastable_to(s1: Shape, s2: Shape) -> bool:
71 r1 = len(s1)
72 if r1 == 0:
73 return True
74 r2 = len(s2)
75 if r2 == 0: # r1 > 0
76 return False
78 if r1 > r2:
79 return False
81 d = r2 - r1
82 for i in range(r1):
83 if s1[i] == 1 or s1[i] == s2[d + i]:
84 continue
85 return False
86 return True
89def broadcast_shapes(shapes: Iterable[Shape]) -> Shape:
90 if len(shapes) == 0:
91 return ()
92 shape = shapes[0]
93 for s in shapes[1:]:
94 shape = broadcast(shape, s)
95 return shape
98def broadcasted_stride(shape: Shape, stride: Stride, new_shape: Shape) -> Stride:
99 assert broadcastable_to(shape, new_shape)
100 r1 = len(shape)
101 r2 = len(new_shape)
102 d = r2 - r1
103 new_stride = [0 for _ in range(r2)]
104 for i in range(r1):
105 new_stride[d + i] = 0 if (shape[i] == 1 and new_shape[d + i] > 1) else stride[i]
106 return tuple(new_stride)
109def volume(shape: Shape) -> int:
110 return functools.reduce(operator.mul, shape, 1)
113def is_valid_perm(perm: Perm) -> bool:
114 r = len(perm)
115 sorted_axes = sorted(perm)
116 for i in range(r):
117 if sorted_axes[i] != i:
118 return False
119 return True
122def unravel_index(linear_offset: int, shape: Shape) -> MultiIndex:
123 multi_index = []
124 r = len(shape)
125 for i in range(r):
126 s = shape[r - 1 - i]
127 i = linear_offset % s
128 linear_offset = linear_offset // s
129 multi_index.append(i)
130 return tuple(reversed(multi_index))
133def c_contiguous_stride(shape: Shape) -> Stride:
134 strides = []
135 s = 1
136 for size in reversed(shape):
137 strides.append(s)
138 s *= max(size, 1) # treat size 0 as size 1
139 return tuple(reversed(strides))
142def f_contiguous_stride(shape: Shape) -> Stride:
143 strides = []
144 s = 1
145 for size in shape:
146 strides.append(s)
147 s *= max(size, 1) # treat size 0 as size 1
148 return tuple(strides)
151def ordered_stride(shape: Shape, order: Perm) -> Stride:
152 strides = [0] * len(shape)
153 s = 1
154 for i in order:
155 strides[i] = s
156 s *= max(shape[i], 1) # treat size 0 as size 1
157 return tuple(strides)
160def stride_order(strides):
161 # we also handle negative strides
162 return sorted(range(len(strides)), key=lambda i: abs(strides[i]))
165def all_the_same_shape(tensors: Sequence[torch.Tensor]) -> bool:
166 if len(tensors) == 0:
167 return True
168 shape = tensors[0].shape
169 return all(item.shape == shape for item in tensors[1:])
172def all_the_same_stride(tensors: Sequence[torch.Tensor]) -> bool:
173 if len(tensors) == 0:
174 return True
175 stride = tensors[0].stride()
176 return all(item.stride() == stride for item in tensors[1:])
179def all_c_contiguous(tensors: Sequence[torch.Tensor]) -> bool:
180 if len(tensors) == 0:
181 return True
182 return all(tensor.is_contiguous() for tensor in tensors)
185def heuristics_for_tile_size(max_tile_size, *sizes):
186 ndim = len(sizes)
187 tile_sizes = [0 for _ in range(ndim)]
188 for i in range(ndim):
189 size = sizes[ndim - 1 - i]
190 tile_size = min(max_tile_size, triton.next_power_of_2(size))
191 tile_sizes[ndim - 1 - i] = tile_size
192 max_tile_size = max(1, max_tile_size // tile_size)
193 return tuple(tile_sizes)
196# This should be part of CodeGenConfig
197def heuristics_for_num_warps(tile_size):
198 return get_heuristics_for_num_warps(tile_size)
201def dim_compress(inp, dims):
202 if isinstance(dims, int):
203 dims = [dims]
204 dim = inp.ndim
205 stride = inp.stride()
206 batch_dim = [i for i in range(dim) if i not in dims]
207 sorted_reduction_dim = sorted(dims, key=lambda x: stride[x], reverse=True)
208 order = batch_dim + sorted_reduction_dim
209 return inp.permute(order).contiguous()
212def size_in_bytes(a):
213 return a.numel() * a.element_size()
216def can_use_int32_index(a):
217 INT32_MAX = torch.iinfo(torch.int32).max
218 if a.is_contiguous():
219 return size_in_bytes(a) <= INT32_MAX
221 max_offset = 0
222 for size, stride in zip(a.shape, a.stride()):
223 max_offset += size * stride
224 if max_offset > INT32_MAX:
225 return False
226 return True
229class MemOverlap(enum.Enum):
230 No = 0
231 Yes = 1
232 TooHard = 2
235def has_internal_overlapping(x: torch.Tensor):
236 if x.is_contiguous():
237 return MemOverlap.No
238 if torch.ops.aten.is_non_overlapping_and_dense(x):
239 return MemOverlap.No
240 for size, stride in zip(x.size(), x.stride()):
241 if size > 1 and stride == 0:
242 return MemOverlap.Yes
243 return MemOverlap.TooHard
246def restride_dim(src, dim, shape, step=0, storage_offset=None):
247 strides = list(src.stride())
248 strides[dim] *= step
249 return src.as_strided(shape, strides, storage_offset)
252def cfggen():
253 block_m = [1, 2, 4]
254 block_n = [256, 1024, 2048, 4096]
255 configs = [
256 triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4)
257 for m in block_m
258 for n in block_n
259 ]
260 return configs
263@triton.autotune(configs=cfggen(), key=["M", "N"])
264@triton.jit
265def add_on_kernel(
266 idx,
267 add_on,
268 cur_shape,
269 cur_strides,
270 M,
271 N,
272 BLOCK_M: tl.constexpr,
273 BLOCK_N: tl.constexpr,
274):
275 pid_x = tle.program_id(axis=0)
276 pid_y = tle.program_id(axis=1)
277 rows_offset = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
278 rows_mask = rows_offset < M
280 cols_offset = pid_y + tl.arange(0, BLOCK_N)[None, :]
281 cols_mask = cols_offset < N
282 block_mask = rows_mask and cols_mask
284 offsets = rows_offset * N + cols_offset
285 cur_idx = tl.load(idx + offsets, mask=block_mask, other=1)
286 mod = cur_idx % cur_shape
287 res = mod * cur_strides
288 tl.store(add_on + offsets, res, mask=block_mask)
291def check_tensor_attributes(data_list, is_tensor_list):
292 """
293 Checks if each element in data_list is a tensor and validates whether the corresponding
294 boolean value in is_tensor_list is correct.
295 Parameters:
296 - data_list: A list containing tensor and non-tensor objects.
297 - is_tensor_list: A list of boolean values indicating whether the corresponding element in data_list is a tensor.
298 Returns:
299 - True if all elements' types match their corresponding boolean values in is_tensor_list.
300 - Raise Error otherwise, and prints the index and element that do not match.
301 """
302 # Check if both lists have the same length
303 if len(data_list) != len(is_tensor_list):
304 raise ValueError(
305 "Error: The lists of inputs and is_tensor must have the same length."
306 )
308 for i, (data, is_tensor) in enumerate(zip(data_list, is_tensor_list)):
309 actual_is_tensor = isinstance(data, torch.Tensor)
311 if actual_is_tensor != is_tensor:
312 raise ValueError(
313 f"Element at index {i} is incorrect. Expected {is_tensor}, but got {actual_is_tensor}."
314 )
316 return True
319_initial_missing = object()
322def offset_calculator(inp, idx, strides, dim, isInp):
323 """
324 Calculate the flat index(a.k.a offset) for a given ravel index in a multi-dimensional array.
325 The formula can be seen in:
326 - https://numpy.org/doc/stable/reference/arrays.ndarray.html#internal-memory-layout-of-an-ndarray
327 - https://numpy.org/devdocs/user/basics.indexing.html#single-element-indexing
330 Parameters:
331 inp (tensor): The input multi-dimensional array from which the offset is calculated.
332 idx (tensor): The linear index for which the offset is to be calculated.
333 strides (list of int): A list containing the stride lengths for each dimension of the input array.
334 dim (int): The specific dimension for which the index offset needs to be calculated.
335 isInp (bool): A flag indicating whether the tensor 'inp' is the parameter 'self'
336 in scatter/gather/index_* operators or not.
338 In operators such as scatter/gather and index_*, when the input tensor 'inp'
339 is the 'self' tensor to be processed, we may need to modify its offsets later.
340 For instance, in the scatter operator, the offset is calculated using the formula:
342 inp_offset = origin_offset - stride[dim] * n_dim + stride[dim] * index.
344 In this case, we return the fixed part of the formula:
346 origin_offset - stride[dim] * n_dim,
348 to facilitate subsequent modifications.
349 For other types of input 'inp', we return the complete calculation result
350 of origin_offsets directly.
353 Returns:
354 The calculated offset. If isInp is True, the fixed offset is returned; otherwise, the origin offset is returned.
357 Note:
358 The function includes a comment suggesting the potential optimization of division and modulus operations,
359 which may be beneficial if this function is called frequently.
360 See also:
361 - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
362 - Division by Invariant Integers Using Multiplication,
363 Torbjörn Granlund and Peter L. Montgomery, 1994.
364 """
365 ndim = inp.ndim
366 shape = list(inp.shape)
367 offsets = torch.zeros_like(inp, dtype=torch.int32, device=inp.device)
368 idx_dim = torch.zeros_like(inp, dtype=torch.int32, device=inp.device)
369 for d in range(0, ndim):
370 add_on = torch.zeros_like(inp, dtype=torch.int32, device=inp.device)
371 N = idx.size(idx.ndim - 1)
372 M = idx.numel() // N
373 grid = lambda meta: (
374 triton.cdiv(M, meta["BLOCK_M"]),
375 triton.cdiv(N, meta["BLOCK_N"]),
376 )
377 add_on_kernel[grid](idx, add_on, shape[d], strides[d], M, N)
379 offsets = torch.add(offsets, add_on)
380 if d == dim:
381 idx_dim = add_on
382 idx = idx // shape[d]
383 return offsets if not isInp else (offsets - idx_dim)
386def offsetCalculator(inp, idx, strides, dim, isInp):
387 ndim = inp.ndim
388 shape = list(inp.shape)
389 offsets = 0
390 idx_dim = 0
391 for d in range(0, ndim):
392 mod = idx % shape[d]
393 add_on = mod * strides[d]
394 offsets += add_on
395 if d == dim:
396 idx_dim = add_on
397 idx = idx // shape[d]
398 # FIXME: Should we write a fast div/mod
399 # to boost the '%' and '//'? (Since they may be run many times)
400 # See also:
401 # - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
402 # - Division by Invariant Integers Using Multiplication,
403 # Torbjörn Granlund and Peter L. Montgomery, 1994.
404 return (offsets) if not isInp else (offsets - idx_dim)