Coverage for src/flag_gems/utils/pointwise_dynamic_cpp_compat.py: 0%

862 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import importlib 

2import os 

3from dataclasses import dataclass 

4from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple 

5 

6import torch 

7import triton 

8from triton.runtime.jit import JITFunction 

9 

10from flag_gems.utils.code_cache import code_cache_dir 

11from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

12from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config 

13from flag_gems.utils.device_info import get_device_capability 

14from flag_gems.utils.shape_utils import ( 

15 MemOverlap, 

16 all_c_contiguous, 

17 all_the_same_shape, 

18 all_the_same_stride, 

19 broadcast_shapes, 

20 broadcasted_stride, 

21 check_tensor_attributes, 

22 has_internal_overlapping, 

23) 

24from flag_gems.utils.tensor_wrapper import StridedBuffer 

25from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion 

26 

27 

28# ------------------ Operation Description --------------------------- 

29def _type_name(type) -> str: 

30 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object" 

31 if type in (bool, int, float, str): 

32 return type.__name__ 

33 if isinstance(type, torch.dtype): 

34 return str(type) 

35 return str(type) 

36 

37 

38def _check_typed_list(container, type): 

39 for item in container: 

40 assert isinstance(item, type) 

41 

42 

43def _check_sized_list(container, size): 

44 assert len(container) == size 

45 

46 

47def _tuple_content(strings: Sequence[str]) -> str: 

48 # comma separated list 

49 if len(strings) == 0: 

50 return "" 

51 if len(strings) == 1: 

52 return f"{strings[0]}," 

53 else: 

54 return ", ".join(strings) 

55 

56 

57def _cs(strings: Iterable[str]) -> str: 

58 return ", ".join(strings) 

59 

60 

61def _broadcast_vec(i, ndim): 

62 axes = [":" if j == i else "None" for j in range(ndim)] 

63 return f"[{_cs(axes)}]" 

64 

65 

66class FunctionSchema: 

67 _num_inputs: int 

68 _is_tensor: List[bool] 

69 _dtypes: List[Optional[type]] 

70 

71 _num_input_tensors: int 

72 _num_non_tensor_inputs: int 

73 

74 _num_outputs: int 

75 _promotion_methods: List[Tuple[int, ...]] 

76 

77 def __init__( 

78 self, 

79 *, 

80 num_inputs: Optional[int] = None, 

81 is_tensor: Optional[List[bool]] = None, 

82 dtypes: Optional[List[Optional[type]]] = None, 

83 num_outputs: Optional[int] = None, 

84 promotion_methods=None, 

85 ): 

86 if is_tensor is not None: 

87 _check_typed_list(is_tensor, bool) 

88 if dtypes is not None: 

89 _check_typed_list(dtypes, (type, type(None))) 

90 

91 if promotion_methods is None: 

92 raise ValueError( 

93 "No type promotion method provided! You must provide type promotion method for each output!" 

94 ) 

95 else: 

96 self._promotion_methods = self.canonicalize_promotion_methods( 

97 promotion_methods 

98 ) 

99 if num_inputs is not None: 

100 self._num_inputs = num_inputs 

101 if is_tensor is not None: 

102 _check_sized_list(is_tensor, num_inputs) 

103 self._is_tensor = is_tensor 

104 else: 

105 self._is_tensor = [True] * num_inputs 

106 

107 if dtypes is not None: 

108 _check_sized_list(dtypes, num_inputs) 

109 self._dtypes = dtypes 

110 else: 

111 self._dtypes = [None] * num_inputs 

112 elif is_tensor is not None: 

113 self._num_inputs = len(is_tensor) 

114 self._is_tensor = is_tensor 

115 if dtypes is not None: 

116 _check_sized_list(dtypes, self._num_inputs) 

117 self._dtypes = dtypes 

118 else: 

119 self._dtypes = [None] * self._num_inputs 

120 elif dtypes is not None: 

121 self._num_inputs = len(dtypes) 

122 self._dtypes = dtypes 

123 if is_tensor is not None: 

124 _check_sized_list(is_tensor, self._num_inputs) 

125 self._is_tensor = is_tensor 

126 else: 

127 self._is_tensor = [item is None for item in dtypes] 

128 else: 

129 raise ValueError( 

130 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified." 

131 ) 

132 

133 if num_outputs is not None: 

134 self._num_outputs = num_outputs 

135 _check_sized_list(promotion_methods, num_outputs) 

136 else: 

137 self._num_outputs = len(promotion_methods) 

138 

139 assert self._num_inputs >= 1 

140 assert self._num_outputs >= 1 

141 

142 self._num_input_tensors = sum(self._is_tensor) 

143 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors 

144 self._input_id = self._compute_input_id() 

145 

146 @staticmethod 

147 def canonicalize_promotion_methods(promotion_methods): 

148 canonicalized = [] 

149 for item in promotion_methods: 

150 *arg_indices, method = item 

151 canonicalized.append( 

152 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method]) 

153 ) 

154 return canonicalized 

155 

156 def num_inputs(self): 

157 # num of arguments, outputs not included 

158 return self._num_inputs 

159 

160 def num_outputs(self): 

161 return self._num_outputs 

162 

163 def is_tensor(self, arg_id: int) -> bool: 

164 return self._is_tensor[arg_id] 

165 

166 def input_type(self, arg_id) -> Optional[type]: 

167 return self._dtypes[arg_id] 

168 

169 def output_type(self, i): 

170 return self._promotion_methods[i] 

171 

172 def num_input_tensors(self) -> int: 

173 return self._num_input_tensors 

174 

175 def num_output_tensors(self) -> int: 

176 return self._num_outputs 

177 

178 def num_non_tensor_args(self) -> int: 

179 return self._num_non_tensor_inputs 

180 

181 def signature(self, outputs_in_arg: bool = False) -> str: 

182 input_types = [] 

183 for is_tensor, dtype in zip(self._is_tensor, self._dtypes): 

184 if is_tensor: 

185 input_types.append("StridedBuffer") 

186 else: 

187 if dtype is None: 

188 input_types.append("scalar") 

189 else: 

190 input_types.append(_type_name(dtype)) 

191 

192 output_types = [] 

193 

194 if outputs_in_arg: 

195 for i in range(self.num_outputs()): 

196 output_types.append(f"StridedBuffer(a{1}!)") 

197 input_types.extend(output_types) 

198 else: 

199 for _ in range(self.num_outputs()): 

200 output_types.append("StridedBuffer") 

201 sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}' 

202 return sig 

203 

204 def _compute_input_id(self): 

205 input_tensor_index = 0 

206 non_tensor_index = 0 

207 mapping: List[int] = [] 

208 for i in range(self.num_inputs()): 

209 if self.is_tensor(i): 

210 mapping.append(input_tensor_index) 

211 input_tensor_index += 1 

212 else: 

213 mapping.append(non_tensor_index) 

214 non_tensor_index += 1 

215 return mapping 

216 

217 def input_index(self, idx): 

218 return self._input_id[idx] 

219 

220 def __str__(self) -> str: 

221 return self.signature(outputs_in_arg=False) 

222 

223 

224class KernelGenerator: 

225 def __init__( 

226 self, 

227 function_schema: FunctionSchema, 

228 scalar_fn: triton.JITFunction, 

229 rank: int, 

230 name: str, 

231 config: CodeGenConfig, 

232 ): 

233 self.fx = function_schema 

234 self.fn = scalar_fn 

235 self.ndim = rank 

236 self.name = name 

237 self.config = config 

238 

239 self.fn_name = scalar_fn.__name__ 

240 self.fn_module = scalar_fn.__module__ 

241 

242 def gen_import_function(self, code: IndentedBuffer): 

243 code.writeline("@triton.jit") 

244 code.writemultiline(self.fn.src) 

245 code.newline() 

246 

247 def gen_decorators(self, code): 

248 code.writeline("@libentry()") 

249 num_non_tensor_args = self.fx.num_non_tensor_args() 

250 if num_non_tensor_args > 0: 

251 # we do not specialize non tensor args since they are passed into the inlined function 

252 # which means that their values may not deserve specialization 

253 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)] 

254 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

255 else: 

256 code.writeline("@triton.jit") 

257 

258 def input_name(self, i): 

259 is_tensor = self.fx.is_tensor(i) 

260 name = "in" if is_tensor else "val" 

261 index = self.fx.input_index(i) 

262 return f"{name}{index}" 

263 

264 def output_name(self, i): 

265 return f"out{i}" 

266 

267 def gen_signature(self, code, with_block_pointer=False): 

268 code.writeline(f"def {self.name}(") 

269 with code.indent(): 

270 input_tensor_index = 0 

271 non_tensor_index = 0 

272 output_tensor_index = 0 

273 

274 schema = self.fx 

275 # signature: inputs ptrs & non tensor inputs 

276 for i in range(schema.num_inputs()): 

277 if schema.is_tensor(i): 

278 code.writeline( 

279 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

280 ) 

281 input_tensor_index += 1 

282 else: 

283 if schema.input_type(i) is not None: 

284 code.writeline( 

285 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

286 ) 

287 else: 

288 code.writeline(f"val{non_tensor_index},") 

289 non_tensor_index += 1 

290 

291 # signature: output ptrs 

292 for i in range(schema.num_outputs()): 

293 code.writeline( 

294 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

295 ) 

296 output_tensor_index += 1 

297 

298 # signature: strides, for each tensor arguments 

299 ndim = self.ndim 

300 if ndim > 0: 

301 # strides for inputs 

302 for i in range(schema.num_input_tensors()): 

303 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

304 code.writeline(f"{stride_args}, # strides for in{i}") 

305 if with_block_pointer: 

306 stride_order_args = _cs( 

307 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

308 ) 

309 code.writeline(f"{stride_order_args}, # stride order for in{i}") 

310 

311 # strides for outputs 

312 for i in range(schema.num_output_tensors()): 

313 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

314 code.writeline(f"{stride_args}, # strides for out{i}") 

315 if with_block_pointer: 

316 stride_order_args = _cs( 

317 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

318 ) 

319 code.writeline( 

320 f"{stride_order_args}, # stride order for out{i}" 

321 ) 

322 

323 # task space, used to reconstruct multi index 

324 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

325 code.writeline(f"{task_space_args}, # task_space") 

326 

327 # number of tasks, used to compute mask 

328 code.writeline("num_tasks,") 

329 

330 # tile size & tiles_per_cta, gsl style 

331 if ndim > 0: 

332 code.writeline("tiles_per_cta: int,") 

333 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim)) 

334 code.writeline(f"{tile_sizes},") 

335 code.writeline("one_tile_per_cta: tl.constexpr,") 

336 code.writeline("):") 

337 

338 def gen_signature_1d_tile(self, code): 

339 code.writeline(f"def {self.name}(") 

340 with code.indent(): 

341 input_tensor_index = 0 

342 non_tensor_index = 0 

343 output_tensor_index = 0 

344 

345 schema = self.fx 

346 # signature: inputs ptrs & non tensor inputs 

347 for i in range(schema.num_inputs()): 

348 if schema.is_tensor(i): 

349 code.writeline( 

350 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

351 ) 

352 input_tensor_index += 1 

353 else: 

354 if schema.input_type(i) is not None: 

355 code.writeline( 

356 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

357 ) 

358 else: 

359 code.writeline(f"val{non_tensor_index},") 

360 non_tensor_index += 1 

361 

362 # signature: output ptrs 

363 for i in range(schema.num_outputs()): 

364 code.writeline( 

365 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

366 ) 

367 output_tensor_index += 1 

368 

369 # signature: strides, for each tensor arguments 

370 ndim = self.ndim 

371 if ndim > 0: 

372 # strides for inputs 

373 for i in range(schema.num_input_tensors()): 

374 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

375 code.writeline(f"{stride_args}, # strides for in{i}") 

376 

377 # strides for outputs 

378 for i in range(schema.num_output_tensors()): 

379 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

380 code.writeline(f"{stride_args}, # strides for out{i}") 

381 

382 # task space, used to reconstruct multi index 

383 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

384 code.writeline(f"{task_space_args}, # task_space") 

385 

386 # number of tasks, used to compute mask 

387 code.writeline("num_tasks,") 

388 

389 # tile size & tiles_per_cta, gsl style 

390 if ndim > 0: 

391 code.writeline("tiles_per_cta: int,") 

392 code.writeline("tile_size: tl.constexpr,") 

393 code.writeline("one_tile_per_cta: tl.constexpr,") 

394 code.writeline("):") 

395 

396 def gen_num_tiles(self, code): 

397 # tile-grid size 

398 ndim = self.ndim 

399 for i in range(ndim): 

400 if i < ndim: 

401 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})") 

402 

403 def gen_body_for_0d(self, code): 

404 schema = self.fx 

405 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

406 outputs_to_scalar_fn = [ 

407 self.output_name(i) for i in range(schema.num_output_tensors()) 

408 ] 

409 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

410 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

411 

412 code.writeline("# loads") 

413 for i in range(schema.num_input_tensors()): 

414 code.writeline( 

415 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) " 

416 "# workaround the bug on bool, we should use the pointer's dtype)" 

417 ) 

418 code.newline() 

419 

420 code.writeline("# compute") 

421 code.writeline( 

422 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

423 ) 

424 code.newline() 

425 

426 code.writeline("# stores") 

427 for i in range(schema.num_output_tensors()): 

428 code.writeline( 

429 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))" 

430 ) 

431 code.newline() 

432 return code 

433 

434 # nd tile 1d grid kernel with block pointer 

435 def gen_body_one_tile_per_cta_with_bptr(self, code): 

436 ndim = self.ndim 

437 schema = self.fx 

438 

439 # block pointer for each operand 

440 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim))) 

441 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim))) 

442 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim))) 

443 

444 # reconstruct pid multi index 

445 code.writeline( 

446 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

447 ) 

448 for i in reversed(range(ndim)): 

449 if i > 0: 

450 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

451 code.writeline(f"tile_id //= num_tiles{i}") 

452 else: 

453 code.writeline(f"tile_id{i} = tile_id") 

454 code.newline() 

455 

456 # cta_offsets 

457 code.writeline("# tile offsets") 

458 for i in range(ndim): 

459 # Or else: AssertionError: Block pointers only support 32 bit 

460 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing 

461 # for 64 bit support 

462 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)") 

463 

464 # loads 

465 code.writeline("# loads") 

466 for i in range(schema.num_input_tensors()): 

467 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim))) 

468 order = _tuple_content(tuple(f"in{i}_stride_order{j}" for j in range(ndim))) 

469 code.writeline( 

470 f"in{i}_bptr = tl.make_block_ptr(" 

471 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

472 ) 

473 code.writeline( 

474 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) " 

475 "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)" 

476 ) 

477 code.newline() 

478 

479 # compute 

480 # TODO: sepearate this part 

481 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

482 outputs_to_scalar_fn = [ 

483 self.output_name(i) for i in range(schema.num_output_tensors()) 

484 ] 

485 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

486 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

487 

488 code.writeline("# compute") 

489 code.writeline( 

490 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

491 ) 

492 code.newline() 

493 

494 # stores 

495 code.writeline( 

496 "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype" 

497 ) 

498 for i in range(schema.num_output_tensors()): 

499 strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim))) 

500 order = _tuple_content( 

501 tuple(f"out{i}_stride_order{j}" for j in range(ndim)) 

502 ) 

503 code.writeline( 

504 f"out{i}_bptr = tl.make_block_ptr(" 

505 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

506 ) 

507 code.writeline( 

508 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))" 

509 ) 

510 

511 def gen_body_gsl_with_bptr(self, code): 

512 code.writeline("num_ctas = tle.num_programs(0)") 

513 code.writeline("for j in range(0, tiles_per_cta):") 

514 with code.indent(): 

515 code.writeline("tile_id = pid + j * num_ctas") 

516 self.gen_body_one_tile_per_cta_with_bptr(code) 

517 

518 def gen_body_one_tile_per_cta_without_bptr(self, code): 

519 ndim = self.ndim 

520 schema = self.fx 

521 

522 # reconstruct pid multi index 

523 code.writeline( 

524 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

525 ) 

526 for i in reversed(range(ndim)): 

527 if i > 0: 

528 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

529 code.writeline(f"tile_id //= num_tiles{i}") 

530 else: 

531 code.writeline(f"tile_id{i} = tile_id") 

532 code.newline() 

533 

534 # offsets 

535 for i in range(ndim): 

536 code.writeline( 

537 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})" 

538 ) 

539 

540 # masks 

541 for i in range(ndim): 

542 code.writeline(f"mask{i} = offsets{i} < s{i}") 

543 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim)) 

544 mask_combine = " & ".join(masks) 

545 code.writeline(f"mask = {mask_combine}") 

546 

547 # loads 

548 code.writeline("# loads") 

549 for i in range(schema.num_input_tensors()): 

550 offsets = tuple( 

551 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}" 

552 for j in range(ndim) 

553 ) 

554 offset_combine = " + ".join(offsets) 

555 code.writeline( 

556 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

557 ) 

558 

559 code.newline() 

560 

561 # compute 

562 # TODO: sepearate this part 

563 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

564 outputs_to_scalar_fn = [ 

565 self.output_name(i) for i in range(schema.num_output_tensors()) 

566 ] 

567 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

568 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

569 

570 code.writeline("# compute") 

571 code.writeline( 

572 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

573 ) 

574 code.newline() 

575 

576 # stores 

577 for i in range(schema.num_output_tensors()): 

578 offsets = tuple( 

579 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}" 

580 for j in range(ndim) 

581 ) 

582 offset_combine = " + ".join(offsets) 

583 code.writeline( 

584 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

585 ) 

586 

587 def gen_body_gsl_without_bptr(self, code): 

588 code.writeline("num_ctas = tle.num_programs(0)") 

589 code.writeline("for j in range(0, tiles_per_cta):") 

590 with code.indent(): 

591 code.writeline("tile_id = pid + j * num_ctas") 

592 self.gen_body_one_tile_per_cta_without_bptr(code) 

593 

594 def codegen_nd_tile_with_bptr(self, code): 

595 """Generate kernel nd tile & 1d grid with gsl support with block pointer.""" 

596 self.gen_import_function(code) 

597 self.gen_decorators(code) 

598 self.gen_signature(code, with_block_pointer=True) 

599 

600 # function body for rank-0 

601 if self.ndim == 0: 

602 with code.indent(): 

603 self.gen_body_for_0d(code) 

604 return code 

605 

606 with code.indent(): 

607 code.writeline("pid = tle.program_id(0)") 

608 self.gen_num_tiles(code) 

609 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

610 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

611 with code.indent(): 

612 code.writeline("tile_id = pid") 

613 self.gen_body_one_tile_per_cta_with_bptr(code) 

614 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

615 code.writeline("else: # grid-stride-loop style kernel") 

616 with code.indent(): 

617 self.gen_body_gsl_with_bptr(code) 

618 code.newline() 

619 return code 

620 

621 def codegen_nd_tile_without_bptr(self, code): 

622 self.gen_import_function(code) 

623 self.gen_decorators(code) 

624 self.gen_signature(code, with_block_pointer=False) 

625 

626 # function body for rank-0 

627 if self.ndim == 0: 

628 with code.indent(): 

629 self.gen_body_for_0d(code) 

630 return code 

631 

632 with code.indent(): 

633 code.writeline("pid = tle.program_id(0)") 

634 self.gen_num_tiles(code) 

635 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

636 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

637 with code.indent(): 

638 code.writeline("tile_id = pid") 

639 self.gen_body_one_tile_per_cta_without_bptr(code) 

640 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

641 code.writeline("else: # grid-stride-loop style kernel") 

642 with code.indent(): 

643 self.gen_body_gsl_without_bptr(code) 

644 code.newline() 

645 return code 

646 

647 def codegen_nd_tile(self, code): 

648 use_block_pointer = self.config.prefer_block_pointer 

649 if use_block_pointer: 

650 self.codegen_nd_tile_with_bptr(code) 

651 else: 

652 self.codegen_nd_tile_without_bptr(code) 

653 return code 

654 

655 def gen_body_one_tile_per_cta_1d_tile(self, code): 

656 ndim = self.ndim 

657 schema = self.fx 

658 

659 # tile id 

660 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)") 

661 code.writeline("mask = tid < num_tasks") 

662 

663 # multi index reconstruction 

664 for i in reversed(range(ndim)): 

665 if i > 0: 

666 code.writeline(f"i{i} = tid % s{i}") 

667 code.writeline(f"tid //= s{i}") 

668 else: 

669 code.writeline(f"i{i} = tid") 

670 code.newline() 

671 

672 # loads 

673 code.writeline("# loads") 

674 for i in range(schema.num_input_tensors()): 

675 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim)) 

676 offset_combine = " + ".join(offsets) 

677 code.writeline( 

678 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

679 ) 

680 

681 code.newline() 

682 

683 # compute 

684 # TODO: sepearate this part 

685 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

686 outputs_to_scalar_fn = [ 

687 self.output_name(i) for i in range(schema.num_output_tensors()) 

688 ] 

689 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

690 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

691 

692 code.writeline("# compute") 

693 code.writeline( 

694 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

695 ) 

696 code.newline() 

697 

698 # stores 

699 for i in range(schema.num_output_tensors()): 

700 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim)) 

701 offset_combine = " + ".join(offsets) 

702 code.writeline( 

703 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

704 ) 

705 

706 def gen_body_gsl_1d_tile(self, code): 

707 code.writeline("num_ctas = tle.num_programs(0)") 

708 code.writeline("for j in range(0, tiles_per_cta):") 

709 with code.indent(): 

710 code.writeline("tile_id = pid + j * num_ctas") 

711 self.gen_body_one_tile_per_cta_1d_tile(code) 

712 

713 def codegen_1d_tile(self, code): 

714 """Generate kernel 1d tile & 1d grid with gsl support.""" 

715 self.gen_import_function(code) 

716 self.gen_decorators(code) 

717 self.gen_signature_1d_tile(code) 

718 

719 # function body for rank-0 

720 if self.ndim == 0: 

721 with code.indent(): 

722 self.gen_body_for_0d(code) 

723 return code 

724 

725 with code.indent(): 

726 code.writeline("pid = tle.program_id(0)") 

727 # code.writeline("num_ctas = te.num_programs(0)") 

728 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

729 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

730 with code.indent(): 

731 code.writeline("tile_id = pid") 

732 self.gen_body_one_tile_per_cta_1d_tile(code) 

733 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

734 code.writeline("else: # grid-stride-loop style kernel") 

735 with code.indent(): 

736 self.gen_body_gsl_1d_tile(code) 

737 code.newline() 

738 return code 

739 

740 

741class WrapperGenerator: 

742 def __init__( 

743 self, 

744 function_schema: FunctionSchema, 

745 jit_fn_name: str, 

746 ndim: int, 

747 name: str, 

748 config: CodeGenConfig, 

749 ): 

750 self.fx = function_schema 

751 self.jit_fn_name = jit_fn_name 

752 self.ndim = ndim 

753 self.name = name 

754 self.config = config 

755 

756 def input_name(self, i): 

757 is_tensor = self.fx.is_tensor(i) 

758 name = "in" if is_tensor else "val" 

759 index = self.fx.input_index(i) 

760 return f"{name}{index}" 

761 

762 def output_name(self, i): 

763 return f"out{i}" 

764 

765 def gen_signature(self, code: IndentedBuffer): 

766 # TODO: check if triton handles constexprs transitively 

767 schema = self.fx 

768 params: List[str] = [] 

769 for i in range(schema.num_inputs()): 

770 if schema.is_tensor(i): 

771 params.append( 

772 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]" 

773 ) 

774 else: 

775 arg_type = schema.input_type(i) 

776 if arg_type is not None: 

777 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}") 

778 else: 

779 params.append(f"{self.input_name(i)}") 

780 # NOTE: [the wrapper's signature and rules for passing parameters ] 

781 # input params: must be passed by position, since the names are renamed to 

782 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd 

783 # So we enforce that these parameters must be passed by position. 

784 # maybe we can fix it later 

785 # output parameters: must be passed by keyword, since the scalar function 

786 # do not have output parameters(think of it as some scalar function, output 

787 # parameter does not make sense in this case.) They are added to allow destination 

788 # passing style API. Output parameter is convenient in cases where we want 

789 # to use some pre-defiend outputs(especially when they are some views of other 

790 # tensors). We emphasize that these parameters are added in-addition, we enforce 

791 # that they be passed by keyword. After all, out0, out1, ... does not mismatch 

792 # names form the scalar function, since it does not have output parameters. 

793 params.append("/") 

794 params.append("*") # output params must be passed by keyword 

795 

796 for i in range(schema.num_output_tensors()): 

797 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]") 

798 code.writeline(f"def {self.name}({_cs(params)}): ") 

799 

800 def gen_docstring(self, code: IndentedBuffer): 

801 schema = self.fx 

802 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""' 

803 code.writeline(doc) 

804 

805 def gen_same_shape_check(self, code: IndentedBuffer): 

806 schema: FunctionSchema = self.fx 

807 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [ 

808 f"out{i}.shape" for i in range(schema.num_output_tensors()) 

809 ] 

810 check: str = " == ".join(params) 

811 code.writeline(f"assert {check}, 'operand shapes mismatch'") 

812 

813 def gen_task_partition(self, code: IndentedBuffer): 

814 code.writeline("# task partitioning") 

815 ndim = self.ndim 

816 if ndim == 0: 

817 code.writeline("num_warps = 1") 

818 code.writeline("num_ctas = 1") 

819 else: 

820 code.writeline("shape = out0.shape") 

821 code.writeline("num_tasks = out0.numel()") 

822 code.writeline("if num_tasks == 0:") 

823 with code.indent(): 

824 self.gen_return(code) 

825 max_tile_size = self.config.max_tile_size 

826 major, _ = get_device_capability() 

827 if self.name.find("fill_scalar") != -1 and major >= 9: 

828 code.writeline("tile_sizes = tuple([64])") 

829 else: 

830 code.writeline( 

831 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)" 

832 ) 

833 code.writeline("tile_size = math.prod(tile_sizes)") 

834 code.writeline( 

835 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))" 

836 ) 

837 

838 if self.name.find("fill_scalar") != -1 and major >= 9: 

839 code.writeline("num_ctas = num_tiles") 

840 else: 

841 max_grid_size0 = self.config.max_grid_size[0] 

842 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

843 

844 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

845 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

846 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

847 code.writeline("grid = (num_ctas, 1, 1)") 

848 

849 def gen_task_partition_1d(self, code: IndentedBuffer): 

850 code.writeline("# task partitioning") 

851 ndim = self.ndim 

852 if ndim == 0: 

853 code.writeline("num_warps = 1") 

854 code.writeline("num_ctas = 1") 

855 else: 

856 code.writeline("shape = out0.shape") 

857 code.writeline("num_tasks = out0.numel()") 

858 code.writeline("if num_tasks == 0:") 

859 with code.indent(): 

860 self.gen_return(code) 

861 max_tile_size = self.config.max_tile_size 

862 

863 major, _ = get_device_capability() 

864 if self.name.find("fill_scalar") != -1 and major >= 9: 

865 code.writeline("tile_sizes = tuple([64])") 

866 else: 

867 code.writeline( 

868 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)" 

869 ) 

870 

871 code.writeline("tile_size = tile_sizes[0]") 

872 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)") 

873 

874 if self.name.find("fill_scalar") != -1 and major >= 9: 

875 code.writeline("num_ctas = num_tiles") 

876 else: 

877 max_grid_size0 = self.config.max_grid_size[0] 

878 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

879 

880 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

881 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

882 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

883 code.writeline("grid = (num_ctas, 1, 1)") 

884 

885 def gen_kernel_launch( 

886 self, 

887 code: IndentedBuffer, 

888 ): 

889 schema = self.fx 

890 ndim = self.ndim 

891 

892 with_block_pointer = self.config.prefer_block_pointer 

893 

894 code.writeline("# kernel launch") 

895 for i in range(schema.num_input_tensors()): 

896 code.writeline(f"in{i}_strides = in{i}.stride()") 

897 if not with_block_pointer: 

898 continue 

899 if ndim >= 2: # where ndim is 1, we don't need to compute stride order 

900 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)") 

901 else: 

902 code.writeline(f"in{i}_stride_order = (0,)") 

903 for i in range(schema.num_output_tensors()): 

904 code.writeline(f"out{i}_strides = out{i}.stride()") 

905 if not with_block_pointer: 

906 continue 

907 if ndim >= 2: 

908 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)") 

909 else: 

910 code.writeline(f"out{i}_stride_order = (0,)") 

911 

912 code.writeline("with torch_device_fn.device(in0.device.index):") 

913 with code.indent(): 

914 code.writeline(f"{self.jit_fn_name}[grid](") 

915 with code.indent(): 

916 params = [] 

917 # NOTE: WRAP 

918 for i in range(schema.num_inputs()): 

919 if schema.is_tensor(i): 

920 params.append(f"{self.input_name(i)}") 

921 else: 

922 params.append(self.input_name(i)) 

923 for i in range(schema.num_output_tensors()): 

924 params.append(f"{self.output_name(i)}") 

925 

926 code.writeline(f"{_cs(params)},") 

927 

928 if ndim > 0: 

929 for i in range(schema.num_input_tensors()): 

930 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

931 code.writeline(f"{s}, # stride for in{i}") 

932 if not with_block_pointer: 

933 continue 

934 order = ", ".join( 

935 f"in{i}_stride_order[{j}]" for j in range(ndim) 

936 ) 

937 code.writeline(f"{order}, # stride order for in{i}") 

938 

939 for i in range(schema.num_output_tensors()): 

940 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

941 code.writeline(f"{s}, # stride for out{i}") 

942 if not with_block_pointer: 

943 continue 

944 order = ", ".join( 

945 f"out{i}_stride_order[{j}]" for j in range(ndim) 

946 ) 

947 code.writeline(f"{order}, # stride orderfor out{i}") 

948 

949 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

950 code.writeline(f"{shape_args}, # task indexing space") 

951 code.writeline("num_tasks, # num tasks") 

952 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

953 for i in range(ndim): 

954 code.writeline(f"tile_size{i}=tile_sizes[{i}],") 

955 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

956 code.writeline("num_warps=num_warps,") 

957 code.writeline(")") 

958 

959 def gen_kernel_launch_1d( 

960 self, 

961 code: IndentedBuffer, 

962 ): 

963 schema = self.fx 

964 ndim = self.ndim 

965 

966 code.writeline("# kernel launch") 

967 for i in range(schema.num_input_tensors()): 

968 code.writeline(f"in{i}_strides = in{i}.stride()") 

969 for i in range(schema.num_output_tensors()): 

970 code.writeline(f"out{i}_strides = out{i}.stride()") 

971 

972 code.writeline("with torch_device_fn.device(in0.device.index):") 

973 with code.indent(): 

974 code.writeline(f"{self.jit_fn_name}[grid](") 

975 with code.indent(): 

976 params = [] 

977 # NOTE: WRAP 

978 for i in range(schema.num_inputs()): 

979 if schema.is_tensor(i): 

980 params.append(f"{self.input_name(i)}") 

981 else: 

982 params.append(self.input_name(i)) 

983 for i in range(schema.num_output_tensors()): 

984 params.append(f"{self.output_name(i)}") 

985 

986 code.writeline(f"{_cs(params)},") 

987 

988 if ndim > 0: 

989 for i in range(schema.num_input_tensors()): 

990 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

991 code.writeline(f"{s}, # stride for in{i}") 

992 for i in range(schema.num_output_tensors()): 

993 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

994 code.writeline(f"{s}, # stride for out{i}") 

995 

996 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

997 code.writeline(f"{shape_args}, # task indexing space") 

998 code.writeline("num_tasks, # num tasks") 

999 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

1000 code.writeline("tile_size=tile_size,") 

1001 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

1002 code.writeline("num_warps=num_warps,") 

1003 code.writeline(")") 

1004 

1005 def gen_return(self, code: IndentedBuffer): 

1006 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors())) 

1007 code.writeline(f"return {return_exprs}") 

1008 

1009 def codegen_nd_tile(self, code): 

1010 self.gen_signature(code) 

1011 

1012 with code.indent(): 

1013 self.gen_docstring(code) 

1014 self.gen_same_shape_check(code) 

1015 self.gen_task_partition(code) 

1016 self.gen_kernel_launch(code) 

1017 self.gen_return(code) 

1018 code.newline() 

1019 return code 

1020 

1021 def codegen_1d_tile(self, code): 

1022 self.gen_signature(code) 

1023 

1024 with code.indent(): 

1025 self.gen_docstring(code) 

1026 self.gen_same_shape_check(code) 

1027 self.gen_task_partition_1d(code) 

1028 self.gen_kernel_launch_1d(code) 

1029 self.gen_return(code) 

1030 code.newline() 

1031 return code 

1032 

1033 

1034class ModuleGenerator: 

1035 def __init__( 

1036 self, 

1037 function_schema: FunctionSchema, 

1038 scalar_fn: triton.JITFunction, 

1039 ndim: int, 

1040 jit_fn_name: str, 

1041 wrapper_name: str, 

1042 config: CodeGenConfig, 

1043 ): 

1044 self.config = config 

1045 self.scalar_fn = scalar_fn 

1046 self.wrapper_gen = WrapperGenerator( 

1047 function_schema, jit_fn_name, ndim, wrapper_name, config 

1048 ) 

1049 self.kernel_gen = KernelGenerator( 

1050 function_schema, scalar_fn, ndim, jit_fn_name, config 

1051 ) 

1052 

1053 @staticmethod 

1054 def _collect_jit_deps(scalar_fn): 

1055 """Collect extra imports and local @triton.jit helper sources. 

1056 

1057 Parses the source module where scalar_fn is defined using AST. 

1058 Returns a tuple of: 

1059 - extra_imports: dict of module_path -> set of names 

1060 - local_sources: list of source strings for local @triton.jit 

1061 functions (those NOT decorated with @pointwise_dynamic) 

1062 """ 

1063 import ast 

1064 import inspect 

1065 

1066 py_fn = getattr(scalar_fn, "fn", scalar_fn) 

1067 module_name = getattr(py_fn, "__module__", None) 

1068 if not module_name: 

1069 return {}, [] 

1070 try: 

1071 mod = importlib.import_module(module_name) 

1072 source_file = inspect.getfile(mod) 

1073 except (ImportError, TypeError, OSError): 

1074 return {}, [] 

1075 try: 

1076 with open(source_file) as f: 

1077 module_source = f.read() 

1078 source_lines = module_source.splitlines(keepends=True) 

1079 tree = ast.parse(module_source) 

1080 except (OSError, SyntaxError): 

1081 return {}, [] 

1082 

1083 # Collect non-standard import-from lines 

1084 ALREADY_IMPORTED = { 

1085 "math", 

1086 "typing", 

1087 "torch", 

1088 "triton", 

1089 "triton.language", 

1090 "flag_gems.utils.shape_utils", 

1091 "flag_gems.utils.tensor_wrapper", 

1092 "flag_gems.utils.libentry", 

1093 "flag_gems.utils", 

1094 "flag_gems.runtime", 

1095 "flag_gems.utils.pointwise_dynamic", 

1096 } 

1097 extra_imports = {} 

1098 for node in ast.iter_child_nodes(tree): 

1099 if isinstance(node, ast.ImportFrom) and node.module: 

1100 if node.module in ALREADY_IMPORTED: 

1101 continue 

1102 names = {alias.name for alias in node.names} 

1103 extra_imports.setdefault(node.module, set()).update(names) 

1104 

1105 # Collect local @triton.jit functions (without @pointwise_dynamic) 

1106 def _has_decorator(func_node, name): 

1107 for dec in func_node.decorator_list: 

1108 src = "".join(source_lines[dec.lineno - 1 : dec.end_lineno]) 

1109 if name in src: 

1110 return True 

1111 return False 

1112 

1113 def _extract_source(func_node): 

1114 start = func_node.lineno - 1 

1115 if func_node.decorator_list: 

1116 start = func_node.decorator_list[0].lineno - 1 

1117 end = func_node.end_lineno 

1118 return "".join(source_lines[start:end]) 

1119 

1120 local_sources = [] 

1121 for node in ast.iter_child_nodes(tree): 

1122 if not isinstance(node, ast.FunctionDef): 

1123 continue 

1124 if not _has_decorator(node, "triton.jit") and not _has_decorator( 

1125 node, "jit" 

1126 ): 

1127 continue 

1128 if _has_decorator(node, "pointwise_dynamic"): 

1129 continue 

1130 local_sources.append(_extract_source(node)) 

1131 

1132 return extra_imports, local_sources 

1133 

1134 def generate_imports(self, code: IndentedBuffer) -> IndentedBuffer: 

1135 code.writeline("import math") 

1136 code.writeline("from typing import Union") 

1137 code.writeline("import torch") 

1138 code.writeline("import triton") 

1139 code.writeline("from triton import language as tl") 

1140 code.newline() 

1141 code.writeline("from flag_gems.utils.shape_utils import (") 

1142 code.writeline(" heuristics_for_tile_size,") 

1143 code.writeline(" heuristics_for_num_warps,") 

1144 code.writeline(" stride_order,") 

1145 code.writeline(")") 

1146 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer") 

1147 code.writeline("from flag_gems.utils.libentry import libentry") 

1148 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

1149 code.writeline("from flag_gems.runtime import torch_device_fn") 

1150 

1151 # Generate extra imports and local JIT deps of the scalar function 

1152 jit_dep_imports, local_jit_sources = self._collect_jit_deps(self.scalar_fn) 

1153 for module_path, names in sorted(jit_dep_imports.items()): 

1154 sorted_names = ", ".join(sorted(names)) 

1155 code.writeline(f"from {module_path} import {sorted_names}") 

1156 

1157 code.newline() 

1158 code.newline() 

1159 

1160 # Emit local @triton.jit helper functions 

1161 for source in local_jit_sources: 

1162 for line in source.splitlines(): 

1163 code.writeline(line) 

1164 code.newline() 

1165 

1166 return code 

1167 

1168 def codegen(self, code: IndentedBuffer): 

1169 code = self.generate_imports(code) 

1170 if self.config.prefer_1d_tile: 

1171 code = self.wrapper_gen.codegen_1d_tile(code) 

1172 code = self.kernel_gen.codegen_1d_tile(code) 

1173 else: 

1174 code = self.wrapper_gen.codegen_nd_tile(code) 

1175 code = self.kernel_gen.codegen_nd_tile(code) 

1176 return code 

1177 

1178 

1179@dataclass 

1180class KernelInfo: 

1181 """Information about a generated kernel for C++ integration.""" 

1182 

1183 file_path: str 

1184 kernel_name: str 

1185 wrapper_name: str 

1186 ndim: int 

1187 

1188 

1189class PointwiseDynamicFunction: 

1190 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

1191 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

1192 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

1193 """ 

1194 

1195 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None): 

1196 self.fx = op_desc 

1197 

1198 assert isinstance(scalar_fn, JITFunction) 

1199 self._scalar_fn = scalar_fn 

1200 self._scalar_fn_cache_key = scalar_fn.cache_key 

1201 self.pid = os.getpid() 

1202 

1203 self.config: CodeGenConfig = config or get_codegen_config() 

1204 

1205 # instantiated & cached overloads 

1206 self.overloads: Mapping[str, Callable] = {} 

1207 # cached kernel info for C++ integration 

1208 self._kernel_info_cache: Mapping[str, KernelInfo] = {} 

1209 

1210 def __call__(self, *args, **kwargs): 

1211 # inputs must be passed by position, outputs must be passed by keyword 

1212 ndim, args, kwargs = self.prepare_args(*args, **kwargs) 

1213 overload = self.instantiate(ndim) 

1214 out = overload(*args, **kwargs) 

1215 # NOTE: overload keeps the type of outputs: 

1216 # if a pre-defiend output is a Tensor or StridedBuffer, the corresponding 

1217 # output is also a Tensor StridedBuffer, respectively 

1218 # since prepare_args Wraps all the arguments, the outputs are all StridedBuffer 

1219 # but if manually instantiated overload is directly called, take care of 

1220 # that manually 

1221 return self._unwrap(out) 

1222 

1223 @staticmethod 

1224 def use_fast_path(tensors): 

1225 return all_the_same_shape(tensors) and ( 

1226 all_c_contiguous(tensors) 

1227 or ( 

1228 all_the_same_stride(tensors) 

1229 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0]) 

1230 ) 

1231 ) 

1232 

1233 def prepare_args(self, *args, **kwargs): 

1234 # output allocation(when needed) 

1235 # task simplification & task-rank infernece & input-output reinterpretation 

1236 schema = self.fx 

1237 outputs_that_need_allocation: List[int] = [] 

1238 out_tensors = [] 

1239 for i in range(schema.num_output_tensors()): 

1240 k = f"out{i}" 

1241 if k in kwargs: 

1242 out_tensors.append(kwargs[k]) 

1243 else: 

1244 outputs_that_need_allocation.append(i) 

1245 # input arguments must be passed by position 

1246 if schema._is_tensor is not None: 

1247 if not check_tensor_attributes(args, (schema._is_tensor)): 

1248 raise ValueError( 

1249 "Input arguments must be passed by position, and the corresponding dtype must be specified." 

1250 ) 

1251 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)] 

1252 

1253 # output dtype promotions 

1254 outputs_dtypes_for_allocation = [] 

1255 for i in outputs_that_need_allocation: 

1256 *arg_indices, method = schema._promotion_methods[i] 

1257 promote_args = (args[j] for j in arg_indices) 

1258 _, dtype = type_promotion(*promote_args, type_promotion=method) 

1259 outputs_dtypes_for_allocation.append(dtype) 

1260 

1261 tensors = out_tensors + in_tensors 

1262 INT32_MAX = torch.iinfo(torch.int32).max 

1263 if tensors[0].numel() > INT32_MAX: 

1264 self.config.prefer_block_pointer = False 

1265 if self.use_fast_path(tensors): # dimension collapse & use physical ordering 

1266 allocated_outputs = [ 

1267 torch.empty_like(tensors[0], dtype=dtype) 

1268 for dtype in outputs_dtypes_for_allocation 

1269 ] 

1270 task_shape = (tensors[0].numel(),) 

1271 strides = (1,) 

1272 ndim = 1 

1273 args = tuple( 

1274 ( 

1275 StridedBuffer(item, task_shape, strides) 

1276 if schema.is_tensor(i) 

1277 else item 

1278 ) 

1279 for i, item in enumerate(args) 

1280 ) 

1281 kwargs = { 

1282 k: StridedBuffer(item, task_shape, strides) 

1283 for k, item in kwargs.items() 

1284 } 

1285 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1286 kwargs[f"out{output_id}"] = StridedBuffer( 

1287 allocated_outputs[seq_id], task_shape, strides 

1288 ) 

1289 else: 

1290 # a simple strategy: all the undefined tensors will follow the first 

1291 # tensor that is not broadcated, no attempts to simplify task, no reordering, 

1292 # no dimenion collapsing 

1293 shapes = tuple(item.shape for item in in_tensors) 

1294 

1295 task_shape = broadcast_shapes(shapes) 

1296 

1297 if out_tensors: 

1298 for index, item in enumerate(out_tensors): 

1299 if list(item.shape) != list(task_shape): 

1300 raise RuntimeError( 

1301 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!" 

1302 ) 

1303 # output arguments must not have internal overlapping for pointwise operation 

1304 if has_internal_overlapping(item) == MemOverlap.Yes: 

1305 raise RuntimeError( 

1306 "Pointwise Input arguments should not have internal overlapping." 

1307 ) 

1308 

1309 ndim = len(task_shape) 

1310 for item in tensors: 

1311 if item.shape == task_shape: 

1312 allocated_outputs = [ 

1313 torch.empty_like(item, dtype=dtype) 

1314 for dtype in outputs_dtypes_for_allocation 

1315 ] 

1316 break 

1317 else: # nobreak 

1318 device = tensors[0].device 

1319 allocated_outputs = [ 

1320 torch.empty(task_shape, dtype=dtype, device=device) 

1321 for dtype in outputs_dtypes_for_allocation 

1322 ] 

1323 args = tuple( 

1324 ( 

1325 StridedBuffer( 

1326 item, 

1327 task_shape, 

1328 broadcasted_stride(item.shape, item.stride(), task_shape), 

1329 ) 

1330 if schema.is_tensor(i) 

1331 else item 

1332 ) 

1333 for i, item in enumerate(args) 

1334 ) 

1335 kwargs = { 

1336 k: StridedBuffer( 

1337 item, 

1338 task_shape, 

1339 broadcasted_stride(item.shape, item.stride(), task_shape), 

1340 ) 

1341 for k, item in kwargs.items() 

1342 } 

1343 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1344 item = allocated_outputs[seq_id] 

1345 kwargs[f"out{output_id}"] = StridedBuffer( 

1346 item, 

1347 task_shape, 

1348 broadcasted_stride(item.shape, item.stride(), task_shape), 

1349 ) 

1350 return (ndim, args, kwargs) 

1351 

1352 def _unwrap(self, tensors): 

1353 # unwrap StridedBuffer to get Tensor 

1354 if self.fx.num_output_tensors() == 1: 

1355 item = tensors 

1356 return item.unwrap() 

1357 return tuple(item.unwrap() for item in tensors) 

1358 

1359 def _compute_kernel_names(self, ndim: int) -> Tuple[str, str, str]: 

1360 """Compute kernel name, wrapper name, and file path for a given ndim. 

1361 

1362 This is the single source of truth for naming, used by both instantiate() 

1363 and get_kernel_info() to ensure consistency. 

1364 

1365 Returns: 

1366 Tuple of (kernel_name, wrapper_name, file_path) 

1367 """ 

1368 scalar_fn_name = self._scalar_fn.__name__ 

1369 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}" 

1370 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}" 

1371 

1372 file_name = ( 

1373 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_" 

1374 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}" 

1375 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}" 

1376 ".py" 

1377 ) 

1378 file_path = str(code_cache_dir() / file_name) 

1379 

1380 return kernel_name, wrapper_name, file_path 

1381 

1382 def instantiate(self, ndim): 

1383 # NOTE: manually instantiated overload does not have `prepare_args` as 

1384 # preprocessing, so you have to manually allocate output and make sure that 

1385 # the inputs & ouputs actually fits the manually instantiated overload 

1386 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1387 if key in self.overloads: 

1388 return self.overloads[key] 

1389 

1390 code = IndentedBuffer() 

1391 

1392 # Use helper to compute names (single source of truth) 

1393 kernel_name, wrapper_name, file_path = self._compute_kernel_names(ndim) 

1394 

1395 module_gen = ModuleGenerator( 

1396 self.fx, 

1397 self._scalar_fn, 

1398 ndim, 

1399 kernel_name, 

1400 wrapper_name, 

1401 self.config, 

1402 ) 

1403 module_gen.codegen(code) 

1404 

1405 # NOTE: [why write the generated code to a file] 

1406 # triton uses inpsect to get the source of the jitted function, which requires 

1407 # that the source code can be found by inspect 

1408 # We write it into a file, since inspect cannot find the source of functions dynamically 

1409 # created via exec string. We can help inspect to find the source by hacking linecache 

1410 # library, but we find generating a module simpler, since we can generating 2 functions 

1411 # the kernel and the wrapper, and the wrapper calls the kernel. 

1412 write_atomic(file_path, code.getvalue()) 

1413 

1414 # load 

1415 spec = importlib.util.spec_from_file_location( 

1416 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}", 

1417 file_path, 

1418 ) 

1419 m = importlib.util.module_from_spec(spec) 

1420 # do not expose it to sys.modules 

1421 # sys.modules["_add_module"] = m 

1422 

1423 # NOTE: [why not import the scalar function] 

1424 # we do not re-import the scalar function, although the generated kernel **calls** it 

1425 # Since a function's __name__ may be changed, from the module where it is defined import its 

1426 # __name__ is not same; Also the same may be rebind to something else, importing via name 

1427 # cannot guarantee that scalar function is imported. 

1428 # So we copy the scalar function and its __globals__ to the generated module to do this 

1429 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime 

1430 spec.loader.exec_module(m) 

1431 m.__dict__.update(self._scalar_fn.__globals__) 

1432 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn 

1433 

1434 overload = getattr(m, wrapper_name) 

1435 self.overloads[key] = overload 

1436 

1437 # Cache kernel info for C++ integration 

1438 self._kernel_info_cache[key] = KernelInfo( 

1439 file_path=file_path, 

1440 kernel_name=kernel_name, 

1441 wrapper_name=wrapper_name, 

1442 ndim=ndim, 

1443 ) 

1444 

1445 return overload 

1446 

1447 def get_kernel_info(self, ndim: int) -> KernelInfo: 

1448 """Get kernel information for a given ndim. 

1449 

1450 This method is useful for C++ integration to get the file path and 

1451 kernel name without duplicating the naming logic. 

1452 

1453 If the kernel hasn't been instantiated yet, this will instantiate it first. 

1454 

1455 Args: 

1456 ndim: The rank of the task space 

1457 

1458 Returns: 

1459 KernelInfo with file_path, kernel_name, wrapper_name, and ndim 

1460 """ 

1461 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1462 

1463 # Ensure the kernel is instantiated 

1464 if key not in self._kernel_info_cache: 

1465 self.instantiate(ndim) 

1466 

1467 return self._kernel_info_cache[key] 

1468 

1469 

1470def pointwise_dynamic( 

1471 f: Optional[JITFunction] = None, 

1472 *, 

1473 num_inputs: Optional[int] = None, 

1474 is_tensor: Optional[List[bool]] = None, 

1475 dtypes: Optional[List[Optional[type]]] = None, 

1476 num_outputs: Optional[int] = None, 

1477 promotion_methods: Optional[Tuple[int, ...]] = None, 

1478 config: Optional[CodeGenConfig] = None, 

1479): 

1480 def decorator(fn): 

1481 nonlocal num_inputs 

1482 if (num_inputs is None) and (is_tensor is None) and (dtypes is None): 

1483 num_inputs = len(fn.arg_names) 

1484 op_desc = FunctionSchema( 

1485 num_inputs=num_inputs, 

1486 is_tensor=is_tensor, 

1487 dtypes=dtypes, 

1488 num_outputs=num_outputs, 

1489 promotion_methods=promotion_methods, 

1490 ) 

1491 return PointwiseDynamicFunction(op_desc, fn, config) 

1492 

1493 if f is not None: 

1494 return decorator(f) 

1495 return decorator