Coverage for src/flag_gems/utils/shape_utils.py: 70%

245 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import enum 

2import functools 

3import operator 

4from typing import Iterable, Sequence, Tuple 

5 

6import torch 

7import triton 

8import triton.language as tl 

9 

10from flag_gems.utils import triton_lang_extension as tle 

11from flag_gems.utils.codegen_config_utils import get_heuristics_for_num_warps 

12 

13Shape = Tuple[int] 

14Stride = Tuple[int] 

15MultiIndex = Tuple[int] 

16Perm = Tuple[int] 

17 

18 

19def bracket_next_power_of_2(N, lower, upper): 

20 return min(max(triton.next_power_of_2(N), lower), upper) 

21 

22 

23def broadcast(s1: Shape, s2: Shape) -> Shape: 

24 _s1, _s2 = s1, s2 

25 r1 = len(s1) 

26 if r1 == 0: 

27 return s2 

28 r2 = len(s2) 

29 if r2 == 0: 

30 return s1 

31 

32 s1, s2 = (s1, s2) if r1 >= r2 else (s2, s1) 

33 r1, r2 = (r1, r2) if r1 >= r2 else (r2, r1) 

34 

35 d = r1 - r2 

36 s = list(s1) 

37 

38 for i in range(r2): 

39 if s1[d + i] == 1: 

40 s[d + i] = s2[i] 

41 elif s2[i] == 1: 

42 s[d + i] = s1[d + i] 

43 elif s2[i] == s1[d + i]: 

44 s[d + i] = s2[i] 

45 else: 

46 raise ValueError(f"Unbroadcastable {_s1} and {_s2}") 

47 s = tuple(s) 

48 return s 

49 

50 

51def broadcastable(s1: Shape, s2: Shape) -> bool: 

52 r1 = len(s1) 

53 if r1 == 0: 

54 return True 

55 r2 = len(s2) 

56 if r2 == 0: 

57 return True 

58 

59 s1, s2 = (s1, s2) if r1 >= r2 else (s2, s1) 

60 r1, r2 = (r1, r2) if r1 >= r2 else (r2, r1) 

61 

62 d = r1 - r2 

63 for i in range(r2): 

64 if s1[d + i] == 1 or s2[i] == 1 or s1[d + i] == s2[i]: 

65 continue 

66 return False 

67 return True 

68 

69 

70def broadcastable_to(s1: Shape, s2: Shape) -> bool: 

71 r1 = len(s1) 

72 if r1 == 0: 

73 return True 

74 r2 = len(s2) 

75 if r2 == 0: # r1 > 0 

76 return False 

77 

78 if r1 > r2: 

79 return False 

80 

81 d = r2 - r1 

82 for i in range(r1): 

83 if s1[i] == 1 or s1[i] == s2[d + i]: 

84 continue 

85 return False 

86 return True 

87 

88 

89def broadcast_shapes(shapes: Iterable[Shape]) -> Shape: 

90 if len(shapes) == 0: 

91 return () 

92 shape = shapes[0] 

93 for s in shapes[1:]: 

94 shape = broadcast(shape, s) 

95 return shape 

96 

97 

98def broadcasted_stride(shape: Shape, stride: Stride, new_shape: Shape) -> Stride: 

99 assert broadcastable_to(shape, new_shape) 

100 r1 = len(shape) 

101 r2 = len(new_shape) 

102 d = r2 - r1 

103 new_stride = [0 for _ in range(r2)] 

104 for i in range(r1): 

105 new_stride[d + i] = 0 if (shape[i] == 1 and new_shape[d + i] > 1) else stride[i] 

106 return tuple(new_stride) 

107 

108 

109def volume(shape: Shape) -> int: 

110 return functools.reduce(operator.mul, shape, 1) 

111 

112 

113def is_valid_perm(perm: Perm) -> bool: 

114 r = len(perm) 

115 sorted_axes = sorted(perm) 

116 for i in range(r): 

117 if sorted_axes[i] != i: 

118 return False 

119 return True 

120 

121 

122def unravel_index(linear_offset: int, shape: Shape) -> MultiIndex: 

123 multi_index = [] 

124 r = len(shape) 

125 for i in range(r): 

126 s = shape[r - 1 - i] 

127 i = linear_offset % s 

128 linear_offset = linear_offset // s 

129 multi_index.append(i) 

130 return tuple(reversed(multi_index)) 

131 

132 

133def c_contiguous_stride(shape: Shape) -> Stride: 

134 strides = [] 

135 s = 1 

136 for size in reversed(shape): 

137 strides.append(s) 

138 s *= max(size, 1) # treat size 0 as size 1 

139 return tuple(reversed(strides)) 

140 

141 

142def f_contiguous_stride(shape: Shape) -> Stride: 

143 strides = [] 

144 s = 1 

145 for size in shape: 

146 strides.append(s) 

147 s *= max(size, 1) # treat size 0 as size 1 

148 return tuple(strides) 

149 

150 

151def ordered_stride(shape: Shape, order: Perm) -> Stride: 

152 strides = [0] * len(shape) 

153 s = 1 

154 for i in order: 

155 strides[i] = s 

156 s *= max(shape[i], 1) # treat size 0 as size 1 

157 return tuple(strides) 

158 

159 

160def stride_order(strides): 

161 # we also handle negative strides 

162 return sorted(range(len(strides)), key=lambda i: abs(strides[i])) 

163 

164 

165def all_the_same_shape(tensors: Sequence[torch.Tensor]) -> bool: 

166 if len(tensors) == 0: 

167 return True 

168 shape = tensors[0].shape 

169 return all(item.shape == shape for item in tensors[1:]) 

170 

171 

172def all_the_same_stride(tensors: Sequence[torch.Tensor]) -> bool: 

173 if len(tensors) == 0: 

174 return True 

175 stride = tensors[0].stride() 

176 return all(item.stride() == stride for item in tensors[1:]) 

177 

178 

179def all_c_contiguous(tensors: Sequence[torch.Tensor]) -> bool: 

180 if len(tensors) == 0: 

181 return True 

182 return all(tensor.is_contiguous() for tensor in tensors) 

183 

184 

185def heuristics_for_tile_size(max_tile_size, *sizes): 

186 ndim = len(sizes) 

187 tile_sizes = [0 for _ in range(ndim)] 

188 for i in range(ndim): 

189 size = sizes[ndim - 1 - i] 

190 tile_size = min(max_tile_size, triton.next_power_of_2(size)) 

191 tile_sizes[ndim - 1 - i] = tile_size 

192 max_tile_size = max(1, max_tile_size // tile_size) 

193 return tuple(tile_sizes) 

194 

195 

196# This should be part of CodeGenConfig 

197def heuristics_for_num_warps(tile_size): 

198 return get_heuristics_for_num_warps(tile_size) 

199 

200 

201def dim_compress(inp, dims): 

202 if isinstance(dims, int): 

203 dims = [dims] 

204 dim = inp.ndim 

205 stride = inp.stride() 

206 batch_dim = [i for i in range(dim) if i not in dims] 

207 sorted_reduction_dim = sorted(dims, key=lambda x: stride[x], reverse=True) 

208 order = batch_dim + sorted_reduction_dim 

209 return inp.permute(order).contiguous() 

210 

211 

212def size_in_bytes(a): 

213 return a.numel() * a.element_size() 

214 

215 

216def can_use_int32_index(a): 

217 INT32_MAX = torch.iinfo(torch.int32).max 

218 if a.is_contiguous(): 

219 return size_in_bytes(a) <= INT32_MAX 

220 

221 max_offset = 0 

222 for size, stride in zip(a.shape, a.stride()): 

223 max_offset += size * stride 

224 if max_offset > INT32_MAX: 

225 return False 

226 return True 

227 

228 

229class MemOverlap(enum.Enum): 

230 No = 0 

231 Yes = 1 

232 TooHard = 2 

233 

234 

235def has_internal_overlapping(x: torch.Tensor): 

236 if x.is_contiguous(): 

237 return MemOverlap.No 

238 if torch.ops.aten.is_non_overlapping_and_dense(x): 

239 return MemOverlap.No 

240 for size, stride in zip(x.size(), x.stride()): 

241 if size > 1 and stride == 0: 

242 return MemOverlap.Yes 

243 return MemOverlap.TooHard 

244 

245 

246def restride_dim(src, dim, shape, step=0, storage_offset=None): 

247 strides = list(src.stride()) 

248 strides[dim] *= step 

249 return src.as_strided(shape, strides, storage_offset) 

250 

251 

252def cfggen(): 

253 block_m = [1, 2, 4] 

254 block_n = [256, 1024, 2048, 4096] 

255 configs = [ 

256 triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4) 

257 for m in block_m 

258 for n in block_n 

259 ] 

260 return configs 

261 

262 

263@triton.autotune(configs=cfggen(), key=["M", "N"]) 

264@triton.jit 

265def add_on_kernel( 

266 idx, 

267 add_on, 

268 cur_shape, 

269 cur_strides, 

270 M, 

271 N, 

272 BLOCK_M: tl.constexpr, 

273 BLOCK_N: tl.constexpr, 

274): 

275 pid_x = tle.program_id(axis=0) 

276 pid_y = tle.program_id(axis=1) 

277 rows_offset = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

278 rows_mask = rows_offset < M 

279 

280 cols_offset = pid_y + tl.arange(0, BLOCK_N)[None, :] 

281 cols_mask = cols_offset < N 

282 block_mask = rows_mask and cols_mask 

283 

284 offsets = rows_offset * N + cols_offset 

285 cur_idx = tl.load(idx + offsets, mask=block_mask, other=1) 

286 mod = cur_idx % cur_shape 

287 res = mod * cur_strides 

288 tl.store(add_on + offsets, res, mask=block_mask) 

289 

290 

291def check_tensor_attributes(data_list, is_tensor_list): 

292 """ 

293 Checks if each element in data_list is a tensor and validates whether the corresponding 

294 boolean value in is_tensor_list is correct. 

295 Parameters: 

296 - data_list: A list containing tensor and non-tensor objects. 

297 - is_tensor_list: A list of boolean values indicating whether the corresponding element in data_list is a tensor. 

298 Returns: 

299 - True if all elements' types match their corresponding boolean values in is_tensor_list. 

300 - Raise Error otherwise, and prints the index and element that do not match. 

301 """ 

302 # Check if both lists have the same length 

303 if len(data_list) != len(is_tensor_list): 

304 raise ValueError( 

305 "Error: The lists of inputs and is_tensor must have the same length." 

306 ) 

307 

308 for i, (data, is_tensor) in enumerate(zip(data_list, is_tensor_list)): 

309 actual_is_tensor = isinstance(data, torch.Tensor) 

310 

311 if actual_is_tensor != is_tensor: 

312 raise ValueError( 

313 f"Element at index {i} is incorrect. Expected {is_tensor}, but got {actual_is_tensor}." 

314 ) 

315 

316 return True 

317 

318 

319_initial_missing = object() 

320 

321 

322def offset_calculator(inp, idx, strides, dim, isInp): 

323 """ 

324 Calculate the flat index(a.k.a offset) for a given ravel index in a multi-dimensional array. 

325 The formula can be seen in: 

326 - https://numpy.org/doc/stable/reference/arrays.ndarray.html#internal-memory-layout-of-an-ndarray 

327 - https://numpy.org/devdocs/user/basics.indexing.html#single-element-indexing 

328 

329 

330 Parameters: 

331 inp (tensor): The input multi-dimensional array from which the offset is calculated. 

332 idx (tensor): The linear index for which the offset is to be calculated. 

333 strides (list of int): A list containing the stride lengths for each dimension of the input array. 

334 dim (int): The specific dimension for which the index offset needs to be calculated. 

335 isInp (bool): A flag indicating whether the tensor 'inp' is the parameter 'self' 

336 in scatter/gather/index_* operators or not. 

337 

338 In operators such as scatter/gather and index_*, when the input tensor 'inp' 

339 is the 'self' tensor to be processed, we may need to modify its offsets later. 

340 For instance, in the scatter operator, the offset is calculated using the formula: 

341 

342 inp_offset = origin_offset - stride[dim] * n_dim + stride[dim] * index. 

343 

344 In this case, we return the fixed part of the formula: 

345 

346 origin_offset - stride[dim] * n_dim, 

347 

348 to facilitate subsequent modifications. 

349 For other types of input 'inp', we return the complete calculation result 

350 of origin_offsets directly. 

351 

352 

353 Returns: 

354 The calculated offset. If isInp is True, the fixed offset is returned; otherwise, the origin offset is returned. 

355 

356 

357 Note: 

358 The function includes a comment suggesting the potential optimization of division and modulus operations, 

359 which may be beneficial if this function is called frequently. 

360 See also: 

361 - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html 

362 - Division by Invariant Integers Using Multiplication, 

363 Torbjörn Granlund and Peter L. Montgomery, 1994. 

364 """ 

365 ndim = inp.ndim 

366 shape = list(inp.shape) 

367 offsets = torch.zeros_like(inp, dtype=torch.int32, device=inp.device) 

368 idx_dim = torch.zeros_like(inp, dtype=torch.int32, device=inp.device) 

369 for d in range(0, ndim): 

370 add_on = torch.zeros_like(inp, dtype=torch.int32, device=inp.device) 

371 N = idx.size(idx.ndim - 1) 

372 M = idx.numel() // N 

373 grid = lambda meta: ( 

374 triton.cdiv(M, meta["BLOCK_M"]), 

375 triton.cdiv(N, meta["BLOCK_N"]), 

376 ) 

377 add_on_kernel[grid](idx, add_on, shape[d], strides[d], M, N) 

378 

379 offsets = torch.add(offsets, add_on) 

380 if d == dim: 

381 idx_dim = add_on 

382 idx = idx // shape[d] 

383 return offsets if not isInp else (offsets - idx_dim) 

384 

385 

386def offsetCalculator(inp, idx, strides, dim, isInp): 

387 ndim = inp.ndim 

388 shape = list(inp.shape) 

389 offsets = 0 

390 idx_dim = 0 

391 for d in range(0, ndim): 

392 mod = idx % shape[d] 

393 add_on = mod * strides[d] 

394 offsets += add_on 

395 if d == dim: 

396 idx_dim = add_on 

397 idx = idx // shape[d] 

398 # FIXME: Should we write a fast div/mod 

399 # to boost the '%' and '//'? (Since they may be run many times) 

400 # See also: 

401 # - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html 

402 # - Division by Invariant Integers Using Multiplication, 

403 # Torbjörn Granlund and Peter L. Montgomery, 1994. 

404 return (offsets) if not isInp else (offsets - idx_dim)