Coverage for src/flag_gems/ops/repeat.py: 99%

251 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import importlib 

2import logging 

3import os 

4from typing import Callable, List, Mapping 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14# --------------------------- repeat wrapper genration ----------------------------------- 

15def parameter_for_wrapper() -> str: 

16 """Generate parameter declaration with type annotation for wrapper function. 

17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

18 """ 

19 parameters: List[str] = [] 

20 

21 parameters.append("in0") 

22 parameters.append("sizes") 

23 return ", ".join(parameters) 

24 

25 

26def parameter_for_wrapper_out() -> str: 

27 """Generate parameter declaration with type annotation for wrapper function. 

28 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

29 """ 

30 parameters: List[str] = [] 

31 

32 parameters.append("in0") 

33 parameters.append("out0") 

34 

35 return ", ".join(parameters) 

36 

37 

38def parameter_ref_for_wrapper() -> str: 

39 """Generate parameter reference for wrapper function. 

40 Example: in0, val0, out0, out0_offset 

41 """ 

42 parameters: List[str] = [] 

43 

44 parameters.append("in0") 

45 parameters.append("out0") 

46 

47 return ", ".join(parameters) 

48 

49 

50def output_ref_for_wrapper() -> str: 

51 return "out0" 

52 

53 

54def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

55 code.writeline("import math") 

56 code.writeline("import torch") 

57 code.writeline("import triton") 

58 code.writeline("from triton import language as tl") 

59 code.newline() 

60 code.writeline("from flag_gems.runtime import torch_device_fn") 

61 code.writeline("from flag_gems.utils.shape_utils import volume") 

62 code.writeline("from flag_gems.utils.libentry import libentry") 

63 code.writeline("from flag_gems.utils.type_utils import type_promotion") 

64 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

65 code.newline() 

66 code.newline() 

67 return code 

68 

69 

70def generate_functional_repeat_wrapper( 

71 wrapper_name: str, 

72 destination_passing_func_name: str, 

73 code: IndentedBuffer, 

74) -> IndentedBuffer: 

75 # wrapper signature 

76 parameters: str = parameter_for_wrapper() 

77 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

78 code.writeline(wrapper_signature) 

79 

80 with code.indent(): 

81 code.writeline("in0_rank = in0.dim()") 

82 code.writeline("sizes_rank = len(sizes)") 

83 code.writeline("in0_shape = list(in0.shape)") 

84 code.writeline("sizes_shape = list(sizes)") 

85 code.newline() 

86 

87 code.writeline( 

88 "assert(sizes_rank >= in0_rank), \ 

89 'Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor'" 

90 ) 

91 code.writeline("if (sizes_rank > in0_rank): ") 

92 with code.indent(): 

93 code.writeline("diff = sizes_rank - in0_rank") 

94 code.writeline("ones = [1 for _ in range(diff)]") 

95 code.writeline("in0_shape = ones + in0_shape") 

96 code.newline() 

97 code.writeline("is_empty = False") 

98 code.writeline("out_shape = []") 

99 code.writeline("for i in range(len(in0_shape)): ") 

100 with code.indent(): 

101 code.writeline( 

102 "assert(sizes_shape[i] >= 0), 'the number of repetitions per dimension out of range (expected to >= 0) \ 

103 but got {}'.format(sizes_shape[i])" 

104 ) 

105 code.writeline("if in0_shape[i] * sizes_shape[i] == 0: ") 

106 with code.indent(): 

107 code.writeline("is_empty = True") 

108 code.writeline("out_shape.append(in0_shape[i] * sizes_shape[i])") 

109 code.newline() 

110 code.writeline( 

111 "out0 = torch.empty(out_shape, device=in0.device, dtype=in0.dtype)" 

112 ) 

113 

114 code.writeline("in0 = in0.reshape(in0_shape)") 

115 code.writeline("if not is_empty: ") 

116 with code.indent(): 

117 # call destination_passing_func 

118 output_names: str = output_ref_for_wrapper() 

119 call_str = ( 

120 f"{output_names} = {destination_passing_func_name}" 

121 f"({parameter_ref_for_wrapper()})" 

122 ) 

123 code.writeline(call_str) 

124 

125 return_str = "return out0" 

126 code.writeline(return_str) 

127 code.newline() 

128 code.newline() 

129 

130 return code 

131 

132 

133def generate_destination_passing_repeat_wrapper( 

134 rank: int, 

135 wrapper_name: str, 

136 kernel_name: str, 

137 code: IndentedBuffer, 

138) -> IndentedBuffer: 

139 # wrapper signature 

140 parameters: str = parameter_for_wrapper_out() 

141 

142 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

143 code.writeline(wrapper_signature) 

144 

145 with code.indent(): 

146 # docstring 

147 if rank > 0: 

148 code.writeline("shape = out0.shape") 

149 code.writeline("num_tasks = volume(shape)") 

150 

151 if rank > 0: 

152 code.writeline("tile_size = min(512, triton.next_power_of_2(num_tasks))") 

153 code.writeline("num_warps = 4") 

154 code.writeline("num_ctas = min(65535, triton.cdiv(num_tasks, tile_size))") 

155 code.writeline( 

156 "tiles_per_cta = triton.cdiv(num_tasks, tile_size * num_ctas)" 

157 ) 

158 else: 

159 code.writeline("num_warps = 1") 

160 code.writeline("num_ctas = 1") 

161 code.writeline("grid = (num_ctas, 1, 1)") 

162 code.newline() 

163 

164 # input strides for each input tensor w.r.t. the task index space 

165 if rank > 0: 

166 code.writeline("# strides of each tensor argument w.r.t the task space") 

167 code.writeline("in0_strides = in0.stride()") 

168 code.writeline("in0_shape = in0.shape") 

169 code.writeline("out0_strides = out0.stride()") 

170 code.newline() 

171 

172 # grid 

173 code.writeline("# kernel launch") 

174 

175 # launch kernel 

176 code.writeline("with torch_device_fn.device(in0.device.index):") 

177 with code.indent(): 

178 kernel_launch: str = f"{kernel_name}[grid](" 

179 code.writeline(kernel_launch) 

180 

181 with code.indent(): 

182 code.writeline("in0, out0, ") 

183 

184 if rank > 0: 

185 s = ", ".join(f"in0_strides[{j}]" for j in range(rank)) 

186 code.writeline(f"{s}, # stride for in0") 

187 

188 s = ", ".join(f"out0_strides[{j}]" for j in range(rank)) 

189 code.writeline(f"{s}, # stride for out0") 

190 

191 shape_args: str = ", ".join(f"shape[{i}]" for i in range(rank)) 

192 code.writeline(f"{shape_args}, # task indexing space") 

193 in_shape_args: str = ", ".join(f"in0_shape[{i}]" for i in range(rank)) 

194 code.writeline( 

195 f"{in_shape_args}, # task indexing space used when input and ouput tensor has different shape" 

196 ) 

197 code.writeline("num_tasks, # num tasks") 

198 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

199 code.writeline("tile_size=tile_size,") 

200 code.writeline("one_tile_per_cta=tiles_per_cta==1,") 

201 code.writeline("num_warps=num_warps,") 

202 code.writeline(")") 

203 

204 # return 

205 code.writeline("return out0") 

206 code.newline() 

207 code.newline() 

208 return code 

209 

210 

211def generate_repeat_kernel( 

212 rank: int, 

213 kernel_name: str, 

214 code: IndentedBuffer, 

215) -> IndentedBuffer: 

216 # make the inlined function visible in the context 

217 code.newline() 

218 

219 # the decorators 

220 code.writeline("@libentry()") 

221 code.writeline("@triton.jit") 

222 

223 # signature 

224 code.writeline(f"def {kernel_name}(") 

225 with code.indent(): 

226 # signature: inputs ptrs & non tensor inputs 

227 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type") 

228 

229 # signature: output ptrs 

230 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type") 

231 

232 # signature: strides, for each tensor arguments 

233 # only add this arguments when rank > 0 

234 if rank > 0: 

235 # strides for inputs 

236 stride_args = ", ".join(f"in0_stride{j}: int" for j in range(rank)) 

237 code.writeline(f"{stride_args}, # strides for in0") 

238 

239 # strides for outputs 

240 stride_args = ", ".join(f"out0_stride{j}: int" for j in range(rank)) 

241 code.writeline(f"{stride_args}, # strides for out0") 

242 

243 # task space, used to reconstruct multi index 

244 task_space_args = ", ".join(f"s{i}: int" for i in range(rank)) 

245 code.writeline(f"{task_space_args}, # task_space") 

246 

247 task_space_args2 = ", ".join(f"in_s{i}: int" for i in range(rank)) 

248 code.writeline( 

249 f"{task_space_args2}, # task_space2 used when input and output tensor has different shape" 

250 ) 

251 

252 # number of tasks, used to compute mask 

253 code.writeline("num_tasks: int,") 

254 

255 # tile size & tiles_per_cta, gsl style 

256 if rank > 0: 

257 code.writeline("tiles_per_cta,") 

258 

259 code.writeline("tile_size: tl.constexpr,") 

260 

261 code.writeline("one_tile_per_cta: tl.constexpr,") 

262 code.writeline("):") 

263 

264 with code.indent(): 

265 # get pid 

266 code.writeline("# task id & masking") 

267 pid_stmt = "pid = tle.program_id(0)" 

268 code.writeline(pid_stmt) 

269 

270 code.writeline("num_ctas = tle.num_programs(0)") 

271 

272 # get tid (a.k.a task id) 

273 tid_stmt = "init_tid = pid * tile_size + tl.arange(0, tile_size)" 

274 code.writeline(tid_stmt) 

275 

276 # one-tile-per-cta, monolithic kernel style 

277 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

278 with code.indent(): 

279 tid_stmt = "tid = init_tid" 

280 code.writeline(tid_stmt) 

281 

282 # only apply masking when rank > 0 

283 # since we only load a value instead of a block of values when the rank is 0 

284 mask_stmt: str = "mask = tid < num_tasks" 

285 code.writeline(mask_stmt) 

286 code.newline() 

287 

288 # reconstruct multi index 

289 code.writeline("# multi index recontruction") 

290 for i in reversed(range(rank)): 

291 if i > 0: 

292 code.writeline(f"i{i} = tid % s{i}") 

293 code.writeline(f"tid //= s{i}") 

294 else: 

295 code.writeline(f"i{i} = tid") 

296 code.newline() 

297 

298 # loads 

299 code.writeline("# loads") 

300 ptrs_expr: str = " + ".join( 

301 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank) 

302 ) 

303 ptrs_expr: str = f"in0_ptr + {ptrs_expr}" 

304 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)" 

305 code.writeline(load_stmt) 

306 code.newline() 

307 

308 # compute 

309 code.writeline("# compute") 

310 code.writeline("out0 = in0") 

311 code.newline() 

312 

313 # stores 

314 code.writeline("# stores") 

315 ptrs_expr: str = " + ".join(f"i{j} * out0_stride{j}" for j in range(rank)) 

316 ptrs_expr: str = f"out0_ptr + {ptrs_expr}" 

317 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)" 

318 code.writeline(store_stmt) 

319 

320 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

321 code.writeline("else: # grid-stride-loop style kernel") 

322 with code.indent(): 

323 code.writeline("for j in range(0, tiles_per_cta):") 

324 with code.indent(): 

325 tid_stmt = "tid = init_tid + j * tile_size * num_ctas" 

326 code.writeline(tid_stmt) 

327 

328 # only apply masking when rank > 0 

329 # since we only load a value instead of a block of values when the rank is 0 

330 mask_stmt: str = "mask = tid < num_tasks" 

331 code.writeline(mask_stmt) 

332 code.newline() 

333 

334 # reconstruct multi index 

335 code.writeline("# multi index recontruction") 

336 for i in reversed(range(rank)): 

337 if i > 0: 

338 code.writeline(f"i{i} = tid % s{i}") 

339 code.writeline(f"tid //= s{i}") 

340 else: 

341 code.writeline(f"i{i} = tid") 

342 code.newline() 

343 

344 # loads 

345 code.writeline("# loads") 

346 ptrs_expr: str = " + ".join( 

347 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank) 

348 ) 

349 ptrs_expr: str = f"in0_ptr + {ptrs_expr}" 

350 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)" 

351 code.writeline(load_stmt) 

352 code.newline() 

353 

354 # compute 

355 code.writeline("# compute") 

356 code.writeline("out0 = in0") 

357 code.newline() 

358 

359 # stores 

360 code.writeline("# stores") 

361 ptrs_expr: str = " + ".join( 

362 f"i{j} * out0_stride{j}" for j in range(rank) 

363 ) 

364 ptrs_expr: str = f"out0_ptr + {ptrs_expr}" 

365 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)" 

366 code.writeline(store_stmt) 

367 code.newline() 

368 return code 

369 

370 

371def generate_code( 

372 rank: int, 

373 wrapper_name: str, 

374 destination_passing_func_name: str, 

375 kernel_name: str, 

376 code: IndentedBuffer, 

377) -> IndentedBuffer: 

378 # the only runtime determined factor is the rank of the task space 

379 code = generate_imports(code) 

380 code = generate_functional_repeat_wrapper( 

381 wrapper_name, destination_passing_func_name, code 

382 ) 

383 code = generate_destination_passing_repeat_wrapper( 

384 rank, destination_passing_func_name, kernel_name, code 

385 ) 

386 code = generate_repeat_kernel(rank, kernel_name, code) 

387 return code 

388 

389 

390class RepeatFunction: 

391 def __init__(self): 

392 self.pid = os.getpid() 

393 # instantiated & cached overloads 

394 self.overloads: Mapping[str, Callable] = {} 

395 

396 def __call__(self, x, sizes): 

397 # note: kwargs should not be used in JITFunction directly 

398 ndim = self.arg_key(x, sizes) 

399 key = str(ndim) 

400 if key in self.overloads: 

401 overload = self.overloads[key] 

402 else: 

403 # generate file & import it 

404 code = IndentedBuffer() 

405 code = generate_code( 

406 ndim, 

407 "_wrapper", 

408 "_wrapper_out", 

409 "_repeat_flaggems_jit_function", 

410 code, 

411 ) 

412 

413 file_name = f"repeat_rank_{key}.py" 

414 file_path = code_cache_dir() / file_name 

415 write_atomic(file_path, code.getvalue()) 

416 

417 # load 

418 spec = importlib.util.spec_from_file_location( 

419 f"_gen_module_rank_{key}", 

420 file_path, 

421 ) 

422 

423 m = importlib.util.module_from_spec(spec) 

424 # do not expose it to sys.modules 

425 # sys.modules["_add_module"] = m 

426 spec.loader.exec_module(m) 

427 overload = getattr(m, "_wrapper") 

428 self.overloads[key] = overload 

429 return overload(x, sizes) 

430 

431 def arg_key(self, x, sizes): 

432 max_rank = max(x.ndim, len(sizes)) 

433 return max_rank 

434 

435 

436_repeat_func = RepeatFunction() 

437 

438 

439def repeat(inp: torch.Tensor, sizes) -> torch.Tensor: 

440 logger.debug("GEMS REPEAT") 

441 

442 out = _repeat_func(inp, sizes) 

443 return out