Coverage for src/flag_gems/runtime/backend/_mthreads/ops/tile.py: 0%

177 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import triton_lang_extension as tle 

9from flag_gems.utils.libentry import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

14 torch._C.DispatchKey.CompositeImplicitAutograd 

15) 

16 

17 

18@libentry() 

19@triton.autotune( 

20 configs=[ 

21 triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4), 

22 triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4), 

23 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4), 

24 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8), 

25 triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4), 

26 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4), 

27 triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=8), 

28 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8), 

29 ], 

30 key=["out_shape0", "out_shape1"], 

31) 

32@triton.jit 

33def tile_kernel_2d( 

34 inp_ptr, 

35 out_ptr, 

36 inp_stride0, 

37 inp_stride1, 

38 out_stride0, 

39 out_stride1, 

40 inp_shape0, 

41 inp_shape1, 

42 out_shape0, 

43 out_shape1, 

44 BLOCK_M: tl.constexpr, 

45 BLOCK_N: tl.constexpr, 

46): 

47 pid_m = tle.program_id(0) 

48 pid_n = tle.program_id(1) 

49 

50 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

51 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

52 

53 mask_m = offs_m < out_shape0 

54 mask_n = offs_n < out_shape1 

55 mask = mask_m[:, None] & mask_n[None, :] 

56 

57 # Map output indices to input indices using modulo 

58 inp_offs_m = offs_m % inp_shape0 

59 inp_offs_n = offs_n % inp_shape1 

60 

61 # Load from input 

62 inp_ptrs = ( 

63 inp_ptr + inp_offs_m[:, None] * inp_stride0 + inp_offs_n[None, :] * inp_stride1 

64 ) 

65 data = tl.load(inp_ptrs, mask=mask, other=0.0) 

66 

67 # Store to output 

68 out_ptrs = out_ptr + offs_m[:, None] * out_stride0 + offs_n[None, :] * out_stride1 

69 tl.store(out_ptrs, data, mask=mask) 

70 

71 

72@libentry() 

73@triton.autotune( 

74 configs=[ 

75 triton.Config({"BLOCK_SIZE": 256}, num_warps=4), 

76 triton.Config({"BLOCK_SIZE": 512}, num_warps=4), 

77 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8), 

78 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), 

79 ], 

80 key=["out_shape0"], 

81) 

82@triton.jit 

83def tile_kernel_1d( 

84 inp_ptr, 

85 out_ptr, 

86 inp_stride0, 

87 out_stride0, 

88 inp_shape0, 

89 out_shape0, 

90 BLOCK_SIZE: tl.constexpr, 

91): 

92 pid = tle.program_id(0) 

93 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

94 mask = offs < out_shape0 

95 

96 # Map output indices to input indices 

97 inp_offs = offs % inp_shape0 

98 

99 # Load and store 

100 data = tl.load(inp_ptr + inp_offs * inp_stride0, mask=mask) 

101 tl.store(out_ptr + offs * out_stride0, data, mask=mask) 

102 

103 

104@libentry() 

105@triton.autotune( 

106 configs=[ 

107 triton.Config({"BLOCK_N": 32, "BLOCK_K": 32}, num_warps=4), 

108 triton.Config({"BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4), 

109 triton.Config({"BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4), 

110 triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8), 

111 triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4), 

112 triton.Config({"BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4), 

113 ], 

114 key=["out_shape1", "out_shape2"], 

115) 

116@triton.jit 

117def tile_kernel_3d( 

118 inp_ptr, 

119 out_ptr, 

120 inp_stride0, 

121 inp_stride1, 

122 inp_stride2, 

123 out_stride0, 

124 out_stride1, 

125 out_stride2, 

126 inp_shape0, 

127 inp_shape1, 

128 inp_shape2, 

129 out_shape0, 

130 out_shape1, 

131 out_shape2, 

132 BLOCK_N: tl.constexpr, 

133 BLOCK_K: tl.constexpr, 

134): 

135 """Process 3D tile: one program handles one (m, n_block, k_block)""" 

136 pid_m = tle.program_id(0) 

137 pid_nk = tle.program_id(1) 

138 

139 num_k_blocks = tl.cdiv(out_shape2, BLOCK_K) 

140 pid_n = pid_nk // num_k_blocks 

141 pid_k = pid_nk % num_k_blocks 

142 

143 m_idx = pid_m 

144 if m_idx >= out_shape0: 

145 return 

146 

147 inp_m = m_idx % inp_shape0 

148 

149 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

150 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) 

151 

152 mask_n = offs_n < out_shape1 

153 mask_k = offs_k < out_shape2 

154 mask = mask_n[:, None] & mask_k[None, :] 

155 

156 inp_n = offs_n % inp_shape1 

157 inp_k = offs_k % inp_shape2 

158 

159 inp_ptrs = ( 

160 inp_ptr 

161 + inp_m * inp_stride0 

162 + inp_n[:, None] * inp_stride1 

163 + inp_k[None, :] * inp_stride2 

164 ) 

165 data = tl.load(inp_ptrs, mask=mask, other=0.0) 

166 

167 out_ptrs = ( 

168 out_ptr 

169 + m_idx * out_stride0 

170 + offs_n[:, None] * out_stride1 

171 + offs_k[None, :] * out_stride2 

172 ) 

173 tl.store(out_ptrs, data, mask=mask) 

174 

175 

176@libentry() 

177@triton.autotune( 

178 configs=[ 

179 triton.Config({"BLOCK_K": 32, "BLOCK_L": 32}, num_warps=4), 

180 triton.Config({"BLOCK_K": 64, "BLOCK_L": 32}, num_warps=4), 

181 triton.Config({"BLOCK_K": 32, "BLOCK_L": 64}, num_warps=4), 

182 triton.Config({"BLOCK_K": 64, "BLOCK_L": 64}, num_warps=8), 

183 triton.Config({"BLOCK_K": 128, "BLOCK_L": 32}, num_warps=4), 

184 triton.Config({"BLOCK_K": 32, "BLOCK_L": 128}, num_warps=4), 

185 ], 

186 key=["out_shape2", "out_shape3"], 

187) 

188@triton.jit 

189def tile_kernel_4d( 

190 inp_ptr, 

191 out_ptr, 

192 inp_stride0, 

193 inp_stride1, 

194 inp_stride2, 

195 inp_stride3, 

196 out_stride0, 

197 out_stride1, 

198 out_stride2, 

199 out_stride3, 

200 inp_shape0, 

201 inp_shape1, 

202 inp_shape2, 

203 inp_shape3, 

204 out_shape0, 

205 out_shape1, 

206 out_shape2, 

207 out_shape3, 

208 BLOCK_K: tl.constexpr, 

209 BLOCK_L: tl.constexpr, 

210): 

211 """Process 4D tile: one program handles one (m, n, k_block, l_block)""" 

212 pid_mn = tle.program_id(0) 

213 pid_kl = tle.program_id(1) 

214 

215 num_l_blocks = tl.cdiv(out_shape3, BLOCK_L) 

216 pid_k = pid_kl // num_l_blocks 

217 pid_l = pid_kl % num_l_blocks 

218 

219 # Flatten m, n 

220 m_idx = pid_mn // out_shape1 

221 n_idx = pid_mn % out_shape1 

222 

223 if m_idx >= out_shape0: 

224 return 

225 

226 inp_m = m_idx % inp_shape0 

227 inp_n = n_idx % inp_shape1 

228 

229 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) 

230 offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) 

231 

232 mask_k = offs_k < out_shape2 

233 mask_l = offs_l < out_shape3 

234 mask = mask_k[:, None] & mask_l[None, :] 

235 

236 inp_k = offs_k % inp_shape2 

237 inp_l = offs_l % inp_shape3 

238 

239 inp_ptrs = ( 

240 inp_ptr 

241 + inp_m * inp_stride0 

242 + inp_n * inp_stride1 

243 + inp_k[:, None] * inp_stride2 

244 + inp_l[None, :] * inp_stride3 

245 ) 

246 data = tl.load(inp_ptrs, mask=mask, other=0.0) 

247 

248 out_ptrs = ( 

249 out_ptr 

250 + m_idx * out_stride0 

251 + n_idx * out_stride1 

252 + offs_k[:, None] * out_stride2 

253 + offs_l[None, :] * out_stride3 

254 ) 

255 tl.store(out_ptrs, data, mask=mask) 

256 

257 

258@libentry() 

259@triton.jit 

260def tile_kernel_nd_flat( 

261 inp_ptr, 

262 out_ptr, 

263 num_tasks, 

264 inp_shape0, 

265 inp_shape1, 

266 inp_shape2, 

267 inp_shape3, 

268 inp_shape4, 

269 out_shape0, 

270 out_shape1, 

271 out_shape2, 

272 out_shape3, 

273 out_shape4, 

274 inp_stride0, 

275 inp_stride1, 

276 inp_stride2, 

277 inp_stride3, 

278 inp_stride4, 

279 out_stride0, 

280 out_stride1, 

281 out_stride2, 

282 out_stride3, 

283 out_stride4, 

284 rank: tl.constexpr, 

285 BLOCK_SIZE: tl.constexpr, 

286): 

287 """Generic N-D tile kernel (up to 5D) using flat indexing with modulo""" 

288 pid = tle.program_id(0) 

289 num_ctas = tle.num_programs(0) 

290 

291 for idx in range(pid * BLOCK_SIZE, num_tasks, num_ctas * BLOCK_SIZE): 

292 offs = idx + tl.arange(0, BLOCK_SIZE) 

293 mask = offs < num_tasks 

294 

295 remaining = offs 

296 

297 # Unroll for up to 5D 

298 if rank >= 5: 

299 out_idx4 = remaining % out_shape4 

300 inp_idx4 = out_idx4 % inp_shape4 

301 remaining = remaining // out_shape4 

302 else: 

303 out_idx4 = tl.zeros_like(offs) 

304 inp_idx4 = tl.zeros_like(offs) 

305 

306 if rank >= 4: 

307 out_idx3 = remaining % out_shape3 

308 inp_idx3 = out_idx3 % inp_shape3 

309 remaining = remaining // out_shape3 

310 else: 

311 out_idx3 = tl.zeros_like(offs) 

312 inp_idx3 = tl.zeros_like(offs) 

313 

314 if rank >= 3: 

315 out_idx2 = remaining % out_shape2 

316 inp_idx2 = out_idx2 % inp_shape2 

317 remaining = remaining // out_shape2 

318 else: 

319 out_idx2 = tl.zeros_like(offs) 

320 inp_idx2 = tl.zeros_like(offs) 

321 

322 if rank >= 2: 

323 out_idx1 = remaining % out_shape1 

324 inp_idx1 = out_idx1 % inp_shape1 

325 remaining = remaining // out_shape1 

326 else: 

327 out_idx1 = tl.zeros_like(offs) 

328 inp_idx1 = tl.zeros_like(offs) 

329 

330 out_idx0 = remaining 

331 inp_idx0 = out_idx0 % inp_shape0 

332 

333 inp_offset = ( 

334 inp_idx0 * inp_stride0 

335 + inp_idx1 * inp_stride1 

336 + inp_idx2 * inp_stride2 

337 + inp_idx3 * inp_stride3 

338 + inp_idx4 * inp_stride4 

339 ) 

340 out_offset = ( 

341 out_idx0 * out_stride0 

342 + out_idx1 * out_stride1 

343 + out_idx2 * out_stride2 

344 + out_idx3 * out_stride3 

345 + out_idx4 * out_stride4 

346 ) 

347 

348 data = tl.load(inp_ptr + inp_offset, mask=mask) 

349 tl.store(out_ptr + out_offset, data, mask=mask) 

350 

351 

352def tile(inp: torch.Tensor, dims) -> torch.Tensor: 

353 logger.debug("GEMS TILE") 

354 if torch.is_grad_enabled(): 

355 return torch.ops.aten.tile.default.redispatch(_FALLBACK_KEYSET, inp, dims) 

356 

357 in0_rank = inp.dim() 

358 dims_rank = len(dims) 

359 in0_shape = list(inp.shape) 

360 dims_shape = list(dims) 

361 

362 # Normalize shapes 

363 if dims_rank < in0_rank: 

364 diff = in0_rank - dims_rank 

365 dims_shape = [1] * diff + dims_shape 

366 elif dims_rank > in0_rank: 

367 diff = dims_rank - in0_rank 

368 in0_shape = [1] * diff + in0_shape 

369 

370 # Check for empty and compute output shape 

371 is_empty = False 

372 out_shape = [] 

373 for i in range(len(in0_shape)): 

374 assert ( 

375 dims_shape[i] >= 0 

376 ), f"the number of repetitions per dimension out of range (expected to >= 0) but got {dims_shape[i]}" 

377 if dims_shape[i] == 0: 

378 is_empty = True 

379 out_shape.append(in0_shape[i] * dims_shape[i]) 

380 

381 out = torch.empty(out_shape, device=inp.device, dtype=inp.dtype) 

382 

383 if is_empty: 

384 return out 

385 

386 inp = inp.reshape(in0_shape) 

387 rank = len(out_shape) 

388 num_tasks = out.numel() 

389 

390 # Get strides (handle 0-sized dimensions) 

391 inp_strides = list(inp.stride()) 

392 out_strides = list(out.stride()) 

393 

394 with torch_device_fn.device(inp.device.index): 

395 if rank == 1: 

396 # 1D case with autotune 

397 grid = lambda META: (triton.cdiv(out_shape[0], META["BLOCK_SIZE"]),) 

398 tile_kernel_1d[grid]( 

399 inp, 

400 out, 

401 inp_strides[0] if inp_strides[0] != 0 else 1, 

402 out_strides[0] if out_strides[0] != 0 else 1, 

403 in0_shape[0], 

404 out_shape[0], 

405 ) 

406 elif rank == 2: 

407 # 2D case - use 2D blocking with autotune 

408 grid = lambda META: ( 

409 triton.cdiv(out_shape[0], META["BLOCK_M"]), 

410 triton.cdiv(out_shape[1], META["BLOCK_N"]), 

411 ) 

412 tile_kernel_2d[grid]( 

413 inp, 

414 out, 

415 inp_strides[0], 

416 inp_strides[1], 

417 out_strides[0], 

418 out_strides[1], 

419 in0_shape[0], 

420 in0_shape[1], 

421 out_shape[0], 

422 out_shape[1], 

423 ) 

424 elif rank == 3: 

425 # 3D case 

426 grid = lambda META: ( 

427 out_shape[0], 

428 triton.cdiv(out_shape[1], META["BLOCK_N"]) 

429 * triton.cdiv(out_shape[2], META["BLOCK_K"]), 

430 ) 

431 tile_kernel_3d[grid]( 

432 inp, 

433 out, 

434 inp_strides[0], 

435 inp_strides[1], 

436 inp_strides[2], 

437 out_strides[0], 

438 out_strides[1], 

439 out_strides[2], 

440 in0_shape[0], 

441 in0_shape[1], 

442 in0_shape[2], 

443 out_shape[0], 

444 out_shape[1], 

445 out_shape[2], 

446 ) 

447 elif rank == 4: 

448 # 4D case 

449 num_mn = out_shape[0] * out_shape[1] 

450 grid = lambda META: ( 

451 num_mn, 

452 triton.cdiv(out_shape[2], META["BLOCK_K"]) 

453 * triton.cdiv(out_shape[3], META["BLOCK_L"]), 

454 ) 

455 tile_kernel_4d[grid]( 

456 inp, 

457 out, 

458 inp_strides[0], 

459 inp_strides[1], 

460 inp_strides[2], 

461 inp_strides[3], 

462 out_strides[0], 

463 out_strides[1], 

464 out_strides[2], 

465 out_strides[3], 

466 in0_shape[0], 

467 in0_shape[1], 

468 in0_shape[2], 

469 in0_shape[3], 

470 out_shape[0], 

471 out_shape[1], 

472 out_shape[2], 

473 out_shape[3], 

474 ) 

475 else: 

476 # 5D+ case - use generic kernel 

477 BLOCK_SIZE = 1024 

478 grid = (min(65535, triton.cdiv(num_tasks, BLOCK_SIZE)),) 

479 

480 # Pad shapes and strides to 5D 

481 while len(in0_shape) < 5: 

482 in0_shape = [1] + in0_shape 

483 out_shape = [1] + out_shape 

484 inp_strides = [0] + inp_strides 

485 out_strides = [0] + out_strides 

486 

487 tile_kernel_nd_flat[grid]( 

488 inp, 

489 out, 

490 num_tasks, 

491 in0_shape[0], 

492 in0_shape[1], 

493 in0_shape[2], 

494 in0_shape[3], 

495 in0_shape[4], 

496 out_shape[0], 

497 out_shape[1], 

498 out_shape[2], 

499 out_shape[3], 

500 out_shape[4], 

501 inp_strides[0], 

502 inp_strides[1], 

503 inp_strides[2], 

504 inp_strides[3], 

505 inp_strides[4], 

506 out_strides[0], 

507 out_strides[1], 

508 out_strides[2], 

509 out_strides[3], 

510 out_strides[4], 

511 rank=rank, 

512 BLOCK_SIZE=BLOCK_SIZE, 

513 ) 

514 

515 return out