Coverage for src/flag_gems/runtime/backend/_cambricon/ops/tile.py: 0%

306 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import importlib 

2import logging 

3import os 

4from typing import Callable, List, Mapping 

5 

6import torch 

7import triton 

8import triton.language as tl 

9 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils.code_cache import code_cache_dir 

12from flag_gems.utils.code_utils import IndentedBuffer 

13 

14from ..utils import TOTAL_CORE_NUM 

15 

16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

17 

18 

19# --------------------------- tile wrapper genration ----------------------------------- 

20def parameter_for_wrapper() -> str: 

21 """Generate parameter declaration with type annotation for wrapper function. 

22 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

23 """ 

24 parameters: List[str] = [] 

25 

26 parameters.append("in0") 

27 parameters.append("dims") 

28 return ", ".join(parameters) 

29 

30 

31def parameter_for_wrapper_out() -> str: 

32 """Generate parameter declaration with type annotation for wrapper function. 

33 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

34 """ 

35 parameters: List[str] = [] 

36 

37 parameters.append("in0") 

38 parameters.append("out0") 

39 

40 return ", ".join(parameters) 

41 

42 

43def parameter_ref_for_wrapper() -> str: 

44 """Generate parameter reference for wrapper function. 

45 Example: in0, val0, out0, out0_offset 

46 """ 

47 parameters: List[str] = [] 

48 

49 parameters.append("in0") 

50 parameters.append("out0") 

51 

52 return ", ".join(parameters) 

53 

54 

55def output_ref_for_wrapper() -> str: 

56 return "out0" 

57 

58 

59def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

60 code.writeline("import math") 

61 code.writeline("import torch") 

62 code.writeline("import triton") 

63 code.writeline("from triton import language as tl") 

64 code.newline() 

65 code.writeline("from flag_gems.runtime import torch_device_fn") 

66 code.writeline("from flag_gems.utils.shape_utils import volume") 

67 code.writeline("from flag_gems.utils import libentry, libtuner") 

68 code.writeline("from flag_gems.runtime.backend import vendor_module") 

69 code.writeline("MAX_GRID_SIZE_X = vendor_module.MAX_GRID_SIZE_X") 

70 code.writeline("from flag_gems.utils.type_utils import type_promotion") 

71 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

72 code.newline() 

73 code.newline() 

74 return code 

75 

76 

77def generate_functional_tile_wrapper( 

78 wrapper_name: str, 

79 destination_passing_func_name: str, 

80 code: IndentedBuffer, 

81) -> IndentedBuffer: 

82 # wrapper signature 

83 parameters: str = parameter_for_wrapper() 

84 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

85 code.writeline(wrapper_signature) 

86 

87 with code.indent(): 

88 code.writeline("in0_rank = in0.dim()") 

89 code.writeline("dims_rank = len(dims)") 

90 code.writeline("in0_shape = list(in0.shape)") 

91 code.writeline("dims_shape = list(dims)") 

92 code.newline() 

93 code.writeline("if (dims_rank < in0_rank): ") 

94 with code.indent(): 

95 code.writeline("diff = in0_rank - dims_rank") 

96 code.writeline("ones = [1 for _ in range(diff)]") 

97 code.writeline("dims_shape = ones + dims_shape") 

98 code.writeline("elif (dims_rank > in0_rank): ") 

99 with code.indent(): 

100 code.writeline("diff = dims_rank - in0_rank") 

101 code.writeline("ones = [1 for _ in range(diff)]") 

102 code.writeline("in0_shape = ones + in0_shape") 

103 code.newline() 

104 code.writeline("is_empty = False") 

105 code.writeline("out_shape = []") 

106 code.writeline("for i in range(len(in0_shape)): ") 

107 with code.indent(): 

108 code.writeline( 

109 "assert(dims_shape[i] >= 0), 'the number of repetitions per dimension out of range (expected to >= 0) \ 

110 but got {}'.format(dims_shape[i])" 

111 ) 

112 code.writeline("if dims_shape[i] == 0: ") 

113 with code.indent(): 

114 code.writeline("is_empty = True") 

115 code.writeline("out_shape.append(in0_shape[i] * dims_shape[i])") 

116 code.newline() 

117 code.writeline( 

118 "out0 = torch.empty(out_shape, device=in0.device, dtype=in0.dtype)" 

119 ) 

120 

121 code.writeline("in0 = in0.reshape(in0_shape)") 

122 code.writeline("if not is_empty: ") 

123 with code.indent(): 

124 # call destination_passing_func 

125 output_names: str = output_ref_for_wrapper() 

126 call_str = ( 

127 f"{output_names} = {destination_passing_func_name}" 

128 f"({parameter_ref_for_wrapper()})" 

129 ) 

130 code.writeline(call_str) 

131 

132 return_str = "return out0" 

133 code.writeline(return_str) 

134 code.newline() 

135 code.newline() 

136 

137 return code 

138 

139 

140def generate_destination_passing_tile_wrapper( 

141 rank: int, 

142 wrapper_name: str, 

143 kernel_name: str, 

144 code: IndentedBuffer, 

145) -> IndentedBuffer: 

146 # wrapper signature 

147 parameters: str = parameter_for_wrapper_out() 

148 

149 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

150 code.writeline(wrapper_signature) 

151 

152 with code.indent(): 

153 # docstring 

154 if rank > 0: 

155 code.writeline("shape = out0.shape") 

156 code.writeline("num_tasks = volume(shape)") 

157 

158 if rank > 0: 

159 code.writeline("tile_size = min(512, triton.next_power_of_2(num_tasks))") 

160 code.writeline("num_warps = 1") 

161 code.writeline( 

162 "num_ctas = min(MAX_GRID_SIZE_X//num_warps, triton.cdiv(num_tasks, tile_size))" 

163 ) 

164 code.writeline( 

165 "tiles_per_cta = triton.cdiv(num_tasks, tile_size * num_ctas)" 

166 ) 

167 else: 

168 code.writeline("num_warps = 1") 

169 code.writeline("num_ctas = 1") 

170 code.writeline("grid = (num_ctas, 1, 1)") 

171 code.newline() 

172 

173 # input strides for each input tensor w.r.t. the task index space 

174 if rank > 0: 

175 code.writeline("# strides of each tensor argument w.r.t the task space") 

176 code.writeline("in0_strides = in0.stride()") 

177 code.writeline("in0_shape = in0.shape") 

178 code.writeline("out0_strides = out0.stride()") 

179 code.newline() 

180 

181 # grid 

182 code.writeline("# kernel launch") 

183 

184 # launch kernel 

185 code.writeline("with torch_device_fn.device(in0.device.index):") 

186 with code.indent(): 

187 kernel_launch: str = f"{kernel_name}[grid](" 

188 code.writeline(kernel_launch) 

189 

190 with code.indent(): 

191 code.writeline("in0, out0, ") 

192 

193 if rank > 0: 

194 s = ", ".join(f"in0_strides[{j}]" for j in range(rank)) 

195 code.writeline(f"{s}, # stride for in0") 

196 

197 s = ", ".join(f"out0_strides[{j}]" for j in range(rank)) 

198 code.writeline(f"{s}, # stride for out0") 

199 

200 shape_args: str = ", ".join(f"shape[{i}]" for i in range(rank)) 

201 code.writeline(f"{shape_args}, # task indexing space") 

202 in_shape_args: str = ", ".join(f"in0_shape[{i}]" for i in range(rank)) 

203 code.writeline( 

204 f"{in_shape_args}, # task indexing space used when input and ouput tensor has different shape" 

205 ) 

206 code.writeline("num_tasks, # num tasks") 

207 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

208 code.writeline("tile_size=tile_size,") 

209 code.writeline("one_tile_per_cta=tiles_per_cta==1,") 

210 code.writeline("num_warps=num_warps,") 

211 code.writeline(")") 

212 

213 # return 

214 code.writeline("return out0") 

215 code.newline() 

216 code.newline() 

217 return code 

218 

219 

220def generate_tile_kernel( 

221 rank: int, 

222 kernel_name: str, 

223 code: IndentedBuffer, 

224) -> IndentedBuffer: 

225 # make the inlined function visible in the context 

226 code.newline() 

227 

228 # the decorators 

229 code.writeline("@libentry()") 

230 code.writeline("@triton.jit") 

231 

232 # signature 

233 code.writeline(f"def {kernel_name}(") 

234 with code.indent(): 

235 # signature: inputs ptrs & non tensor inputs 

236 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type") 

237 

238 # signature: output ptrs 

239 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type") 

240 

241 # signature: strides, for each tensor arguments 

242 # only add this arguments when rank > 0 

243 if rank > 0: 

244 # strides for inputs 

245 stride_args = ", ".join(f"in0_stride{j}: int" for j in range(rank)) 

246 code.writeline(f"{stride_args}, # strides for in0") 

247 

248 # strides for outputs 

249 stride_args = ", ".join(f"out0_stride{j}: int" for j in range(rank)) 

250 code.writeline(f"{stride_args}, # strides for out0") 

251 

252 # task space, used to reconstruct multi index 

253 task_space_args = ", ".join(f"s{i}: int" for i in range(rank)) 

254 code.writeline(f"{task_space_args}, # task_space") 

255 

256 task_space_args2 = ", ".join(f"in_s{i}: int" for i in range(rank)) 

257 code.writeline( 

258 f"{task_space_args2}, # task_space2 used when input and output tensor has different shape" 

259 ) 

260 

261 # number of tasks, used to compute mask 

262 code.writeline("num_tasks: int,") 

263 

264 # tile size & tiles_per_cta, gsl style 

265 if rank > 0: 

266 code.writeline("tiles_per_cta,") 

267 

268 code.writeline("tile_size: tl.constexpr,") 

269 

270 code.writeline("one_tile_per_cta: tl.constexpr,") 

271 code.writeline("):") 

272 

273 with code.indent(): 

274 # get pid 

275 code.writeline("# task id & masking") 

276 pid_stmt = "pid = tl.program_id(0)" 

277 code.writeline(pid_stmt) 

278 

279 code.writeline("num_ctas = tl.num_programs(0)") 

280 

281 # get tid (a.k.a task id) 

282 tid_stmt = "init_tid = pid * tile_size + tl.arange(0, tile_size)" 

283 code.writeline(tid_stmt) 

284 

285 # one-tile-per-cta, monolithic kernel style 

286 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

287 with code.indent(): 

288 tid_stmt = "tid = init_tid" 

289 code.writeline(tid_stmt) 

290 

291 # only apply masking when rank > 0 

292 # since we only load a value instead of a block of values when the rank is 0 

293 mask_stmt: str = "mask = tid < num_tasks" 

294 code.writeline(mask_stmt) 

295 code.newline() 

296 

297 # reconstruct multi index 

298 code.writeline("# multi index recontruction") 

299 for i in reversed(range(rank)): 

300 if i > 0: 

301 code.writeline(f"i{i} = tid % s{i}") 

302 code.writeline(f"tid //= s{i}") 

303 else: 

304 code.writeline(f"i{i} = tid") 

305 code.newline() 

306 

307 # loads 

308 code.writeline("# loads") 

309 ptrs_expr: str = " + ".join( 

310 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank) 

311 ) 

312 ptrs_expr: str = f"in0_ptr + {ptrs_expr}" 

313 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)" 

314 code.writeline(load_stmt) 

315 code.newline() 

316 

317 # compute 

318 code.writeline("# compute") 

319 code.writeline("out0 = in0") 

320 code.newline() 

321 

322 # stores 

323 code.writeline("# stores") 

324 ptrs_expr: str = " + ".join(f"i{j} * out0_stride{j}" for j in range(rank)) 

325 ptrs_expr: str = f"out0_ptr + {ptrs_expr}" 

326 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)" 

327 code.writeline(store_stmt) 

328 

329 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

330 code.writeline("else: # grid-stride-loop style kernel") 

331 with code.indent(): 

332 code.writeline("for j in range(0, tiles_per_cta):") 

333 with code.indent(): 

334 tid_stmt = "tid = init_tid + j * tile_size * num_ctas" 

335 code.writeline(tid_stmt) 

336 

337 # only apply masking when rank > 0 

338 # since we only load a value instead of a block of values when the rank is 0 

339 mask_stmt: str = "mask = tid < num_tasks" 

340 code.writeline(mask_stmt) 

341 code.newline() 

342 

343 # reconstruct multi index 

344 code.writeline("# multi index recontruction") 

345 for i in reversed(range(rank)): 

346 if i > 0: 

347 code.writeline(f"i{i} = tid % s{i}") 

348 code.writeline(f"tid //= s{i}") 

349 else: 

350 code.writeline(f"i{i} = tid") 

351 code.newline() 

352 

353 # loads 

354 code.writeline("# loads") 

355 ptrs_expr: str = " + ".join( 

356 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank) 

357 ) 

358 ptrs_expr: str = f"in0_ptr + {ptrs_expr}" 

359 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)" 

360 code.writeline(load_stmt) 

361 code.newline() 

362 

363 # compute 

364 code.writeline("# compute") 

365 code.writeline("out0 = in0") 

366 code.newline() 

367 

368 # stores 

369 code.writeline("# stores") 

370 ptrs_expr: str = " + ".join( 

371 f"i{j} * out0_stride{j}" for j in range(rank) 

372 ) 

373 ptrs_expr: str = f"out0_ptr + {ptrs_expr}" 

374 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)" 

375 code.writeline(store_stmt) 

376 code.newline() 

377 return code 

378 

379 

380def generate_code( 

381 rank: int, 

382 wrapper_name: str, 

383 destination_passing_func_name: str, 

384 kernel_name: str, 

385 code: IndentedBuffer, 

386) -> IndentedBuffer: 

387 # the only runtime determined factor is the rank of the task space 

388 code = generate_imports(code) 

389 code = generate_functional_tile_wrapper( 

390 wrapper_name, destination_passing_func_name, code 

391 ) 

392 code = generate_destination_passing_tile_wrapper( 

393 rank, destination_passing_func_name, kernel_name, code 

394 ) 

395 code = generate_tile_kernel(rank, kernel_name, code) 

396 return code 

397 

398 

399class TileFunction: 

400 def __init__(self): 

401 self.pid = os.getpid() 

402 # instantiated & cached overloads 

403 self.overloads: Mapping[str, Callable] = {} 

404 

405 def __call__(self, x, dims): 

406 # note: kwargs should not be used in JITFunction directly 

407 ndim = self.arg_key(x, dims) 

408 key = str(ndim) 

409 if key in self.overloads: 

410 overload = self.overloads[key] 

411 else: 

412 # generate file & import it 

413 code = IndentedBuffer() 

414 code = generate_code( 

415 ndim, 

416 "_wrapper", 

417 "_wrapper_out", 

418 "_tile_flaggems_jit_function", 

419 code, 

420 ) 

421 

422 file_name = f"tile_rank_{key}_pid_{self.pid}.py" 

423 

424 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

425 f.write(code.getvalue()) 

426 

427 # load 

428 spec = importlib.util.spec_from_file_location( 

429 f"_gen_module_rank_{key}_pid_{self.pid}", 

430 f.name, 

431 ) 

432 

433 m = importlib.util.module_from_spec(spec) 

434 # do not expose it to sys.modules 

435 # sys.modules["_add_module"] = m 

436 spec.loader.exec_module(m) 

437 overload = getattr(m, "_wrapper") 

438 self.overloads[key] = overload 

439 return overload(x, dims) 

440 

441 def arg_key(self, x, dims): 

442 max_rank = max(x.ndim, len(dims)) 

443 return max_rank 

444 

445 

446_tile_func = TileFunction() 

447 

448 

449@libentry() 

450@libtuner( 

451 configs=[ 

452 triton.Config({"BLOCK_C": 2**n}, num_stages=3) for n in range(10, 17, 2) 

453 ], 

454 key=["C"], 

455 strategy=["log"], 

456) 

457@triton.jit 

458def tile_2d_kernel( 

459 inp_ptr, 

460 out_ptr, 

461 N, 

462 C: tl.constexpr, 

463 repeat_N: tl.constexpr, 

464 repeat_C: tl.constexpr, 

465 BLOCK_C: tl.constexpr, 

466): 

467 job_id = tl.program_id(0) 

468 num_jobs = tl.num_programs(0) 

469 for batch_idx in range(job_id, N, num_jobs): 

470 if C <= BLOCK_C: 

471 offset_c = tl.arange(0, C) 

472 inp_ptrs = inp_ptr + batch_idx * C + offset_c 

473 inp = tl.load(inp_ptrs).reshape(1, C) 

474 repeat_inp = inp.broadcast_to(repeat_C, C).reshape(repeat_C * C) 

475 out_offset_c = tl.arange(0, repeat_C * C) 

476 for n_idx in tl.static_range(0, repeat_N): 

477 out_ptrs = ( 

478 out_ptr 

479 + N * n_idx * repeat_C * C 

480 + batch_idx * repeat_C * C 

481 + out_offset_c 

482 ) 

483 tl.store(out_ptrs, repeat_inp) 

484 else: 

485 for off in range(0, C, BLOCK_C): 

486 offset_c = off + tl.arange(0, BLOCK_C) 

487 inp_ptrs = inp_ptr + batch_idx * C + offset_c 

488 inp_mask = offset_c < C 

489 inp = tl.load(inp_ptrs, mask=inp_mask, other=0) 

490 for c_idx in tl.static_range(0, repeat_C): 

491 for n_idx in tl.static_range(0, repeat_N): 

492 out_ptrs = ( 

493 out_ptr 

494 + N * n_idx * repeat_C * C 

495 + batch_idx * repeat_C * C 

496 + c_idx * C 

497 + offset_c 

498 ) 

499 tl.store(out_ptrs, inp, mask=inp_mask) 

500 

501 

502def tile(inp: torch.Tensor, dims) -> torch.Tensor: 

503 logger.debug("GEMS_CAMBRICON TILE") 

504 

505 inp_rank = inp.dim() 

506 dims_rank = len(dims) 

507 if inp_rank == 2 and dims_rank == 2: 

508 inp_shape = list(inp.shape) 

509 N = inp_shape[0] 

510 C = inp_shape[1] 

511 dims_shape = list(dims) 

512 repeat_N = dims[0] 

513 repeat_C = dims[1] 

514 

515 out_shape = [] 

516 is_empty = False 

517 for i in range(len(inp_shape)): 

518 if dims_shape[i] == 0: 

519 is_empty = True 

520 out_shape.append(inp_shape[i] * dims_shape[i]) 

521 out = torch.empty(out_shape, device=inp.device, dtype=inp.dtype) 

522 

523 if is_empty: 

524 return out 

525 tile_2d_kernel[(TOTAL_CORE_NUM,)]( 

526 inp.contiguous(), out, N, C, repeat_N, repeat_C 

527 ) 

528 return out 

529 

530 out = _tile_func(inp, dims) 

531 return out