Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/tile.py: 0%

255 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import importlib 

2import logging 

3import os 

4from typing import Callable, List, Mapping 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14# --------------------------- tile wrapper genration ----------------------------------- 

15def parameter_for_wrapper() -> str: 

16 """Generate parameter declaration with type annotation for wrapper function. 

17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

18 """ 

19 parameters: List[str] = [] 

20 

21 parameters.append("in0") 

22 parameters.append("dims") 

23 return ", ".join(parameters) 

24 

25 

26def parameter_for_wrapper_out() -> str: 

27 """Generate parameter declaration with type annotation for wrapper function. 

28 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

29 """ 

30 parameters: List[str] = [] 

31 

32 parameters.append("in0") 

33 parameters.append("out0") 

34 

35 return ", ".join(parameters) 

36 

37 

38def parameter_ref_for_wrapper() -> str: 

39 """Generate parameter reference for wrapper function. 

40 Example: in0, val0, out0, out0_offset 

41 """ 

42 parameters: List[str] = [] 

43 

44 parameters.append("in0") 

45 parameters.append("out0") 

46 

47 return ", ".join(parameters) 

48 

49 

50def output_ref_for_wrapper() -> str: 

51 return "out0" 

52 

53 

54def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

55 code.writeline("import math") 

56 code.writeline("import torch") 

57 code.writeline("import triton") 

58 code.writeline("from triton import language as tl") 

59 code.newline() 

60 code.writeline("from flag_gems.runtime import torch_device_fn") 

61 code.writeline("from flag_gems.utils.shape_utils import volume") 

62 code.writeline("from flag_gems.utils.libentry import libentry") 

63 code.writeline("from flag_gems.utils.type_utils import type_promotion") 

64 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

65 code.newline() 

66 code.newline() 

67 return code 

68 

69 

70def generate_functional_tile_wrapper( 

71 wrapper_name: str, 

72 destination_passing_func_name: str, 

73 code: IndentedBuffer, 

74) -> IndentedBuffer: 

75 # wrapper signature 

76 parameters: str = parameter_for_wrapper() 

77 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

78 code.writeline(wrapper_signature) 

79 

80 with code.indent(): 

81 code.writeline("in0_rank = in0.dim()") 

82 code.writeline("dims_rank = len(dims)") 

83 code.writeline("in0_shape = list(in0.shape)") 

84 code.writeline("dims_shape = list(dims)") 

85 code.newline() 

86 code.writeline("if (dims_rank < in0_rank): ") 

87 with code.indent(): 

88 code.writeline("diff = in0_rank - dims_rank") 

89 code.writeline("ones = [1 for _ in range(diff)]") 

90 code.writeline("dims_shape = ones + dims_shape") 

91 code.writeline("elif (dims_rank > in0_rank): ") 

92 with code.indent(): 

93 code.writeline("diff = dims_rank - in0_rank") 

94 code.writeline("ones = [1 for _ in range(diff)]") 

95 code.writeline("in0_shape = ones + in0_shape") 

96 code.newline() 

97 code.writeline("is_empty = False") 

98 code.writeline("out_shape = []") 

99 code.writeline("for i in range(len(in0_shape)): ") 

100 with code.indent(): 

101 code.writeline( 

102 "assert(dims_shape[i] >= 0), 'the number of repetitions per dimension out of range (expected to >= 0) \ 

103 but got {}'.format(dims_shape[i])" 

104 ) 

105 code.writeline("if dims_shape[i] == 0: ") 

106 with code.indent(): 

107 code.writeline("is_empty = True") 

108 code.writeline("out_shape.append(in0_shape[i] * dims_shape[i])") 

109 code.newline() 

110 code.writeline( 

111 "out0 = torch.empty(out_shape, device=in0.device, dtype=in0.dtype)" 

112 ) 

113 

114 code.writeline("in0 = in0.reshape(in0_shape)") 

115 code.writeline("if not is_empty: ") 

116 with code.indent(): 

117 # call destination_passing_func 

118 output_names: str = output_ref_for_wrapper() 

119 call_str = ( 

120 f"{output_names} = {destination_passing_func_name}" 

121 f"({parameter_ref_for_wrapper()})" 

122 ) 

123 code.writeline(call_str) 

124 

125 return_str = "return out0" 

126 code.writeline(return_str) 

127 code.newline() 

128 code.newline() 

129 

130 return code 

131 

132 

133def generate_destination_passing_tile_wrapper( 

134 rank: int, 

135 wrapper_name: str, 

136 kernel_name: str, 

137 code: IndentedBuffer, 

138) -> IndentedBuffer: 

139 # wrapper signature 

140 parameters: str = parameter_for_wrapper_out() 

141 

142 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

143 code.writeline(wrapper_signature) 

144 

145 with code.indent(): 

146 # docstring 

147 if rank > 0: 

148 code.writeline("shape = out0.shape") 

149 code.writeline("num_tasks = volume(shape)") 

150 

151 if rank > 0: 

152 code.writeline("num_ctas = 12") 

153 code.writeline("num_warps = 1") 

154 code.writeline( 

155 "tile_size = triton.next_power_of_2(triton.cdiv(num_tasks, num_ctas))" 

156 ) 

157 code.writeline( 

158 "tiles_per_cta = triton.cdiv(num_tasks, tile_size * num_ctas)" 

159 ) 

160 else: 

161 code.writeline("num_warps = 1") 

162 code.writeline("num_ctas = 1") 

163 code.writeline("grid = (num_ctas, 1, 1)") 

164 code.newline() 

165 

166 # input strides for each input tensor w.r.t. the task index space 

167 if rank > 0: 

168 code.writeline("# strides of each tensor argument w.r.t the task space") 

169 code.writeline("in0_strides = in0.stride()") 

170 code.writeline("in0_shape = in0.shape") 

171 code.writeline("out0_strides = out0.stride()") 

172 code.newline() 

173 

174 # grid 

175 code.writeline("# kernel launch") 

176 

177 # launch kernel 

178 code.writeline("with torch_device_fn.device(in0.device.index):") 

179 with code.indent(): 

180 kernel_launch: str = f"{kernel_name}[grid](" 

181 code.writeline(kernel_launch) 

182 

183 with code.indent(): 

184 code.writeline("in0, out0, ") 

185 

186 if rank > 0: 

187 s = ", ".join(f"in0_strides[{j}]" for j in range(rank)) 

188 code.writeline(f"{s}, # stride for in0") 

189 

190 s = ", ".join(f"out0_strides[{j}]" for j in range(rank)) 

191 code.writeline(f"{s}, # stride for out0") 

192 

193 shape_args: str = ", ".join(f"shape[{i}]" for i in range(rank)) 

194 code.writeline(f"{shape_args}, # task indexing space") 

195 in_shape_args: str = ", ".join(f"in0_shape[{i}]" for i in range(rank)) 

196 code.writeline( 

197 f"{in_shape_args}, # task indexing space used when input and ouput tensor has different shape" 

198 ) 

199 code.writeline("num_tasks, # num tasks") 

200 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

201 code.writeline("tile_size=tile_size,") 

202 code.writeline("one_tile_per_cta=tiles_per_cta==1,") 

203 code.writeline("num_warps=num_warps,") 

204 code.writeline(")") 

205 

206 # return 

207 code.writeline("return out0") 

208 code.newline() 

209 code.newline() 

210 return code 

211 

212 

213def generate_tile_kernel( 

214 rank: int, 

215 kernel_name: str, 

216 code: IndentedBuffer, 

217) -> IndentedBuffer: 

218 # make the inlined function visible in the context 

219 code.newline() 

220 

221 # the decorators 

222 code.writeline("@libentry()") 

223 code.writeline("@triton.jit") 

224 

225 # signature 

226 code.writeline(f"def {kernel_name}(") 

227 with code.indent(): 

228 # signature: inputs ptrs & non tensor inputs 

229 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type") 

230 

231 # signature: output ptrs 

232 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type") 

233 

234 # signature: strides, for each tensor arguments 

235 # only add this arguments when rank > 0 

236 if rank > 0: 

237 # strides for inputs 

238 stride_args = ", ".join(f"in0_stride{j}: tl.constexpr" for j in range(rank)) 

239 code.writeline(f"{stride_args}, # strides for in0") 

240 

241 # strides for outputs 

242 stride_args = ", ".join( 

243 f"out0_stride{j}: tl.constexpr" for j in range(rank) 

244 ) 

245 code.writeline(f"{stride_args}, # strides for out0") 

246 

247 # task space, used to reconstruct multi index 

248 task_space_args = ", ".join(f"s{i}: tl.constexpr" for i in range(rank)) 

249 code.writeline(f"{task_space_args}, # task_space") 

250 

251 task_space_args2 = ", ".join(f"in_s{i}: tl.constexpr" for i in range(rank)) 

252 code.writeline( 

253 f"{task_space_args2}, # task_space2 used when input and output tensor has different shape" 

254 ) 

255 

256 # number of tasks, used to compute mask 

257 code.writeline("num_tasks: int,") 

258 

259 # tile size & tiles_per_cta, gsl style 

260 if rank > 0: 

261 code.writeline("tiles_per_cta,") 

262 

263 code.writeline("tile_size: tl.constexpr,") 

264 

265 code.writeline("one_tile_per_cta: tl.constexpr,") 

266 code.writeline("):") 

267 

268 with code.indent(): 

269 # get pid 

270 code.writeline("# task id & masking") 

271 pid_stmt = "pid = tle.program_id(0)" 

272 code.writeline(pid_stmt) 

273 

274 code.writeline("num_ctas = tle.num_programs(0)") 

275 

276 # get tid (a.k.a task id) 

277 tid_stmt = "init_tid = pid * tile_size + tl.arange(0, tile_size)" 

278 code.writeline(tid_stmt) 

279 

280 # one-tile-per-cta, monolithic kernel style 

281 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

282 with code.indent(): 

283 tid_stmt = "tid = init_tid" 

284 code.writeline(tid_stmt) 

285 

286 # only apply masking when rank > 0 

287 # since we only load a value instead of a block of values when the rank is 0 

288 mask_stmt: str = "mask = tid < num_tasks" 

289 code.writeline(mask_stmt) 

290 code.newline() 

291 

292 # reconstruct multi index 

293 code.writeline("# multi index recontruction") 

294 for i in reversed(range(rank)): 

295 if i > 0: 

296 code.writeline(f"i{i} = tid % s{i}") 

297 code.writeline(f"tid //= s{i}") 

298 else: 

299 code.writeline(f"i{i} = tid") 

300 code.newline() 

301 

302 # loads 

303 code.writeline("# loads") 

304 ptrs_expr: str = " + ".join( 

305 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank) 

306 ) 

307 ptrs_expr: str = f"in0_ptr + {ptrs_expr}" 

308 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)" 

309 code.writeline(load_stmt) 

310 code.newline() 

311 

312 # compute 

313 code.writeline("# compute") 

314 code.writeline("out0 = in0") 

315 code.newline() 

316 

317 # stores 

318 code.writeline("# stores") 

319 ptrs_expr: str = " + ".join(f"i{j} * out0_stride{j}" for j in range(rank)) 

320 ptrs_expr: str = f"out0_ptr + {ptrs_expr}" 

321 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)" 

322 code.writeline(store_stmt) 

323 

324 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

325 code.writeline("else: # grid-stride-loop style kernel") 

326 with code.indent(): 

327 code.writeline("for j in range(0, tiles_per_cta):") 

328 with code.indent(): 

329 tid_stmt = "tid = init_tid + j * tile_size * num_ctas" 

330 code.writeline(tid_stmt) 

331 

332 # only apply masking when rank > 0 

333 # since we only load a value instead of a block of values when the rank is 0 

334 mask_stmt: str = "mask = tid < num_tasks" 

335 code.writeline(mask_stmt) 

336 code.newline() 

337 

338 # reconstruct multi index 

339 code.writeline("# multi index recontruction") 

340 for i in reversed(range(rank)): 

341 if i > 0: 

342 code.writeline(f"i{i} = tid % s{i}") 

343 code.writeline(f"tid //= s{i}") 

344 else: 

345 code.writeline(f"i{i} = tid") 

346 code.newline() 

347 

348 # loads 

349 code.writeline("# loads") 

350 ptrs_expr: str = " + ".join( 

351 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank) 

352 ) 

353 ptrs_expr: str = f"in0_ptr + {ptrs_expr}" 

354 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)" 

355 code.writeline(load_stmt) 

356 code.newline() 

357 

358 # compute 

359 code.writeline("# compute") 

360 code.writeline("out0 = in0") 

361 code.newline() 

362 

363 # stores 

364 code.writeline("# stores") 

365 ptrs_expr: str = " + ".join( 

366 f"i{j} * out0_stride{j}" for j in range(rank) 

367 ) 

368 ptrs_expr: str = f"out0_ptr + {ptrs_expr}" 

369 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)" 

370 code.writeline(store_stmt) 

371 code.newline() 

372 return code 

373 

374 

375def generate_code( 

376 rank: int, 

377 wrapper_name: str, 

378 destination_passing_func_name: str, 

379 kernel_name: str, 

380 code: IndentedBuffer, 

381) -> IndentedBuffer: 

382 # the only runtime determined factor is the rank of the task space 

383 code = generate_imports(code) 

384 code = generate_functional_tile_wrapper( 

385 wrapper_name, destination_passing_func_name, code 

386 ) 

387 code = generate_destination_passing_tile_wrapper( 

388 rank, destination_passing_func_name, kernel_name, code 

389 ) 

390 code = generate_tile_kernel(rank, kernel_name, code) 

391 return code 

392 

393 

394class TileFunction: 

395 def __init__(self): 

396 self.pid = os.getpid() 

397 # instantiated & cached overloads 

398 self.overloads: Mapping[str, Callable] = {} 

399 

400 def __call__(self, x, dims): 

401 # note: kwargs should not be used in JITFunction directly 

402 ndim = self.arg_key(x, dims) 

403 key = str(ndim) 

404 if key in self.overloads: 

405 overload = self.overloads[key] 

406 else: 

407 # generate file & import it 

408 code = IndentedBuffer() 

409 code = generate_code( 

410 ndim, 

411 "_wrapper", 

412 "_wrapper_out", 

413 "_tile_flaggems_jit_function", 

414 code, 

415 ) 

416 

417 file_name = f"tile_rank_{key}_pid_{self.pid}.py" 

418 

419 with open(cache_dir() / file_name, "wt", encoding="utf-8") as f: 

420 f.write(code.getvalue()) 

421 

422 # load 

423 spec = importlib.util.spec_from_file_location( 

424 f"_gen_module_rank_{key}_pid_{self.pid}", 

425 f.name, 

426 ) 

427 

428 m = importlib.util.module_from_spec(spec) 

429 # do not expose it to sys.modules 

430 # sys.modules["_add_module"] = m 

431 spec.loader.exec_module(m) 

432 overload = getattr(m, "_wrapper") 

433 self.overloads[key] = overload 

434 return overload(x, dims) 

435 

436 def arg_key(self, x, dims): 

437 max_rank = max(x.ndim, len(dims)) 

438 return max_rank 

439 

440 

441_tile_func = TileFunction() 

442 

443 

444def tile(inp: torch.Tensor, dims) -> torch.Tensor: 

445 logger.debug("GEMS TILE") 

446 

447 out = _tile_func(inp, dims) 

448 return out