Coverage for src/flag_gems/runtime/backend/_mthreads/ops/tile.py: 0%

174 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import triton_lang_extension as tle 

9from flag_gems.utils.libentry import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.autotune( 

16 configs=[ 

17 triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4), 

18 triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4), 

19 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4), 

20 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8), 

21 triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4), 

22 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4), 

23 triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=8), 

24 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8), 

25 ], 

26 key=["out_shape0", "out_shape1"], 

27) 

28@triton.jit 

29def tile_kernel_2d( 

30 inp_ptr, 

31 out_ptr, 

32 inp_stride0, 

33 inp_stride1, 

34 out_stride0, 

35 out_stride1, 

36 inp_shape0, 

37 inp_shape1, 

38 out_shape0, 

39 out_shape1, 

40 BLOCK_M: tl.constexpr, 

41 BLOCK_N: tl.constexpr, 

42): 

43 pid_m = tle.program_id(0) 

44 pid_n = tle.program_id(1) 

45 

46 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

47 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

48 

49 mask_m = offs_m < out_shape0 

50 mask_n = offs_n < out_shape1 

51 mask = mask_m[:, None] & mask_n[None, :] 

52 

53 # Map output indices to input indices using modulo 

54 inp_offs_m = offs_m % inp_shape0 

55 inp_offs_n = offs_n % inp_shape1 

56 

57 # Load from input 

58 inp_ptrs = ( 

59 inp_ptr + inp_offs_m[:, None] * inp_stride0 + inp_offs_n[None, :] * inp_stride1 

60 ) 

61 data = tl.load(inp_ptrs, mask=mask, other=0.0) 

62 

63 # Store to output 

64 out_ptrs = out_ptr + offs_m[:, None] * out_stride0 + offs_n[None, :] * out_stride1 

65 tl.store(out_ptrs, data, mask=mask) 

66 

67 

68@libentry() 

69@triton.autotune( 

70 configs=[ 

71 triton.Config({"BLOCK_SIZE": 256}, num_warps=4), 

72 triton.Config({"BLOCK_SIZE": 512}, num_warps=4), 

73 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8), 

74 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), 

75 ], 

76 key=["out_shape0"], 

77) 

78@triton.jit 

79def tile_kernel_1d( 

80 inp_ptr, 

81 out_ptr, 

82 inp_stride0, 

83 out_stride0, 

84 inp_shape0, 

85 out_shape0, 

86 BLOCK_SIZE: tl.constexpr, 

87): 

88 pid = tle.program_id(0) 

89 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

90 mask = offs < out_shape0 

91 

92 # Map output indices to input indices 

93 inp_offs = offs % inp_shape0 

94 

95 # Load and store 

96 data = tl.load(inp_ptr + inp_offs * inp_stride0, mask=mask) 

97 tl.store(out_ptr + offs * out_stride0, data, mask=mask) 

98 

99 

100@libentry() 

101@triton.autotune( 

102 configs=[ 

103 triton.Config({"BLOCK_N": 32, "BLOCK_K": 32}, num_warps=4), 

104 triton.Config({"BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4), 

105 triton.Config({"BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4), 

106 triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8), 

107 triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4), 

108 triton.Config({"BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4), 

109 ], 

110 key=["out_shape1", "out_shape2"], 

111) 

112@triton.jit 

113def tile_kernel_3d( 

114 inp_ptr, 

115 out_ptr, 

116 inp_stride0, 

117 inp_stride1, 

118 inp_stride2, 

119 out_stride0, 

120 out_stride1, 

121 out_stride2, 

122 inp_shape0, 

123 inp_shape1, 

124 inp_shape2, 

125 out_shape0, 

126 out_shape1, 

127 out_shape2, 

128 BLOCK_N: tl.constexpr, 

129 BLOCK_K: tl.constexpr, 

130): 

131 """Process 3D tile: one program handles one (m, n_block, k_block)""" 

132 pid_m = tle.program_id(0) 

133 pid_nk = tle.program_id(1) 

134 

135 num_k_blocks = tl.cdiv(out_shape2, BLOCK_K) 

136 pid_n = pid_nk // num_k_blocks 

137 pid_k = pid_nk % num_k_blocks 

138 

139 m_idx = pid_m 

140 if m_idx >= out_shape0: 

141 return 

142 

143 inp_m = m_idx % inp_shape0 

144 

145 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

146 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) 

147 

148 mask_n = offs_n < out_shape1 

149 mask_k = offs_k < out_shape2 

150 mask = mask_n[:, None] & mask_k[None, :] 

151 

152 inp_n = offs_n % inp_shape1 

153 inp_k = offs_k % inp_shape2 

154 

155 inp_ptrs = ( 

156 inp_ptr 

157 + inp_m * inp_stride0 

158 + inp_n[:, None] * inp_stride1 

159 + inp_k[None, :] * inp_stride2 

160 ) 

161 data = tl.load(inp_ptrs, mask=mask, other=0.0) 

162 

163 out_ptrs = ( 

164 out_ptr 

165 + m_idx * out_stride0 

166 + offs_n[:, None] * out_stride1 

167 + offs_k[None, :] * out_stride2 

168 ) 

169 tl.store(out_ptrs, data, mask=mask) 

170 

171 

172@libentry() 

173@triton.autotune( 

174 configs=[ 

175 triton.Config({"BLOCK_K": 32, "BLOCK_L": 32}, num_warps=4), 

176 triton.Config({"BLOCK_K": 64, "BLOCK_L": 32}, num_warps=4), 

177 triton.Config({"BLOCK_K": 32, "BLOCK_L": 64}, num_warps=4), 

178 triton.Config({"BLOCK_K": 64, "BLOCK_L": 64}, num_warps=8), 

179 triton.Config({"BLOCK_K": 128, "BLOCK_L": 32}, num_warps=4), 

180 triton.Config({"BLOCK_K": 32, "BLOCK_L": 128}, num_warps=4), 

181 ], 

182 key=["out_shape2", "out_shape3"], 

183) 

184@triton.jit 

185def tile_kernel_4d( 

186 inp_ptr, 

187 out_ptr, 

188 inp_stride0, 

189 inp_stride1, 

190 inp_stride2, 

191 inp_stride3, 

192 out_stride0, 

193 out_stride1, 

194 out_stride2, 

195 out_stride3, 

196 inp_shape0, 

197 inp_shape1, 

198 inp_shape2, 

199 inp_shape3, 

200 out_shape0, 

201 out_shape1, 

202 out_shape2, 

203 out_shape3, 

204 BLOCK_K: tl.constexpr, 

205 BLOCK_L: tl.constexpr, 

206): 

207 """Process 4D tile: one program handles one (m, n, k_block, l_block)""" 

208 pid_mn = tle.program_id(0) 

209 pid_kl = tle.program_id(1) 

210 

211 num_l_blocks = tl.cdiv(out_shape3, BLOCK_L) 

212 pid_k = pid_kl // num_l_blocks 

213 pid_l = pid_kl % num_l_blocks 

214 

215 # Flatten m, n 

216 m_idx = pid_mn // out_shape1 

217 n_idx = pid_mn % out_shape1 

218 

219 if m_idx >= out_shape0: 

220 return 

221 

222 inp_m = m_idx % inp_shape0 

223 inp_n = n_idx % inp_shape1 

224 

225 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) 

226 offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) 

227 

228 mask_k = offs_k < out_shape2 

229 mask_l = offs_l < out_shape3 

230 mask = mask_k[:, None] & mask_l[None, :] 

231 

232 inp_k = offs_k % inp_shape2 

233 inp_l = offs_l % inp_shape3 

234 

235 inp_ptrs = ( 

236 inp_ptr 

237 + inp_m * inp_stride0 

238 + inp_n * inp_stride1 

239 + inp_k[:, None] * inp_stride2 

240 + inp_l[None, :] * inp_stride3 

241 ) 

242 data = tl.load(inp_ptrs, mask=mask, other=0.0) 

243 

244 out_ptrs = ( 

245 out_ptr 

246 + m_idx * out_stride0 

247 + n_idx * out_stride1 

248 + offs_k[:, None] * out_stride2 

249 + offs_l[None, :] * out_stride3 

250 ) 

251 tl.store(out_ptrs, data, mask=mask) 

252 

253 

254@libentry() 

255@triton.jit 

256def tile_kernel_nd_flat( 

257 inp_ptr, 

258 out_ptr, 

259 num_tasks, 

260 inp_shape0, 

261 inp_shape1, 

262 inp_shape2, 

263 inp_shape3, 

264 inp_shape4, 

265 out_shape0, 

266 out_shape1, 

267 out_shape2, 

268 out_shape3, 

269 out_shape4, 

270 inp_stride0, 

271 inp_stride1, 

272 inp_stride2, 

273 inp_stride3, 

274 inp_stride4, 

275 out_stride0, 

276 out_stride1, 

277 out_stride2, 

278 out_stride3, 

279 out_stride4, 

280 rank: tl.constexpr, 

281 BLOCK_SIZE: tl.constexpr, 

282): 

283 """Generic N-D tile kernel (up to 5D) using flat indexing with modulo""" 

284 pid = tle.program_id(0) 

285 num_ctas = tle.num_programs(0) 

286 

287 for idx in range(pid * BLOCK_SIZE, num_tasks, num_ctas * BLOCK_SIZE): 

288 offs = idx + tl.arange(0, BLOCK_SIZE) 

289 mask = offs < num_tasks 

290 

291 remaining = offs 

292 

293 # Unroll for up to 5D 

294 if rank >= 5: 

295 out_idx4 = remaining % out_shape4 

296 inp_idx4 = out_idx4 % inp_shape4 

297 remaining = remaining // out_shape4 

298 else: 

299 out_idx4 = tl.zeros_like(offs) 

300 inp_idx4 = tl.zeros_like(offs) 

301 

302 if rank >= 4: 

303 out_idx3 = remaining % out_shape3 

304 inp_idx3 = out_idx3 % inp_shape3 

305 remaining = remaining // out_shape3 

306 else: 

307 out_idx3 = tl.zeros_like(offs) 

308 inp_idx3 = tl.zeros_like(offs) 

309 

310 if rank >= 3: 

311 out_idx2 = remaining % out_shape2 

312 inp_idx2 = out_idx2 % inp_shape2 

313 remaining = remaining // out_shape2 

314 else: 

315 out_idx2 = tl.zeros_like(offs) 

316 inp_idx2 = tl.zeros_like(offs) 

317 

318 if rank >= 2: 

319 out_idx1 = remaining % out_shape1 

320 inp_idx1 = out_idx1 % inp_shape1 

321 remaining = remaining // out_shape1 

322 else: 

323 out_idx1 = tl.zeros_like(offs) 

324 inp_idx1 = tl.zeros_like(offs) 

325 

326 out_idx0 = remaining 

327 inp_idx0 = out_idx0 % inp_shape0 

328 

329 inp_offset = ( 

330 inp_idx0 * inp_stride0 

331 + inp_idx1 * inp_stride1 

332 + inp_idx2 * inp_stride2 

333 + inp_idx3 * inp_stride3 

334 + inp_idx4 * inp_stride4 

335 ) 

336 out_offset = ( 

337 out_idx0 * out_stride0 

338 + out_idx1 * out_stride1 

339 + out_idx2 * out_stride2 

340 + out_idx3 * out_stride3 

341 + out_idx4 * out_stride4 

342 ) 

343 

344 data = tl.load(inp_ptr + inp_offset, mask=mask) 

345 tl.store(out_ptr + out_offset, data, mask=mask) 

346 

347 

348def tile(inp: torch.Tensor, dims) -> torch.Tensor: 

349 logger.debug("GEMS TILE") 

350 

351 in0_rank = inp.dim() 

352 dims_rank = len(dims) 

353 in0_shape = list(inp.shape) 

354 dims_shape = list(dims) 

355 

356 # Normalize shapes 

357 if dims_rank < in0_rank: 

358 diff = in0_rank - dims_rank 

359 dims_shape = [1] * diff + dims_shape 

360 elif dims_rank > in0_rank: 

361 diff = dims_rank - in0_rank 

362 in0_shape = [1] * diff + in0_shape 

363 

364 # Check for empty and compute output shape 

365 is_empty = False 

366 out_shape = [] 

367 for i in range(len(in0_shape)): 

368 assert ( 

369 dims_shape[i] >= 0 

370 ), f"the number of repetitions per dimension out of range (expected to >= 0) but got {dims_shape[i]}" 

371 if dims_shape[i] == 0: 

372 is_empty = True 

373 out_shape.append(in0_shape[i] * dims_shape[i]) 

374 

375 out = torch.empty(out_shape, device=inp.device, dtype=inp.dtype) 

376 

377 if is_empty: 

378 return out 

379 

380 inp = inp.reshape(in0_shape) 

381 rank = len(out_shape) 

382 num_tasks = out.numel() 

383 

384 # Get strides (handle 0-sized dimensions) 

385 inp_strides = list(inp.stride()) 

386 out_strides = list(out.stride()) 

387 

388 with torch_device_fn.device(inp.device.index): 

389 if rank == 1: 

390 # 1D case with autotune 

391 grid = lambda META: (triton.cdiv(out_shape[0], META["BLOCK_SIZE"]),) 

392 tile_kernel_1d[grid]( 

393 inp, 

394 out, 

395 inp_strides[0] if inp_strides[0] != 0 else 1, 

396 out_strides[0] if out_strides[0] != 0 else 1, 

397 in0_shape[0], 

398 out_shape[0], 

399 ) 

400 elif rank == 2: 

401 # 2D case - use 2D blocking with autotune 

402 grid = lambda META: ( 

403 triton.cdiv(out_shape[0], META["BLOCK_M"]), 

404 triton.cdiv(out_shape[1], META["BLOCK_N"]), 

405 ) 

406 tile_kernel_2d[grid]( 

407 inp, 

408 out, 

409 inp_strides[0], 

410 inp_strides[1], 

411 out_strides[0], 

412 out_strides[1], 

413 in0_shape[0], 

414 in0_shape[1], 

415 out_shape[0], 

416 out_shape[1], 

417 ) 

418 elif rank == 3: 

419 # 3D case 

420 grid = lambda META: ( 

421 out_shape[0], 

422 triton.cdiv(out_shape[1], META["BLOCK_N"]) 

423 * triton.cdiv(out_shape[2], META["BLOCK_K"]), 

424 ) 

425 tile_kernel_3d[grid]( 

426 inp, 

427 out, 

428 inp_strides[0], 

429 inp_strides[1], 

430 inp_strides[2], 

431 out_strides[0], 

432 out_strides[1], 

433 out_strides[2], 

434 in0_shape[0], 

435 in0_shape[1], 

436 in0_shape[2], 

437 out_shape[0], 

438 out_shape[1], 

439 out_shape[2], 

440 ) 

441 elif rank == 4: 

442 # 4D case 

443 num_mn = out_shape[0] * out_shape[1] 

444 grid = lambda META: ( 

445 num_mn, 

446 triton.cdiv(out_shape[2], META["BLOCK_K"]) 

447 * triton.cdiv(out_shape[3], META["BLOCK_L"]), 

448 ) 

449 tile_kernel_4d[grid]( 

450 inp, 

451 out, 

452 inp_strides[0], 

453 inp_strides[1], 

454 inp_strides[2], 

455 inp_strides[3], 

456 out_strides[0], 

457 out_strides[1], 

458 out_strides[2], 

459 out_strides[3], 

460 in0_shape[0], 

461 in0_shape[1], 

462 in0_shape[2], 

463 in0_shape[3], 

464 out_shape[0], 

465 out_shape[1], 

466 out_shape[2], 

467 out_shape[3], 

468 ) 

469 else: 

470 # 5D+ case - use generic kernel 

471 BLOCK_SIZE = 1024 

472 grid = (min(65535, triton.cdiv(num_tasks, BLOCK_SIZE)),) 

473 

474 # Pad shapes and strides to 5D 

475 while len(in0_shape) < 5: 

476 in0_shape = [1] + in0_shape 

477 out_shape = [1] + out_shape 

478 inp_strides = [0] + inp_strides 

479 out_strides = [0] + out_strides 

480 

481 tile_kernel_nd_flat[grid]( 

482 inp, 

483 out, 

484 num_tasks, 

485 in0_shape[0], 

486 in0_shape[1], 

487 in0_shape[2], 

488 in0_shape[3], 

489 in0_shape[4], 

490 out_shape[0], 

491 out_shape[1], 

492 out_shape[2], 

493 out_shape[3], 

494 out_shape[4], 

495 inp_strides[0], 

496 inp_strides[1], 

497 inp_strides[2], 

498 inp_strides[3], 

499 inp_strides[4], 

500 out_strides[0], 

501 out_strides[1], 

502 out_strides[2], 

503 out_strides[3], 

504 out_strides[4], 

505 rank=rank, 

506 BLOCK_SIZE=BLOCK_SIZE, 

507 ) 

508 

509 return out