Coverage for src/flag_gems/utils/pointwise_dynamic.py: 95%

805 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import importlib 

2import os 

3from dataclasses import dataclass 

4from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple 

5 

6import torch 

7import triton 

8from triton.runtime.jit import JITFunction 

9 

10from flag_gems.utils.code_cache import code_cache_dir 

11from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

12from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config 

13from flag_gems.utils.device_info import get_device_capability 

14from flag_gems.utils.shape_utils import ( 

15 MemOverlap, 

16 all_c_contiguous, 

17 all_the_same_shape, 

18 all_the_same_stride, 

19 broadcast_shapes, 

20 broadcasted_stride, 

21 check_tensor_attributes, 

22 has_internal_overlapping, 

23) 

24from flag_gems.utils.tensor_wrapper import StridedBuffer 

25from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion 

26 

27 

28# ------------------ Operation Description --------------------------- 

29def _type_name(type) -> str: 

30 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object" 

31 if type in (bool, int, float, str): 

32 return type.__name__ 

33 if isinstance(type, torch.dtype): 

34 return str(type) 

35 return str(type) 

36 

37 

38def _check_typed_list(container, type): 

39 for item in container: 

40 assert isinstance(item, type) 

41 

42 

43def _check_sized_list(container, size): 

44 assert len(container) == size 

45 

46 

47def _tuple_content(strings: Sequence[str]) -> str: 

48 # comma separated list 

49 if len(strings) == 0: 

50 return "" 

51 if len(strings) == 1: 

52 return f"{strings[0]}," 

53 else: 

54 return ", ".join(strings) 

55 

56 

57def _cs(strings: Iterable[str]) -> str: 

58 return ", ".join(strings) 

59 

60 

61def _broadcast_vec(i, ndim): 

62 axes = [":" if j == i else "None" for j in range(ndim)] 

63 return f"[{_cs(axes)}]" 

64 

65 

66class FunctionSchema: 

67 _num_inputs: int 

68 _is_tensor: List[bool] 

69 _dtypes: List[Optional[type]] 

70 

71 _num_input_tensors: int 

72 _num_non_tensor_inputs: int 

73 

74 _num_outputs: int 

75 _promotion_methods: List[Tuple[int, ...]] 

76 

77 def __init__( 

78 self, 

79 *, 

80 num_inputs: Optional[int] = None, 

81 is_tensor: Optional[List[bool]] = None, 

82 dtypes: Optional[List[Optional[type]]] = None, 

83 num_outputs: Optional[int] = None, 

84 promotion_methods=None, 

85 ): 

86 if is_tensor is not None: 

87 _check_typed_list(is_tensor, bool) 

88 if dtypes is not None: 

89 _check_typed_list(dtypes, (type, type(None))) 

90 

91 if promotion_methods is None: 

92 raise ValueError( 

93 "No type promotion method provided! You must provide type promotion method for each output!" 

94 ) 

95 else: 

96 self._promotion_methods = self.canonicalize_promotion_methods( 

97 promotion_methods 

98 ) 

99 if num_inputs is not None: 

100 self._num_inputs = num_inputs 

101 if is_tensor is not None: 

102 _check_sized_list(is_tensor, num_inputs) 

103 self._is_tensor = is_tensor 

104 else: 

105 self._is_tensor = [True] * num_inputs 

106 

107 if dtypes is not None: 

108 _check_sized_list(dtypes, num_inputs) 

109 self._dtypes = dtypes 

110 else: 

111 self._dtypes = [None] * num_inputs 

112 elif is_tensor is not None: 

113 self._num_inputs = len(is_tensor) 

114 self._is_tensor = is_tensor 

115 if dtypes is not None: 

116 _check_sized_list(dtypes, self._num_inputs) 

117 self._dtypes = dtypes 

118 else: 

119 self._dtypes = [None] * self._num_inputs 

120 elif dtypes is not None: 

121 self._num_inputs = len(dtypes) 

122 self._dtypes = dtypes 

123 if is_tensor is not None: 

124 _check_sized_list(is_tensor, self._num_inputs) 

125 self._is_tensor = is_tensor 

126 else: 

127 self._is_tensor = [item is None for item in dtypes] 

128 else: 

129 raise ValueError( 

130 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified." 

131 ) 

132 

133 if num_outputs is not None: 

134 self._num_outputs = num_outputs 

135 _check_sized_list(promotion_methods, num_outputs) 

136 else: 

137 self._num_outputs = len(promotion_methods) 

138 

139 assert self._num_inputs >= 1 

140 assert self._num_outputs >= 1 

141 

142 self._num_input_tensors = sum(self._is_tensor) 

143 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors 

144 self._input_id = self._compute_input_id() 

145 

146 @staticmethod 

147 def canonicalize_promotion_methods(promotion_methods): 

148 canonicalized = [] 

149 for item in promotion_methods: 

150 *arg_indices, method = item 

151 canonicalized.append( 

152 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method]) 

153 ) 

154 return canonicalized 

155 

156 def num_inputs(self): 

157 # num of arguments, outputs not included 

158 return self._num_inputs 

159 

160 def num_outputs(self): 

161 return self._num_outputs 

162 

163 def is_tensor(self, arg_id: int) -> bool: 

164 return self._is_tensor[arg_id] 

165 

166 def input_type(self, arg_id) -> Optional[type]: 

167 return self._dtypes[arg_id] 

168 

169 def output_type(self, i): 

170 return self._promotion_methods[i] 

171 

172 def num_input_tensors(self) -> int: 

173 return self._num_input_tensors 

174 

175 def num_output_tensors(self) -> int: 

176 return self._num_outputs 

177 

178 def num_non_tensor_args(self) -> int: 

179 return self._num_non_tensor_inputs 

180 

181 def signature(self, outputs_in_arg: bool = False) -> str: 

182 input_types = [] 

183 for is_tensor, dtype in zip(self._is_tensor, self._dtypes): 

184 if is_tensor: 

185 input_types.append("StridedBuffer") 

186 else: 

187 if dtype is None: 

188 input_types.append("scalar") 

189 else: 

190 input_types.append(_type_name(dtype)) 

191 

192 output_types = [] 

193 

194 if outputs_in_arg: 

195 for i in range(self.num_outputs()): 

196 output_types.append(f"StridedBuffer(a{1}!)") 

197 input_types.extend(output_types) 

198 else: 

199 for _ in range(self.num_outputs()): 

200 output_types.append("StridedBuffer") 

201 sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}' 

202 return sig 

203 

204 def _compute_input_id(self): 

205 input_tensor_index = 0 

206 non_tensor_index = 0 

207 mapping: List[int] = [] 

208 for i in range(self.num_inputs()): 

209 if self.is_tensor(i): 

210 mapping.append(input_tensor_index) 

211 input_tensor_index += 1 

212 else: 

213 mapping.append(non_tensor_index) 

214 non_tensor_index += 1 

215 return mapping 

216 

217 def input_index(self, idx): 

218 return self._input_id[idx] 

219 

220 def __str__(self) -> str: 

221 return self.signature(outputs_in_arg=False) 

222 

223 

224class KernelGenerator: 

225 def __init__( 

226 self, 

227 function_schema: FunctionSchema, 

228 scalar_fn: triton.JITFunction, 

229 rank: int, 

230 name: str, 

231 config: CodeGenConfig, 

232 ): 

233 self.fx = function_schema 

234 self.fn = scalar_fn 

235 self.ndim = rank 

236 self.name = name 

237 self.config = config 

238 

239 self.fn_name = scalar_fn.__name__ 

240 self.fn_module = scalar_fn.__module__ 

241 

242 def gen_import_function(self, code: IndentedBuffer): 

243 code.writeline(f'"""Quoted source of {self.fn_name}:') 

244 code.writemultiline(self.fn.src) 

245 code.writeline('"""') 

246 code.newline() 

247 

248 def gen_decorators(self, code): 

249 code.writeline("@libentry()") 

250 num_non_tensor_args = self.fx.num_non_tensor_args() 

251 if num_non_tensor_args > 0: 

252 # we do not specialize non tensor args since they are passed into the inlined function 

253 # which means that their values may not deserve specialization 

254 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)] 

255 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

256 else: 

257 code.writeline("@triton.jit") 

258 

259 def input_name(self, i): 

260 is_tensor = self.fx.is_tensor(i) 

261 name = "in" if is_tensor else "val" 

262 index = self.fx.input_index(i) 

263 return f"{name}{index}" 

264 

265 def output_name(self, i): 

266 return f"out{i}" 

267 

268 def gen_signature(self, code, with_block_pointer=False): 

269 code.writeline(f"def {self.name}(") 

270 with code.indent(): 

271 input_tensor_index = 0 

272 non_tensor_index = 0 

273 output_tensor_index = 0 

274 

275 schema = self.fx 

276 # signature: inputs ptrs & non tensor inputs 

277 for i in range(schema.num_inputs()): 

278 if schema.is_tensor(i): 

279 code.writeline( 

280 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

281 ) 

282 input_tensor_index += 1 

283 else: 

284 if schema.input_type(i) is not None: 

285 code.writeline( 

286 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

287 ) 

288 else: 

289 code.writeline(f"val{non_tensor_index},") 

290 non_tensor_index += 1 

291 

292 # signature: output ptrs 

293 for i in range(schema.num_outputs()): 

294 code.writeline( 

295 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

296 ) 

297 output_tensor_index += 1 

298 

299 # signature: strides, for each tensor arguments 

300 ndim = self.ndim 

301 if ndim > 0: 

302 # strides for inputs 

303 for i in range(schema.num_input_tensors()): 

304 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

305 code.writeline(f"{stride_args}, # strides for in{i}") 

306 if with_block_pointer: 

307 stride_order_args = _cs( 

308 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

309 ) 

310 code.writeline(f"{stride_order_args}, # stride order for in{i}") 

311 

312 # strides for outputs 

313 for i in range(schema.num_output_tensors()): 

314 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

315 code.writeline(f"{stride_args}, # strides for out{i}") 

316 if with_block_pointer: 

317 stride_order_args = _cs( 

318 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

319 ) 

320 code.writeline( 

321 f"{stride_order_args}, # stride order for out{i}" 

322 ) 

323 

324 # task space, used to reconstruct multi index 

325 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

326 code.writeline(f"{task_space_args}, # task_space") 

327 

328 # number of tasks, used to compute mask 

329 code.writeline("num_tasks,") 

330 

331 # tile size & tiles_per_cta, gsl style 

332 if ndim > 0: 

333 code.writeline("tiles_per_cta: int,") 

334 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim)) 

335 code.writeline(f"{tile_sizes},") 

336 code.writeline("one_tile_per_cta: tl.constexpr,") 

337 code.writeline("):") 

338 

339 def gen_signature_1d_tile(self, code): 

340 code.writeline(f"def {self.name}(") 

341 with code.indent(): 

342 input_tensor_index = 0 

343 non_tensor_index = 0 

344 output_tensor_index = 0 

345 

346 schema = self.fx 

347 # signature: inputs ptrs & non tensor inputs 

348 for i in range(schema.num_inputs()): 

349 if schema.is_tensor(i): 

350 code.writeline( 

351 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

352 ) 

353 input_tensor_index += 1 

354 else: 

355 if schema.input_type(i) is not None: 

356 code.writeline( 

357 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

358 ) 

359 else: 

360 code.writeline(f"val{non_tensor_index},") 

361 non_tensor_index += 1 

362 

363 # signature: output ptrs 

364 for i in range(schema.num_outputs()): 

365 code.writeline( 

366 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

367 ) 

368 output_tensor_index += 1 

369 

370 # signature: strides, for each tensor arguments 

371 ndim = self.ndim 

372 if ndim > 0: 

373 # strides for inputs 

374 for i in range(schema.num_input_tensors()): 

375 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

376 code.writeline(f"{stride_args}, # strides for in{i}") 

377 

378 # strides for outputs 

379 for i in range(schema.num_output_tensors()): 

380 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

381 code.writeline(f"{stride_args}, # strides for out{i}") 

382 

383 # task space, used to reconstruct multi index 

384 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

385 code.writeline(f"{task_space_args}, # task_space") 

386 

387 # number of tasks, used to compute mask 

388 code.writeline("num_tasks,") 

389 

390 # tile size & tiles_per_cta, gsl style 

391 if ndim > 0: 

392 code.writeline("tiles_per_cta: int,") 

393 code.writeline("tile_size: tl.constexpr,") 

394 code.writeline("one_tile_per_cta: tl.constexpr,") 

395 code.writeline("):") 

396 

397 def gen_num_tiles(self, code): 

398 # tile-grid size 

399 ndim = self.ndim 

400 for i in range(ndim): 

401 if i < ndim: 

402 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})") 

403 

404 def gen_body_for_0d(self, code): 

405 schema = self.fx 

406 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

407 outputs_to_scalar_fn = [ 

408 self.output_name(i) for i in range(schema.num_output_tensors()) 

409 ] 

410 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

411 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

412 

413 code.writeline("# loads") 

414 for i in range(schema.num_input_tensors()): 

415 code.writeline( 

416 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) " 

417 "# workaround the bug on bool, we should use the pointer's dtype)" 

418 ) 

419 code.newline() 

420 

421 code.writeline("# compute") 

422 code.writeline( 

423 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

424 ) 

425 code.newline() 

426 

427 code.writeline("# stores") 

428 for i in range(schema.num_output_tensors()): 

429 code.writeline( 

430 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))" 

431 ) 

432 code.newline() 

433 return code 

434 

435 # nd tile 1d grid kernel with block pointer 

436 def gen_body_one_tile_per_cta_with_bptr(self, code): 

437 ndim = self.ndim 

438 schema = self.fx 

439 

440 # block pointer for each operand 

441 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim))) 

442 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim))) 

443 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim))) 

444 

445 # reconstruct pid multi index 

446 code.writeline( 

447 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

448 ) 

449 for i in reversed(range(ndim)): 

450 if i > 0: 

451 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

452 code.writeline(f"tile_id //= num_tiles{i}") 

453 else: 

454 code.writeline(f"tile_id{i} = tile_id") 

455 code.newline() 

456 

457 # cta_offsets 

458 code.writeline("# tile offsets") 

459 for i in range(ndim): 

460 # Or else: AssertionError: Block pointers only support 32 bit 

461 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing 

462 # for 64 bit support 

463 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)") 

464 

465 # loads 

466 code.writeline("# loads") 

467 for i in range(schema.num_input_tensors()): 

468 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim))) 

469 order = _tuple_content(tuple(f"in{i}_stride_order{j}" for j in range(ndim))) 

470 code.writeline( 

471 f"in{i}_bptr = tl.make_block_ptr(" 

472 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

473 ) 

474 code.writeline( 

475 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) " 

476 "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)" 

477 ) 

478 code.newline() 

479 

480 # compute 

481 # TODO: sepearate this part 

482 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

483 outputs_to_scalar_fn = [ 

484 self.output_name(i) for i in range(schema.num_output_tensors()) 

485 ] 

486 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

487 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

488 

489 code.writeline("# compute") 

490 code.writeline( 

491 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

492 ) 

493 code.newline() 

494 

495 # stores 

496 code.writeline( 

497 "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype" 

498 ) 

499 for i in range(schema.num_output_tensors()): 

500 strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim))) 

501 order = _tuple_content( 

502 tuple(f"out{i}_stride_order{j}" for j in range(ndim)) 

503 ) 

504 code.writeline( 

505 f"out{i}_bptr = tl.make_block_ptr(" 

506 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

507 ) 

508 code.writeline( 

509 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))" 

510 ) 

511 

512 def gen_body_gsl_with_bptr(self, code): 

513 code.writeline("num_ctas = tle.num_programs(0)") 

514 code.writeline("for j in range(0, tiles_per_cta):") 

515 with code.indent(): 

516 code.writeline("tile_id = pid + j * num_ctas") 

517 self.gen_body_one_tile_per_cta_with_bptr(code) 

518 

519 def gen_body_one_tile_per_cta_without_bptr(self, code): 

520 ndim = self.ndim 

521 schema = self.fx 

522 

523 # reconstruct pid multi index 

524 code.writeline( 

525 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

526 ) 

527 for i in reversed(range(ndim)): 

528 if i > 0: 

529 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

530 code.writeline(f"tile_id //= num_tiles{i}") 

531 else: 

532 code.writeline(f"tile_id{i} = tile_id") 

533 code.newline() 

534 

535 # offsets 

536 for i in range(ndim): 

537 code.writeline( 

538 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})" 

539 ) 

540 

541 # masks 

542 for i in range(ndim): 

543 code.writeline(f"mask{i} = offsets{i} < s{i}") 

544 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim)) 

545 mask_combine = " & ".join(masks) 

546 code.writeline(f"mask = {mask_combine}") 

547 

548 # loads 

549 code.writeline("# loads") 

550 for i in range(schema.num_input_tensors()): 

551 offsets = tuple( 

552 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}" 

553 for j in range(ndim) 

554 ) 

555 offset_combine = " + ".join(offsets) 

556 code.writeline( 

557 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

558 ) 

559 

560 code.newline() 

561 

562 # compute 

563 # TODO: sepearate this part 

564 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

565 outputs_to_scalar_fn = [ 

566 self.output_name(i) for i in range(schema.num_output_tensors()) 

567 ] 

568 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

569 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

570 

571 code.writeline("# compute") 

572 code.writeline( 

573 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

574 ) 

575 code.newline() 

576 

577 # stores 

578 for i in range(schema.num_output_tensors()): 

579 offsets = tuple( 

580 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}" 

581 for j in range(ndim) 

582 ) 

583 offset_combine = " + ".join(offsets) 

584 code.writeline( 

585 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

586 ) 

587 

588 def gen_body_gsl_without_bptr(self, code): 

589 code.writeline("num_ctas = tle.num_programs(0)") 

590 code.writeline("for j in range(0, tiles_per_cta):") 

591 with code.indent(): 

592 code.writeline("tile_id = pid + j * num_ctas") 

593 self.gen_body_one_tile_per_cta_without_bptr(code) 

594 

595 def codegen_nd_tile_with_bptr(self, code): 

596 """Generate kernel nd tile & 1d grid with gsl support with block pointer.""" 

597 self.gen_import_function(code) 

598 self.gen_decorators(code) 

599 self.gen_signature(code, with_block_pointer=True) 

600 

601 # function body for rank-0 

602 if self.ndim == 0: 

603 with code.indent(): 

604 self.gen_body_for_0d(code) 

605 return code 

606 

607 with code.indent(): 

608 code.writeline("pid = tle.program_id(0)") 

609 self.gen_num_tiles(code) 

610 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

611 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

612 with code.indent(): 

613 code.writeline("tile_id = pid") 

614 self.gen_body_one_tile_per_cta_with_bptr(code) 

615 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

616 code.writeline("else: # grid-stride-loop style kernel") 

617 with code.indent(): 

618 self.gen_body_gsl_with_bptr(code) 

619 code.newline() 

620 return code 

621 

622 def codegen_nd_tile_without_bptr(self, code): 

623 self.gen_import_function(code) 

624 self.gen_decorators(code) 

625 self.gen_signature(code, with_block_pointer=False) 

626 

627 # function body for rank-0 

628 if self.ndim == 0: 

629 with code.indent(): 

630 self.gen_body_for_0d(code) 

631 return code 

632 

633 with code.indent(): 

634 code.writeline("pid = tle.program_id(0)") 

635 self.gen_num_tiles(code) 

636 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

637 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

638 with code.indent(): 

639 code.writeline("tile_id = pid") 

640 self.gen_body_one_tile_per_cta_without_bptr(code) 

641 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

642 code.writeline("else: # grid-stride-loop style kernel") 

643 with code.indent(): 

644 self.gen_body_gsl_without_bptr(code) 

645 code.newline() 

646 return code 

647 

648 def codegen_nd_tile(self, code): 

649 use_block_pointer = self.config.prefer_block_pointer 

650 if use_block_pointer: 

651 self.codegen_nd_tile_with_bptr(code) 

652 else: 

653 self.codegen_nd_tile_without_bptr(code) 

654 return code 

655 

656 def gen_body_one_tile_per_cta_1d_tile(self, code): 

657 ndim = self.ndim 

658 schema = self.fx 

659 

660 # tile id 

661 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)") 

662 code.writeline("mask = tid < num_tasks") 

663 

664 # multi index reconstruction 

665 for i in reversed(range(ndim)): 

666 if i > 0: 

667 code.writeline(f"i{i} = tid % s{i}") 

668 code.writeline(f"tid //= s{i}") 

669 else: 

670 code.writeline(f"i{i} = tid") 

671 code.newline() 

672 

673 # loads 

674 code.writeline("# loads") 

675 for i in range(schema.num_input_tensors()): 

676 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim)) 

677 offset_combine = " + ".join(offsets) 

678 code.writeline( 

679 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

680 ) 

681 

682 code.newline() 

683 

684 # compute 

685 # TODO: sepearate this part 

686 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

687 outputs_to_scalar_fn = [ 

688 self.output_name(i) for i in range(schema.num_output_tensors()) 

689 ] 

690 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

691 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

692 

693 code.writeline("# compute") 

694 code.writeline( 

695 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

696 ) 

697 code.newline() 

698 

699 # stores 

700 for i in range(schema.num_output_tensors()): 

701 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim)) 

702 offset_combine = " + ".join(offsets) 

703 code.writeline( 

704 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

705 ) 

706 

707 def gen_body_gsl_1d_tile(self, code): 

708 code.writeline("num_ctas = tle.num_programs(0)") 

709 code.writeline("for j in range(0, tiles_per_cta):") 

710 with code.indent(): 

711 code.writeline("tile_id = pid + j * num_ctas") 

712 self.gen_body_one_tile_per_cta_1d_tile(code) 

713 

714 def codegen_1d_tile(self, code): 

715 """Generate kernel 1d tile & 1d grid with gsl support.""" 

716 self.gen_import_function(code) 

717 self.gen_decorators(code) 

718 self.gen_signature_1d_tile(code) 

719 

720 # function body for rank-0 

721 if self.ndim == 0: 

722 with code.indent(): 

723 self.gen_body_for_0d(code) 

724 return code 

725 

726 with code.indent(): 

727 code.writeline("pid = tle.program_id(0)") 

728 # code.writeline("num_ctas = te.num_programs(0)") 

729 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

730 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

731 with code.indent(): 

732 code.writeline("tile_id = pid") 

733 self.gen_body_one_tile_per_cta_1d_tile(code) 

734 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

735 code.writeline("else: # grid-stride-loop style kernel") 

736 with code.indent(): 

737 self.gen_body_gsl_1d_tile(code) 

738 code.newline() 

739 return code 

740 

741 

742class WrapperGenerator: 

743 def __init__( 

744 self, 

745 function_schema: FunctionSchema, 

746 jit_fn_name: str, 

747 ndim: int, 

748 name: str, 

749 config: CodeGenConfig, 

750 ): 

751 self.fx = function_schema 

752 self.jit_fn_name = jit_fn_name 

753 self.ndim = ndim 

754 self.name = name 

755 self.config = config 

756 

757 def input_name(self, i): 

758 is_tensor = self.fx.is_tensor(i) 

759 name = "in" if is_tensor else "val" 

760 index = self.fx.input_index(i) 

761 return f"{name}{index}" 

762 

763 def output_name(self, i): 

764 return f"out{i}" 

765 

766 def gen_signature(self, code: IndentedBuffer): 

767 # TODO: check if triton handles constexprs transitively 

768 schema = self.fx 

769 params: List[str] = [] 

770 for i in range(schema.num_inputs()): 

771 if schema.is_tensor(i): 

772 params.append( 

773 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]" 

774 ) 

775 else: 

776 arg_type = schema.input_type(i) 

777 if arg_type is not None: 

778 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}") 

779 else: 

780 params.append(f"{self.input_name(i)}") 

781 # NOTE: [the wrapper's signature and rules for passing parameters ] 

782 # input params: must be passed by position, since the names are renamed to 

783 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd 

784 # So we enforce that these parameters must be passed by position. 

785 # maybe we can fix it later 

786 # output parameters: must be passed by keyword, since the scalar function 

787 # do not have output parameters(think of it as some scalar function, output 

788 # parameter does not make sense in this case.) They are added to allow destination 

789 # passing style API. Output parameter is convenient in cases where we want 

790 # to use some pre-defiend outputs(especially when they are some views of other 

791 # tensors). We emphasize that these parameters are added in-addition, we enforce 

792 # that they be passed by keyword. After all, out0, out1, ... does not mismatch 

793 # names form the scalar function, since it does not have output parameters. 

794 params.append("/") 

795 params.append("*") # output params must be passed by keyword 

796 

797 for i in range(schema.num_output_tensors()): 

798 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]") 

799 code.writeline(f"def {self.name}({_cs(params)}): ") 

800 

801 def gen_docstring(self, code: IndentedBuffer): 

802 schema = self.fx 

803 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""' 

804 code.writeline(doc) 

805 

806 def gen_same_shape_check(self, code: IndentedBuffer): 

807 schema: FunctionSchema = self.fx 

808 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [ 

809 f"out{i}.shape" for i in range(schema.num_output_tensors()) 

810 ] 

811 check: str = " == ".join(params) 

812 code.writeline(f"assert {check}, 'operand shapes mismatch'") 

813 

814 def gen_task_partition(self, code: IndentedBuffer): 

815 code.writeline("# task partitioning") 

816 ndim = self.ndim 

817 if ndim == 0: 

818 code.writeline("num_warps = 1") 

819 code.writeline("num_ctas = 1") 

820 else: 

821 code.writeline("shape = out0.shape") 

822 code.writeline("num_tasks = out0.numel()") 

823 code.writeline("if num_tasks == 0:") 

824 with code.indent(): 

825 self.gen_return(code) 

826 max_tile_size = self.config.max_tile_size 

827 major, _ = get_device_capability() 

828 if self.name.find("fill_scalar") != -1 and major >= 9: 

829 code.writeline("tile_sizes = tuple([64])") 

830 else: 

831 code.writeline( 

832 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)" 

833 ) 

834 code.writeline("tile_size = math.prod(tile_sizes)") 

835 code.writeline( 

836 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))" 

837 ) 

838 

839 if self.name.find("fill_scalar") != -1 and major >= 9: 

840 code.writeline("num_ctas = num_tiles") 

841 else: 

842 max_grid_size0 = self.config.max_grid_size[0] 

843 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

844 

845 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

846 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

847 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

848 code.writeline("grid = (num_ctas, 1, 1)") 

849 

850 def gen_task_partition_1d(self, code: IndentedBuffer): 

851 code.writeline("# task partitioning") 

852 ndim = self.ndim 

853 if ndim == 0: 

854 code.writeline("num_warps = 1") 

855 code.writeline("num_ctas = 1") 

856 else: 

857 code.writeline("shape = out0.shape") 

858 code.writeline("num_tasks = out0.numel()") 

859 code.writeline("if num_tasks == 0:") 

860 with code.indent(): 

861 self.gen_return(code) 

862 max_tile_size = self.config.max_tile_size 

863 

864 major, _ = get_device_capability() 

865 if self.name.find("fill_scalar") != -1 and major >= 9: 

866 code.writeline("tile_sizes = tuple([64])") 

867 else: 

868 code.writeline( 

869 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)" 

870 ) 

871 

872 code.writeline("tile_size = tile_sizes[0]") 

873 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)") 

874 

875 if self.name.find("fill_scalar") != -1 and major >= 9: 

876 code.writeline("num_ctas = num_tiles") 

877 else: 

878 max_grid_size0 = self.config.max_grid_size[0] 

879 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

880 

881 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

882 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

883 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

884 code.writeline("grid = (num_ctas, 1, 1)") 

885 

886 def gen_kernel_launch( 

887 self, 

888 code: IndentedBuffer, 

889 ): 

890 schema = self.fx 

891 ndim = self.ndim 

892 

893 with_block_pointer = self.config.prefer_block_pointer 

894 

895 code.writeline("# kernel launch") 

896 for i in range(schema.num_input_tensors()): 

897 code.writeline(f"in{i}_strides = in{i}.stride()") 

898 if not with_block_pointer: 

899 continue 

900 if ndim >= 2: # where ndim is 1, we don't need to compute stride order 

901 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)") 

902 else: 

903 code.writeline(f"in{i}_stride_order = (0,)") 

904 for i in range(schema.num_output_tensors()): 

905 code.writeline(f"out{i}_strides = out{i}.stride()") 

906 if not with_block_pointer: 

907 continue 

908 if ndim >= 2: 

909 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)") 

910 else: 

911 code.writeline(f"out{i}_stride_order = (0,)") 

912 

913 code.writeline("with torch_device_fn.device(in0.device.index):") 

914 with code.indent(): 

915 code.writeline(f"{self.jit_fn_name}[grid](") 

916 with code.indent(): 

917 params = [] 

918 # NOTE: WRAP 

919 for i in range(schema.num_inputs()): 

920 if schema.is_tensor(i): 

921 params.append(f"{self.input_name(i)}") 

922 else: 

923 params.append(self.input_name(i)) 

924 for i in range(schema.num_output_tensors()): 

925 params.append(f"{self.output_name(i)}") 

926 

927 code.writeline(f"{_cs(params)},") 

928 

929 if ndim > 0: 

930 for i in range(schema.num_input_tensors()): 

931 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

932 code.writeline(f"{s}, # stride for in{i}") 

933 if not with_block_pointer: 

934 continue 

935 order = ", ".join( 

936 f"in{i}_stride_order[{j}]" for j in range(ndim) 

937 ) 

938 code.writeline(f"{order}, # stride order for in{i}") 

939 

940 for i in range(schema.num_output_tensors()): 

941 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

942 code.writeline(f"{s}, # stride for out{i}") 

943 if not with_block_pointer: 

944 continue 

945 order = ", ".join( 

946 f"out{i}_stride_order[{j}]" for j in range(ndim) 

947 ) 

948 code.writeline(f"{order}, # stride orderfor out{i}") 

949 

950 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

951 code.writeline(f"{shape_args}, # task indexing space") 

952 code.writeline("num_tasks, # num tasks") 

953 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

954 for i in range(ndim): 

955 code.writeline(f"tile_size{i}=tile_sizes[{i}],") 

956 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

957 code.writeline("num_warps=num_warps,") 

958 code.writeline(")") 

959 

960 def gen_kernel_launch_1d( 

961 self, 

962 code: IndentedBuffer, 

963 ): 

964 schema = self.fx 

965 ndim = self.ndim 

966 

967 code.writeline("# kernel launch") 

968 for i in range(schema.num_input_tensors()): 

969 code.writeline(f"in{i}_strides = in{i}.stride()") 

970 for i in range(schema.num_output_tensors()): 

971 code.writeline(f"out{i}_strides = out{i}.stride()") 

972 

973 code.writeline("with torch_device_fn.device(in0.device.index):") 

974 with code.indent(): 

975 code.writeline(f"{self.jit_fn_name}[grid](") 

976 with code.indent(): 

977 params = [] 

978 # NOTE: WRAP 

979 for i in range(schema.num_inputs()): 

980 if schema.is_tensor(i): 

981 params.append(f"{self.input_name(i)}") 

982 else: 

983 params.append(self.input_name(i)) 

984 for i in range(schema.num_output_tensors()): 

985 params.append(f"{self.output_name(i)}") 

986 

987 code.writeline(f"{_cs(params)},") 

988 

989 if ndim > 0: 

990 for i in range(schema.num_input_tensors()): 

991 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

992 code.writeline(f"{s}, # stride for in{i}") 

993 for i in range(schema.num_output_tensors()): 

994 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

995 code.writeline(f"{s}, # stride for out{i}") 

996 

997 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

998 code.writeline(f"{shape_args}, # task indexing space") 

999 code.writeline("num_tasks, # num tasks") 

1000 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

1001 code.writeline("tile_size=tile_size,") 

1002 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

1003 code.writeline("num_warps=num_warps,") 

1004 code.writeline(")") 

1005 

1006 def gen_return(self, code: IndentedBuffer): 

1007 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors())) 

1008 code.writeline(f"return {return_exprs}") 

1009 

1010 def codegen_nd_tile(self, code): 

1011 self.gen_signature(code) 

1012 

1013 with code.indent(): 

1014 self.gen_docstring(code) 

1015 self.gen_same_shape_check(code) 

1016 self.gen_task_partition(code) 

1017 self.gen_kernel_launch(code) 

1018 self.gen_return(code) 

1019 code.newline() 

1020 return code 

1021 

1022 def codegen_1d_tile(self, code): 

1023 self.gen_signature(code) 

1024 

1025 with code.indent(): 

1026 self.gen_docstring(code) 

1027 self.gen_same_shape_check(code) 

1028 self.gen_task_partition_1d(code) 

1029 self.gen_kernel_launch_1d(code) 

1030 self.gen_return(code) 

1031 code.newline() 

1032 return code 

1033 

1034 

1035class ModuleGenerator: 

1036 def __init__( 

1037 self, 

1038 function_schema: FunctionSchema, 

1039 scalar_fn: triton.JITFunction, 

1040 ndim: int, 

1041 jit_fn_name: str, 

1042 wrapper_name: str, 

1043 config: CodeGenConfig, 

1044 ): 

1045 self.config = config 

1046 self.wrapper_gen = WrapperGenerator( 

1047 function_schema, jit_fn_name, ndim, wrapper_name, config 

1048 ) 

1049 self.kernel_gen = KernelGenerator( 

1050 function_schema, scalar_fn, ndim, jit_fn_name, config 

1051 ) 

1052 

1053 @staticmethod 

1054 def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

1055 code.writeline("import math") 

1056 code.writeline("from typing import Union") 

1057 code.writeline("import torch") 

1058 code.writeline("import triton") 

1059 code.writeline("from triton import language as tl") 

1060 code.newline() 

1061 code.writeline("from flag_gems.utils.shape_utils import (") 

1062 code.writeline(" heuristics_for_tile_size,") 

1063 code.writeline(" heuristics_for_num_warps,") 

1064 code.writeline(" stride_order,") 

1065 code.writeline(")") 

1066 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer") 

1067 code.writeline("from flag_gems.utils.libentry import libentry") 

1068 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

1069 code.writeline("from flag_gems.runtime import torch_device_fn") 

1070 code.newline() 

1071 code.newline() 

1072 return code 

1073 

1074 def codegen(self, code: IndentedBuffer): 

1075 # the only runtime determined factor is the rank of the task space 

1076 code = self.generate_imports(code) 

1077 if self.config.prefer_1d_tile: 

1078 code = self.wrapper_gen.codegen_1d_tile(code) 

1079 code = self.kernel_gen.codegen_1d_tile(code) 

1080 else: 

1081 code = self.wrapper_gen.codegen_nd_tile(code) 

1082 code = self.kernel_gen.codegen_nd_tile(code) 

1083 return code 

1084 

1085 

1086@dataclass 

1087class KernelInfo: 

1088 """Information about a generated kernel for C++ integration.""" 

1089 

1090 file_path: str 

1091 kernel_name: str 

1092 wrapper_name: str 

1093 ndim: int 

1094 

1095 

1096class PointwiseDynamicFunction: 

1097 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

1098 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

1099 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

1100 """ 

1101 

1102 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None): 

1103 self.fx = op_desc 

1104 

1105 assert isinstance(scalar_fn, JITFunction) 

1106 self._scalar_fn = scalar_fn 

1107 self._scalar_fn_cache_key = scalar_fn.cache_key 

1108 self.pid = os.getpid() 

1109 

1110 self.config: CodeGenConfig = config or get_codegen_config() 

1111 

1112 # instantiated & cached overloads 

1113 self.overloads: Mapping[str, Callable] = {} 

1114 # cached kernel info for C++ integration 

1115 self._kernel_info_cache: Mapping[str, KernelInfo] = {} 

1116 

1117 def __call__(self, *args, **kwargs): 

1118 # inputs must be passed by position, outputs must be passed by keyword 

1119 ndim, args, kwargs = self.prepare_args(*args, **kwargs) 

1120 overload = self.instantiate(ndim) 

1121 out = overload(*args, **kwargs) 

1122 # NOTE: overload keeps the type of outputs: 

1123 # if a pre-defiend output is a Tensor or StridedBuffer, the corresponding 

1124 # output is also a Tensor StridedBuffer, respectively 

1125 # since prepare_args Wraps all the arguments, the outputs are all StridedBuffer 

1126 # but if manually instantiated overload is directly called, take care of 

1127 # that manually 

1128 return self._unwrap(out) 

1129 

1130 @staticmethod 

1131 def use_fast_path(tensors): 

1132 return all_the_same_shape(tensors) and ( 

1133 all_c_contiguous(tensors) 

1134 or ( 

1135 all_the_same_stride(tensors) 

1136 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0]) 

1137 ) 

1138 ) 

1139 

1140 def prepare_args(self, *args, **kwargs): 

1141 # output allocation(when needed) 

1142 # task simplification & task-rank infernece & input-output reinterpretation 

1143 schema = self.fx 

1144 outputs_that_need_allocation: List[int] = [] 

1145 out_tensors = [] 

1146 for i in range(schema.num_output_tensors()): 

1147 k = f"out{i}" 

1148 if k in kwargs: 

1149 out_tensors.append(kwargs[k]) 

1150 else: 

1151 outputs_that_need_allocation.append(i) 

1152 # input arguments must be passed by position 

1153 if schema._is_tensor is not None: 

1154 if not check_tensor_attributes(args, (schema._is_tensor)): 

1155 raise ValueError( 

1156 "Input arguments must be passed by position, and the corresponding dtype must be specified." 

1157 ) 

1158 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)] 

1159 

1160 # output dtype promotions 

1161 outputs_dtypes_for_allocation = [] 

1162 for i in outputs_that_need_allocation: 

1163 *arg_indices, method = schema._promotion_methods[i] 

1164 promote_args = (args[j] for j in arg_indices) 

1165 _, dtype = type_promotion(*promote_args, type_promotion=method) 

1166 outputs_dtypes_for_allocation.append(dtype) 

1167 

1168 tensors = out_tensors + in_tensors 

1169 INT32_MAX = torch.iinfo(torch.int32).max 

1170 if tensors[0].numel() > INT32_MAX: 

1171 self.config.prefer_block_pointer = False 

1172 if self.use_fast_path(tensors): # dimension collapse & use physical ordering 

1173 allocated_outputs = [ 

1174 torch.empty_like(tensors[0], dtype=dtype) 

1175 for dtype in outputs_dtypes_for_allocation 

1176 ] 

1177 task_shape = (tensors[0].numel(),) 

1178 strides = (1,) 

1179 ndim = 1 

1180 args = tuple( 

1181 ( 

1182 StridedBuffer(item, task_shape, strides) 

1183 if schema.is_tensor(i) 

1184 else item 

1185 ) 

1186 for i, item in enumerate(args) 

1187 ) 

1188 kwargs = { 

1189 k: StridedBuffer(item, task_shape, strides) 

1190 for k, item in kwargs.items() 

1191 } 

1192 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1193 kwargs[f"out{output_id}"] = StridedBuffer( 

1194 allocated_outputs[seq_id], task_shape, strides 

1195 ) 

1196 else: 

1197 # a simple strategy: all the undefined tensors will follow the first 

1198 # tensor that is not broadcated, no attempts to simplify task, no reordering, 

1199 # no dimenion collapsing 

1200 shapes = tuple(item.shape for item in in_tensors) 

1201 

1202 task_shape = broadcast_shapes(shapes) 

1203 

1204 if out_tensors: 

1205 for index, item in enumerate(out_tensors): 

1206 if list(item.shape) != list(task_shape): 

1207 raise RuntimeError( 

1208 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!" 

1209 ) 

1210 # output arguments must not have internal overlapping for pointwise operation 

1211 if has_internal_overlapping(item) == MemOverlap.Yes: 

1212 raise RuntimeError( 

1213 "Pointwise Input arguments should not have internal overlapping." 

1214 ) 

1215 

1216 ndim = len(task_shape) 

1217 for item in tensors: 

1218 if item.shape == task_shape: 

1219 allocated_outputs = [ 

1220 torch.empty_like(item, dtype=dtype) 

1221 for dtype in outputs_dtypes_for_allocation 

1222 ] 

1223 break 

1224 else: # nobreak 

1225 device = tensors[0].device 

1226 allocated_outputs = [ 

1227 torch.empty(task_shape, dtype=dtype, device=device) 

1228 for dtype in outputs_dtypes_for_allocation 

1229 ] 

1230 args = tuple( 

1231 ( 

1232 StridedBuffer( 

1233 item, 

1234 task_shape, 

1235 broadcasted_stride(item.shape, item.stride(), task_shape), 

1236 ) 

1237 if schema.is_tensor(i) 

1238 else item 

1239 ) 

1240 for i, item in enumerate(args) 

1241 ) 

1242 kwargs = { 

1243 k: StridedBuffer( 

1244 item, 

1245 task_shape, 

1246 broadcasted_stride(item.shape, item.stride(), task_shape), 

1247 ) 

1248 for k, item in kwargs.items() 

1249 } 

1250 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1251 item = allocated_outputs[seq_id] 

1252 kwargs[f"out{output_id}"] = StridedBuffer( 

1253 item, 

1254 task_shape, 

1255 broadcasted_stride(item.shape, item.stride(), task_shape), 

1256 ) 

1257 return (ndim, args, kwargs) 

1258 

1259 def _unwrap(self, tensors): 

1260 # unwrap StridedBuffer to get Tensor 

1261 if self.fx.num_output_tensors() == 1: 

1262 item = tensors 

1263 return item.unwrap() 

1264 return tuple(item.unwrap() for item in tensors) 

1265 

1266 def _compute_kernel_names(self, ndim: int) -> Tuple[str, str, str]: 

1267 """Compute kernel name, wrapper name, and file path for a given ndim. 

1268 

1269 This is the single source of truth for naming, used by both instantiate() 

1270 and get_kernel_info() to ensure consistency. 

1271 

1272 Returns: 

1273 Tuple of (kernel_name, wrapper_name, file_path) 

1274 """ 

1275 scalar_fn_name = self._scalar_fn.__name__ 

1276 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}" 

1277 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}" 

1278 

1279 file_name = ( 

1280 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_" 

1281 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}" 

1282 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}" 

1283 ".py" 

1284 ) 

1285 file_path = str(code_cache_dir() / file_name) 

1286 

1287 return kernel_name, wrapper_name, file_path 

1288 

1289 def instantiate(self, ndim): 

1290 # NOTE: manually instantiated overload does not have `prepare_args` as 

1291 # preprocessing, so you have to manually allocate output and make sure that 

1292 # the inputs & ouputs actually fits the manually instantiated overload 

1293 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1294 if key in self.overloads: 

1295 return self.overloads[key] 

1296 

1297 code = IndentedBuffer() 

1298 

1299 # Use helper to compute names (single source of truth) 

1300 kernel_name, wrapper_name, file_path = self._compute_kernel_names(ndim) 

1301 

1302 module_gen = ModuleGenerator( 

1303 self.fx, 

1304 self._scalar_fn, 

1305 ndim, 

1306 kernel_name, 

1307 wrapper_name, 

1308 self.config, 

1309 ) 

1310 module_gen.codegen(code) 

1311 

1312 # NOTE: [why write the generated code to a file] 

1313 # triton uses inpsect to get the source of the jitted function, which requires 

1314 # that the source code can be found by inspect 

1315 # We write it into a file, since inspect cannot find the source of functions dynamically 

1316 # created via exec string. We can help inspect to find the source by hacking linecache 

1317 # library, but we find generating a module simpler, since we can generating 2 functions 

1318 # the kernel and the wrapper, and the wrapper calls the kernel. 

1319 write_atomic(file_path, code.getvalue()) 

1320 

1321 # load 

1322 spec = importlib.util.spec_from_file_location( 

1323 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}", 

1324 file_path, 

1325 ) 

1326 m = importlib.util.module_from_spec(spec) 

1327 # do not expose it to sys.modules 

1328 # sys.modules["_add_module"] = m 

1329 

1330 # NOTE: [why not import the scalar function] 

1331 # we do not re-import the scalar function, although the generated kernel **calls** it 

1332 # Since a function's __name__ may be changed, from the module where it is defined import its 

1333 # __name__ is not same; Also the same may be rebind to something else, importing via name 

1334 # cannot guarantee that scalar function is imported. 

1335 # So we copy the scalar function and its __globals__ to the generated module to do this 

1336 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime 

1337 spec.loader.exec_module(m) 

1338 m.__dict__.update(self._scalar_fn.__globals__) 

1339 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn 

1340 

1341 overload = getattr(m, wrapper_name) 

1342 self.overloads[key] = overload 

1343 

1344 # Cache kernel info for C++ integration 

1345 self._kernel_info_cache[key] = KernelInfo( 

1346 file_path=file_path, 

1347 kernel_name=kernel_name, 

1348 wrapper_name=wrapper_name, 

1349 ndim=ndim, 

1350 ) 

1351 

1352 return overload 

1353 

1354 def get_kernel_info(self, ndim: int) -> KernelInfo: 

1355 """Get kernel information for a given ndim. 

1356 

1357 This method is useful for C++ integration to get the file path and 

1358 kernel name without duplicating the naming logic. 

1359 

1360 If the kernel hasn't been instantiated yet, this will instantiate it first. 

1361 

1362 Args: 

1363 ndim: The rank of the task space 

1364 

1365 Returns: 

1366 KernelInfo with file_path, kernel_name, wrapper_name, and ndim 

1367 """ 

1368 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1369 

1370 # Ensure the kernel is instantiated 

1371 if key not in self._kernel_info_cache: 

1372 self.instantiate(ndim) 

1373 

1374 return self._kernel_info_cache[key] 

1375 

1376 

1377def pointwise_dynamic( 

1378 f: Optional[JITFunction] = None, 

1379 *, 

1380 num_inputs: Optional[int] = None, 

1381 is_tensor: Optional[List[bool]] = None, 

1382 dtypes: Optional[List[Optional[type]]] = None, 

1383 num_outputs: Optional[int] = None, 

1384 promotion_methods: Optional[Tuple[int, ...]] = None, 

1385 config: Optional[CodeGenConfig] = None, 

1386): 

1387 def decorator(fn): 

1388 nonlocal num_inputs 

1389 if (num_inputs is None) and (is_tensor is None) and (dtypes is None): 

1390 num_inputs = len(fn.arg_names) 

1391 op_desc = FunctionSchema( 

1392 num_inputs=num_inputs, 

1393 is_tensor=is_tensor, 

1394 dtypes=dtypes, 

1395 num_outputs=num_outputs, 

1396 promotion_methods=promotion_methods, 

1397 ) 

1398 return PointwiseDynamicFunction(op_desc, fn, config) 

1399 

1400 if f is not None: 

1401 return decorator(f) 

1402 return decorator