Coverage for src/flag_gems/runtime/backend/_mthreads/ops/repeat.py: 0%

176 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import triton_lang_extension as tle 

9from flag_gems.utils.libentry import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.autotune( 

16 configs=[ 

17 triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4), 

18 triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4), 

19 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4), 

20 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8), 

21 triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4), 

22 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4), 

23 triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=8), 

24 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8), 

25 ], 

26 key=["out_shape0", "out_shape1"], 

27) 

28@triton.jit 

29def repeat_kernel_2d( 

30 inp_ptr, 

31 out_ptr, 

32 inp_stride0, 

33 inp_stride1, 

34 out_stride0, 

35 out_stride1, 

36 inp_shape0, 

37 inp_shape1, 

38 out_shape0, 

39 out_shape1, 

40 BLOCK_M: tl.constexpr, 

41 BLOCK_N: tl.constexpr, 

42): 

43 pid_m = tle.program_id(0) 

44 pid_n = tle.program_id(1) 

45 

46 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

47 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

48 

49 mask_m = offs_m < out_shape0 

50 mask_n = offs_n < out_shape1 

51 mask = mask_m[:, None] & mask_n[None, :] 

52 

53 # Map output indices to input indices using modulo 

54 inp_offs_m = offs_m % inp_shape0 

55 inp_offs_n = offs_n % inp_shape1 

56 

57 # Load from input 

58 inp_ptrs = ( 

59 inp_ptr + inp_offs_m[:, None] * inp_stride0 + inp_offs_n[None, :] * inp_stride1 

60 ) 

61 data = tl.load(inp_ptrs, mask=mask, other=0.0) 

62 

63 # Store to output 

64 out_ptrs = out_ptr + offs_m[:, None] * out_stride0 + offs_n[None, :] * out_stride1 

65 tl.store(out_ptrs, data, mask=mask) 

66 

67 

68@libentry() 

69@triton.autotune( 

70 configs=[ 

71 triton.Config({"BLOCK_SIZE": 256}, num_warps=4), 

72 triton.Config({"BLOCK_SIZE": 512}, num_warps=4), 

73 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8), 

74 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), 

75 ], 

76 key=["out_shape0"], 

77) 

78@triton.jit 

79def repeat_kernel_1d( 

80 inp_ptr, 

81 out_ptr, 

82 inp_stride0, 

83 out_stride0, 

84 inp_shape0, 

85 out_shape0, 

86 BLOCK_SIZE: tl.constexpr, 

87): 

88 pid = tle.program_id(0) 

89 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

90 mask = offs < out_shape0 

91 

92 # Map output indices to input indices 

93 inp_offs = offs % inp_shape0 

94 

95 # Load and store 

96 data = tl.load(inp_ptr + inp_offs * inp_stride0, mask=mask) 

97 tl.store(out_ptr + offs * out_stride0, data, mask=mask) 

98 

99 

100@libentry() 

101@triton.autotune( 

102 configs=[ 

103 triton.Config({"BLOCK_N": 32, "BLOCK_K": 32}, num_warps=4), 

104 triton.Config({"BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4), 

105 triton.Config({"BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4), 

106 triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8), 

107 triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4), 

108 triton.Config({"BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4), 

109 ], 

110 key=["out_shape1", "out_shape2"], 

111) 

112@triton.jit 

113def repeat_kernel_3d( 

114 inp_ptr, 

115 out_ptr, 

116 inp_stride0, 

117 inp_stride1, 

118 inp_stride2, 

119 out_stride0, 

120 out_stride1, 

121 out_stride2, 

122 inp_shape0, 

123 inp_shape1, 

124 inp_shape2, 

125 out_shape0, 

126 out_shape1, 

127 out_shape2, 

128 BLOCK_N: tl.constexpr, 

129 BLOCK_K: tl.constexpr, 

130): 

131 """Process 3D repeat: one program handles one (m, n_block, k_block)""" 

132 pid_m = tle.program_id(0) 

133 pid_nk = tle.program_id(1) 

134 

135 num_k_blocks = tl.cdiv(out_shape2, BLOCK_K) 

136 pid_n = pid_nk // num_k_blocks 

137 pid_k = pid_nk % num_k_blocks 

138 

139 m_idx = pid_m 

140 if m_idx >= out_shape0: 

141 return 

142 

143 inp_m = m_idx % inp_shape0 

144 

145 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

146 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) 

147 

148 mask_n = offs_n < out_shape1 

149 mask_k = offs_k < out_shape2 

150 mask = mask_n[:, None] & mask_k[None, :] 

151 

152 inp_n = offs_n % inp_shape1 

153 inp_k = offs_k % inp_shape2 

154 

155 inp_ptrs = ( 

156 inp_ptr 

157 + inp_m * inp_stride0 

158 + inp_n[:, None] * inp_stride1 

159 + inp_k[None, :] * inp_stride2 

160 ) 

161 data = tl.load(inp_ptrs, mask=mask, other=0.0) 

162 

163 out_ptrs = ( 

164 out_ptr 

165 + m_idx * out_stride0 

166 + offs_n[:, None] * out_stride1 

167 + offs_k[None, :] * out_stride2 

168 ) 

169 tl.store(out_ptrs, data, mask=mask) 

170 

171 

172@libentry() 

173@triton.autotune( 

174 configs=[ 

175 triton.Config({"BLOCK_K": 32, "BLOCK_L": 32}, num_warps=4), 

176 triton.Config({"BLOCK_K": 64, "BLOCK_L": 32}, num_warps=4), 

177 triton.Config({"BLOCK_K": 32, "BLOCK_L": 64}, num_warps=4), 

178 triton.Config({"BLOCK_K": 64, "BLOCK_L": 64}, num_warps=8), 

179 triton.Config({"BLOCK_K": 128, "BLOCK_L": 32}, num_warps=4), 

180 triton.Config({"BLOCK_K": 32, "BLOCK_L": 128}, num_warps=4), 

181 triton.Config({"BLOCK_K": 128, "BLOCK_L": 64}, num_warps=8), 

182 triton.Config({"BLOCK_K": 64, "BLOCK_L": 128}, num_warps=8), 

183 ], 

184 key=["out_shape2", "out_shape3"], 

185) 

186@triton.jit 

187def repeat_kernel_4d( 

188 inp_ptr, 

189 out_ptr, 

190 inp_stride0, 

191 inp_stride1, 

192 inp_stride2, 

193 inp_stride3, 

194 out_stride0, 

195 out_stride1, 

196 out_stride2, 

197 out_stride3, 

198 inp_shape0, 

199 inp_shape1, 

200 inp_shape2, 

201 inp_shape3, 

202 out_shape0, 

203 out_shape1, 

204 out_shape2, 

205 out_shape3, 

206 BLOCK_K: tl.constexpr, 

207 BLOCK_L: tl.constexpr, 

208): 

209 """Process 4D repeat: one program handles one (m, n, k_block, l_block)""" 

210 pid_mn = tle.program_id(0) 

211 pid_kl = tle.program_id(1) 

212 

213 num_l_blocks = tl.cdiv(out_shape3, BLOCK_L) 

214 pid_k = pid_kl // num_l_blocks 

215 pid_l = pid_kl % num_l_blocks 

216 

217 # Flatten m, n 

218 m_idx = pid_mn // out_shape1 

219 n_idx = pid_mn % out_shape1 

220 

221 if m_idx >= out_shape0: 

222 return 

223 

224 inp_m = m_idx % inp_shape0 

225 inp_n = n_idx % inp_shape1 

226 

227 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) 

228 offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) 

229 

230 mask_k = offs_k < out_shape2 

231 mask_l = offs_l < out_shape3 

232 mask = mask_k[:, None] & mask_l[None, :] 

233 

234 inp_k = offs_k % inp_shape2 

235 inp_l = offs_l % inp_shape3 

236 

237 inp_ptrs = ( 

238 inp_ptr 

239 + inp_m * inp_stride0 

240 + inp_n * inp_stride1 

241 + inp_k[:, None] * inp_stride2 

242 + inp_l[None, :] * inp_stride3 

243 ) 

244 data = tl.load(inp_ptrs, mask=mask, other=0.0) 

245 

246 out_ptrs = ( 

247 out_ptr 

248 + m_idx * out_stride0 

249 + n_idx * out_stride1 

250 + offs_k[:, None] * out_stride2 

251 + offs_l[None, :] * out_stride3 

252 ) 

253 tl.store(out_ptrs, data, mask=mask) 

254 

255 

256@libentry() 

257@triton.autotune( 

258 configs=[ 

259 triton.Config({"BLOCK_SIZE": 256}, num_warps=4), 

260 triton.Config({"BLOCK_SIZE": 512}, num_warps=4), 

261 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8), 

262 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), 

263 triton.Config({"BLOCK_SIZE": 4096}, num_warps=16), 

264 ], 

265 key=["num_tasks"], 

266) 

267@triton.jit 

268def repeat_kernel_nd_flat( 

269 inp_ptr, 

270 out_ptr, 

271 num_tasks, 

272 inp_shape0, 

273 inp_shape1, 

274 inp_shape2, 

275 inp_shape3, 

276 inp_shape4, 

277 out_shape0, 

278 out_shape1, 

279 out_shape2, 

280 out_shape3, 

281 out_shape4, 

282 inp_stride0, 

283 inp_stride1, 

284 inp_stride2, 

285 inp_stride3, 

286 inp_stride4, 

287 out_stride0, 

288 out_stride1, 

289 out_stride2, 

290 out_stride3, 

291 out_stride4, 

292 rank: tl.constexpr, 

293 BLOCK_SIZE: tl.constexpr, 

294): 

295 """Generic N-D repeat kernel (up to 5D) using flat indexing with modulo""" 

296 pid = tle.program_id(0) 

297 num_ctas = tle.num_programs(0) 

298 

299 for idx in range(pid * BLOCK_SIZE, num_tasks, num_ctas * BLOCK_SIZE): 

300 offs = idx + tl.arange(0, BLOCK_SIZE) 

301 mask = offs < num_tasks 

302 

303 remaining = offs 

304 

305 # Unroll for up to 5D 

306 if rank >= 5: 

307 out_idx4 = remaining % out_shape4 

308 inp_idx4 = out_idx4 % inp_shape4 

309 remaining = remaining // out_shape4 

310 else: 

311 out_idx4 = tl.zeros_like(offs) 

312 inp_idx4 = tl.zeros_like(offs) 

313 

314 if rank >= 4: 

315 out_idx3 = remaining % out_shape3 

316 inp_idx3 = out_idx3 % inp_shape3 

317 remaining = remaining // out_shape3 

318 else: 

319 out_idx3 = tl.zeros_like(offs) 

320 inp_idx3 = tl.zeros_like(offs) 

321 

322 if rank >= 3: 

323 out_idx2 = remaining % out_shape2 

324 inp_idx2 = out_idx2 % inp_shape2 

325 remaining = remaining // out_shape2 

326 else: 

327 out_idx2 = tl.zeros_like(offs) 

328 inp_idx2 = tl.zeros_like(offs) 

329 

330 if rank >= 2: 

331 out_idx1 = remaining % out_shape1 

332 inp_idx1 = out_idx1 % inp_shape1 

333 remaining = remaining // out_shape1 

334 else: 

335 out_idx1 = tl.zeros_like(offs) 

336 inp_idx1 = tl.zeros_like(offs) 

337 

338 out_idx0 = remaining 

339 inp_idx0 = out_idx0 % inp_shape0 

340 

341 inp_offset = ( 

342 inp_idx0 * inp_stride0 

343 + inp_idx1 * inp_stride1 

344 + inp_idx2 * inp_stride2 

345 + inp_idx3 * inp_stride3 

346 + inp_idx4 * inp_stride4 

347 ) 

348 out_offset = ( 

349 out_idx0 * out_stride0 

350 + out_idx1 * out_stride1 

351 + out_idx2 * out_stride2 

352 + out_idx3 * out_stride3 

353 + out_idx4 * out_stride4 

354 ) 

355 

356 data = tl.load(inp_ptr + inp_offset, mask=mask) 

357 tl.store(out_ptr + out_offset, data, mask=mask) 

358 

359 

360def repeat(inp: torch.Tensor, sizes) -> torch.Tensor: 

361 logger.debug("GEMS_MTHREADS REPEAT") 

362 

363 in0_rank = inp.dim() 

364 sizes_rank = len(sizes) 

365 in0_shape = list(inp.shape) 

366 sizes_shape = list(sizes) 

367 

368 # Normalize shapes - for repeat, sizes_rank must be >= in0_rank 

369 assert ( 

370 sizes_rank >= in0_rank 

371 ), "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor" 

372 

373 if sizes_rank > in0_rank: 

374 diff = sizes_rank - in0_rank 

375 in0_shape = [1] * diff + in0_shape 

376 

377 # Check for empty and compute output shape 

378 is_empty = False 

379 out_shape = [] 

380 for i in range(len(in0_shape)): 

381 assert ( 

382 sizes_shape[i] >= 0 

383 ), f"the number of repetitions per dimension out of range (expected to >= 0) but got {sizes_shape[i]}" 

384 if sizes_shape[i] == 0: 

385 is_empty = True 

386 out_shape.append(in0_shape[i] * sizes_shape[i]) 

387 

388 out = torch.empty(out_shape, device=inp.device, dtype=inp.dtype) 

389 

390 if is_empty: 

391 return out 

392 

393 inp = inp.reshape(in0_shape) 

394 rank = len(out_shape) 

395 num_tasks = out.numel() 

396 

397 # Get strides (handle 0-sized dimensions) 

398 inp_strides = list(inp.stride()) 

399 out_strides = list(out.stride()) 

400 

401 with torch_device_fn.device(inp.device.index): 

402 if rank == 1: 

403 # 1D case with autotune 

404 grid = lambda META: (triton.cdiv(out_shape[0], META["BLOCK_SIZE"]),) 

405 repeat_kernel_1d[grid]( 

406 inp, 

407 out, 

408 inp_strides[0] if inp_strides[0] != 0 else 1, 

409 out_strides[0] if out_strides[0] != 0 else 1, 

410 in0_shape[0], 

411 out_shape[0], 

412 ) 

413 elif rank == 2: 

414 # 2D case - use 2D blocking with autotune 

415 grid = lambda META: ( 

416 triton.cdiv(out_shape[0], META["BLOCK_M"]), 

417 triton.cdiv(out_shape[1], META["BLOCK_N"]), 

418 ) 

419 repeat_kernel_2d[grid]( 

420 inp, 

421 out, 

422 inp_strides[0], 

423 inp_strides[1], 

424 out_strides[0], 

425 out_strides[1], 

426 in0_shape[0], 

427 in0_shape[1], 

428 out_shape[0], 

429 out_shape[1], 

430 ) 

431 elif rank == 3: 

432 # 3D case 

433 grid = lambda META: ( 

434 out_shape[0], 

435 triton.cdiv(out_shape[1], META["BLOCK_N"]) 

436 * triton.cdiv(out_shape[2], META["BLOCK_K"]), 

437 ) 

438 repeat_kernel_3d[grid]( 

439 inp, 

440 out, 

441 inp_strides[0], 

442 inp_strides[1], 

443 inp_strides[2], 

444 out_strides[0], 

445 out_strides[1], 

446 out_strides[2], 

447 in0_shape[0], 

448 in0_shape[1], 

449 in0_shape[2], 

450 out_shape[0], 

451 out_shape[1], 

452 out_shape[2], 

453 ) 

454 elif rank == 4: 

455 # 4D case - use 2D grid kernel 

456 num_mn = out_shape[0] * out_shape[1] 

457 grid = lambda META: ( 

458 num_mn, 

459 triton.cdiv(out_shape[2], META["BLOCK_K"]) 

460 * triton.cdiv(out_shape[3], META["BLOCK_L"]), 

461 ) 

462 repeat_kernel_4d[grid]( 

463 inp, 

464 out, 

465 inp_strides[0], 

466 inp_strides[1], 

467 inp_strides[2], 

468 inp_strides[3], 

469 out_strides[0], 

470 out_strides[1], 

471 out_strides[2], 

472 out_strides[3], 

473 in0_shape[0], 

474 in0_shape[1], 

475 in0_shape[2], 

476 in0_shape[3], 

477 out_shape[0], 

478 out_shape[1], 

479 out_shape[2], 

480 out_shape[3], 

481 ) 

482 else: 

483 # 5D+ case - use generic kernel with autotune 

484 # Pad shapes and strides to 5D 

485 in0_shape_padded = list(in0_shape) 

486 out_shape_padded = list(out_shape) 

487 inp_strides_padded = list(inp_strides) 

488 out_strides_padded = list(out_strides) 

489 

490 while len(in0_shape_padded) < 5: 

491 in0_shape_padded = [1] + in0_shape_padded 

492 out_shape_padded = [1] + out_shape_padded 

493 inp_strides_padded = [0] + inp_strides_padded 

494 out_strides_padded = [0] + out_strides_padded 

495 

496 grid = lambda META: ( 

497 min(65535, triton.cdiv(num_tasks, META["BLOCK_SIZE"])), 

498 ) 

499 repeat_kernel_nd_flat[grid]( 

500 inp, 

501 out, 

502 num_tasks, 

503 in0_shape_padded[0], 

504 in0_shape_padded[1], 

505 in0_shape_padded[2], 

506 in0_shape_padded[3], 

507 in0_shape_padded[4], 

508 out_shape_padded[0], 

509 out_shape_padded[1], 

510 out_shape_padded[2], 

511 out_shape_padded[3], 

512 out_shape_padded[4], 

513 inp_strides_padded[0], 

514 inp_strides_padded[1], 

515 inp_strides_padded[2], 

516 inp_strides_padded[3], 

517 inp_strides_padded[4], 

518 out_strides_padded[0], 

519 out_strides_padded[1], 

520 out_strides_padded[2], 

521 out_strides_padded[3], 

522 out_strides_padded[4], 

523 rank=rank, 

524 ) 

525 

526 return out