Coverage for src/flag_gems/utils/pointwise_dynamic.py: 96%

788 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import importlib 

2import os 

3from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple 

4 

5import torch 

6import triton 

7from triton.runtime.jit import JITFunction 

8 

9from flag_gems.utils.code_cache import code_cache_dir 

10from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

11from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config 

12from flag_gems.utils.device_info import get_device_capability 

13from flag_gems.utils.shape_utils import ( 

14 MemOverlap, 

15 all_c_contiguous, 

16 all_the_same_shape, 

17 all_the_same_stride, 

18 broadcast_shapes, 

19 broadcasted_stride, 

20 check_tensor_attributes, 

21 has_internal_overlapping, 

22) 

23from flag_gems.utils.tensor_wrapper import StridedBuffer 

24from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion 

25 

26 

27# ------------------ Operation Description --------------------------- 

28def _type_name(type) -> str: 

29 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object" 

30 if type in (bool, int, float, str): 

31 return type.__name__ 

32 if isinstance(type, torch.dtype): 

33 return str(type) 

34 return str(type) 

35 

36 

37def _check_typed_list(container, type): 

38 for item in container: 

39 assert isinstance(item, type) 

40 

41 

42def _check_sized_list(container, size): 

43 assert len(container) == size 

44 

45 

46def _tuple_content(strings: Sequence[str]) -> str: 

47 # comma separated list 

48 if len(strings) == 0: 

49 return "" 

50 if len(strings) == 1: 

51 return f"{strings[0]}," 

52 else: 

53 return ", ".join(strings) 

54 

55 

56def _cs(strings: Iterable[str]) -> str: 

57 return ", ".join(strings) 

58 

59 

60def _broadcast_vec(i, ndim): 

61 axes = [":" if j == i else "None" for j in range(ndim)] 

62 return f"[{_cs(axes)}]" 

63 

64 

65class FunctionSchema: 

66 _num_inputs: int 

67 _is_tensor: List[bool] 

68 _dtypes: List[Optional[type]] 

69 

70 _num_input_tensors: int 

71 _num_non_tensor_inputs: int 

72 

73 _num_outputs: int 

74 _promotion_methods: List[Tuple[int, ...]] 

75 

76 def __init__( 

77 self, 

78 *, 

79 num_inputs: Optional[int] = None, 

80 is_tensor: Optional[List[bool]] = None, 

81 dtypes: Optional[List[Optional[type]]] = None, 

82 num_outputs: Optional[int] = None, 

83 promotion_methods=None, 

84 ): 

85 if is_tensor is not None: 

86 _check_typed_list(is_tensor, bool) 

87 if dtypes is not None: 

88 _check_typed_list(dtypes, (type, type(None))) 

89 

90 if promotion_methods is None: 

91 raise ValueError( 

92 "No type promotion method provided! You must provide type promotion method for each output!" 

93 ) 

94 else: 

95 self._promotion_methods = self.canonicalize_promotion_methods( 

96 promotion_methods 

97 ) 

98 if num_inputs is not None: 

99 self._num_inputs = num_inputs 

100 if is_tensor is not None: 

101 _check_sized_list(is_tensor, num_inputs) 

102 self._is_tensor = is_tensor 

103 else: 

104 self._is_tensor = [True] * num_inputs 

105 

106 if dtypes is not None: 

107 _check_sized_list(dtypes, num_inputs) 

108 self._dtypes = dtypes 

109 else: 

110 self._dtypes = [None] * num_inputs 

111 elif is_tensor is not None: 

112 self._num_inputs = len(is_tensor) 

113 self._is_tensor = is_tensor 

114 if dtypes is not None: 

115 _check_sized_list(dtypes, self._num_inputs) 

116 self._dtypes = dtypes 

117 else: 

118 self._dtypes = [None] * self._num_inputs 

119 elif dtypes is not None: 

120 self._num_inputs = len(dtypes) 

121 self._dtypes = dtypes 

122 if is_tensor is not None: 

123 _check_sized_list(is_tensor, self._num_inputs) 

124 self._is_tensor = is_tensor 

125 else: 

126 self._is_tensor = [item is None for item in dtypes] 

127 else: 

128 raise ValueError( 

129 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified." 

130 ) 

131 

132 if num_outputs is not None: 

133 self._num_outputs = num_outputs 

134 _check_sized_list(promotion_methods, num_outputs) 

135 else: 

136 self._num_outputs = len(promotion_methods) 

137 

138 assert self._num_inputs >= 1 

139 assert self._num_outputs >= 1 

140 

141 self._num_input_tensors = sum(self._is_tensor) 

142 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors 

143 self._input_id = self._compute_input_id() 

144 

145 @staticmethod 

146 def canonicalize_promotion_methods(promotion_methods): 

147 canonicalized = [] 

148 for item in promotion_methods: 

149 *arg_indices, method = item 

150 canonicalized.append( 

151 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method]) 

152 ) 

153 return canonicalized 

154 

155 def num_inputs(self): 

156 # num of arguments, outputs not included 

157 return self._num_inputs 

158 

159 def num_outputs(self): 

160 return self._num_outputs 

161 

162 def is_tensor(self, arg_id: int) -> bool: 

163 return self._is_tensor[arg_id] 

164 

165 def input_type(self, arg_id) -> Optional[type]: 

166 return self._dtypes[arg_id] 

167 

168 def output_type(self, i): 

169 return self._promotion_methods[i] 

170 

171 def num_input_tensors(self) -> int: 

172 return self._num_input_tensors 

173 

174 def num_output_tensors(self) -> int: 

175 return self._num_outputs 

176 

177 def num_non_tensor_args(self) -> int: 

178 return self._num_non_tensor_inputs 

179 

180 def signature(self, outputs_in_arg: bool = False) -> str: 

181 input_types = [] 

182 for is_tensor, dtype in zip(self._is_tensor, self._dtypes): 

183 if is_tensor: 

184 input_types.append("StridedBuffer") 

185 else: 

186 if dtype is None: 

187 input_types.append("scalar") 

188 else: 

189 input_types.append(_type_name(dtype)) 

190 

191 output_types = [] 

192 

193 if outputs_in_arg: 

194 for i in range(self.num_outputs()): 

195 output_types.append(f"StridedBuffer(a{1}!)") 

196 input_types.extend(output_types) 

197 else: 

198 for _ in range(self.num_outputs()): 

199 output_types.append("StridedBuffer") 

200 sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}' 

201 return sig 

202 

203 def _compute_input_id(self): 

204 input_tensor_index = 0 

205 non_tensor_index = 0 

206 mapping: List[int] = [] 

207 for i in range(self.num_inputs()): 

208 if self.is_tensor(i): 

209 mapping.append(input_tensor_index) 

210 input_tensor_index += 1 

211 else: 

212 mapping.append(non_tensor_index) 

213 non_tensor_index += 1 

214 return mapping 

215 

216 def input_index(self, idx): 

217 return self._input_id[idx] 

218 

219 def __str__(self) -> str: 

220 return self.signature(outputs_in_arg=False) 

221 

222 

223class KernelGenerator: 

224 def __init__( 

225 self, 

226 function_schema: FunctionSchema, 

227 scalar_fn: triton.JITFunction, 

228 rank: int, 

229 name: str, 

230 config: CodeGenConfig, 

231 ): 

232 self.fx = function_schema 

233 self.fn = scalar_fn 

234 self.ndim = rank 

235 self.name = name 

236 self.config = config 

237 

238 self.fn_name = scalar_fn.__name__ 

239 self.fn_module = scalar_fn.__module__ 

240 

241 def gen_import_function(self, code: IndentedBuffer): 

242 code.writeline(f'"""Quoted source of {self.fn_name}:') 

243 code.writemultiline(self.fn.src) 

244 code.writeline('"""') 

245 code.newline() 

246 

247 def gen_decorators(self, code): 

248 code.writeline("@libentry()") 

249 num_non_tensor_args = self.fx.num_non_tensor_args() 

250 if num_non_tensor_args > 0: 

251 # we do not specialize non tensor args since they are passed into the inlined function 

252 # which means that their values may not deserve specialization 

253 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)] 

254 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

255 else: 

256 code.writeline("@triton.jit") 

257 

258 def input_name(self, i): 

259 is_tensor = self.fx.is_tensor(i) 

260 name = "in" if is_tensor else "val" 

261 index = self.fx.input_index(i) 

262 return f"{name}{index}" 

263 

264 def output_name(self, i): 

265 return f"out{i}" 

266 

267 def gen_signature(self, code, with_block_pointer=False): 

268 code.writeline(f"def {self.name}(") 

269 with code.indent(): 

270 input_tensor_index = 0 

271 non_tensor_index = 0 

272 output_tensor_index = 0 

273 

274 schema = self.fx 

275 # signature: inputs ptrs & non tensor inputs 

276 for i in range(schema.num_inputs()): 

277 if schema.is_tensor(i): 

278 code.writeline( 

279 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

280 ) 

281 input_tensor_index += 1 

282 else: 

283 if schema.input_type(i) is not None: 

284 code.writeline( 

285 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

286 ) 

287 else: 

288 code.writeline(f"val{non_tensor_index},") 

289 non_tensor_index += 1 

290 

291 # signature: output ptrs 

292 for i in range(schema.num_outputs()): 

293 code.writeline( 

294 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

295 ) 

296 output_tensor_index += 1 

297 

298 # signature: strides, for each tensor arguments 

299 ndim = self.ndim 

300 if ndim > 0: 

301 # strides for inputs 

302 for i in range(schema.num_input_tensors()): 

303 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

304 code.writeline(f"{stride_args}, # strides for in{i}") 

305 if with_block_pointer: 

306 stride_order_args = _cs( 

307 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

308 ) 

309 code.writeline(f"{stride_order_args}, # stride order for in{i}") 

310 

311 # strides for outputs 

312 for i in range(schema.num_output_tensors()): 

313 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

314 code.writeline(f"{stride_args}, # strides for out{i}") 

315 if with_block_pointer: 

316 stride_order_args = _cs( 

317 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

318 ) 

319 code.writeline( 

320 f"{stride_order_args}, # stride order for out{i}" 

321 ) 

322 

323 # task space, used to reconstruct multi index 

324 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

325 code.writeline(f"{task_space_args}, # task_space") 

326 

327 # number of tasks, used to compute mask 

328 code.writeline("num_tasks,") 

329 

330 # tile size & tiles_per_cta, gsl style 

331 if ndim > 0: 

332 code.writeline("tiles_per_cta: int,") 

333 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim)) 

334 code.writeline(f"{tile_sizes},") 

335 code.writeline("one_tile_per_cta: tl.constexpr,") 

336 code.writeline("):") 

337 

338 def gen_signature_1d_tile(self, code): 

339 code.writeline(f"def {self.name}(") 

340 with code.indent(): 

341 input_tensor_index = 0 

342 non_tensor_index = 0 

343 output_tensor_index = 0 

344 

345 schema = self.fx 

346 # signature: inputs ptrs & non tensor inputs 

347 for i in range(schema.num_inputs()): 

348 if schema.is_tensor(i): 

349 code.writeline( 

350 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

351 ) 

352 input_tensor_index += 1 

353 else: 

354 if schema.input_type(i) is not None: 

355 code.writeline( 

356 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

357 ) 

358 else: 

359 code.writeline(f"val{non_tensor_index},") 

360 non_tensor_index += 1 

361 

362 # signature: output ptrs 

363 for i in range(schema.num_outputs()): 

364 code.writeline( 

365 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

366 ) 

367 output_tensor_index += 1 

368 

369 # signature: strides, for each tensor arguments 

370 ndim = self.ndim 

371 if ndim > 0: 

372 # strides for inputs 

373 for i in range(schema.num_input_tensors()): 

374 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

375 code.writeline(f"{stride_args}, # strides for in{i}") 

376 

377 # strides for outputs 

378 for i in range(schema.num_output_tensors()): 

379 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

380 code.writeline(f"{stride_args}, # strides for out{i}") 

381 

382 # task space, used to reconstruct multi index 

383 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

384 code.writeline(f"{task_space_args}, # task_space") 

385 

386 # number of tasks, used to compute mask 

387 code.writeline("num_tasks,") 

388 

389 # tile size & tiles_per_cta, gsl style 

390 if ndim > 0: 

391 code.writeline("tiles_per_cta: int,") 

392 code.writeline("tile_size: tl.constexpr,") 

393 code.writeline("one_tile_per_cta: tl.constexpr,") 

394 code.writeline("):") 

395 

396 def gen_num_tiles(self, code): 

397 # tile-grid size 

398 ndim = self.ndim 

399 for i in range(ndim): 

400 if i < ndim: 

401 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})") 

402 

403 def gen_body_for_0d(self, code): 

404 schema = self.fx 

405 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

406 outputs_to_scalar_fn = [ 

407 self.output_name(i) for i in range(schema.num_output_tensors()) 

408 ] 

409 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

410 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

411 

412 code.writeline("# loads") 

413 for i in range(schema.num_input_tensors()): 

414 code.writeline( 

415 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) " 

416 "# workaround the bug on bool, we should use the pointer's dtype)" 

417 ) 

418 code.newline() 

419 

420 code.writeline("# compute") 

421 code.writeline( 

422 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

423 ) 

424 code.newline() 

425 

426 code.writeline("# stores") 

427 for i in range(schema.num_output_tensors()): 

428 code.writeline( 

429 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))" 

430 ) 

431 code.newline() 

432 return code 

433 

434 # nd tile 1d grid kernel with block pointer 

435 def gen_body_one_tile_per_cta_with_bptr(self, code): 

436 ndim = self.ndim 

437 schema = self.fx 

438 

439 # block pointer for each operand 

440 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim))) 

441 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim))) 

442 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim))) 

443 

444 # reconstruct pid multi index 

445 code.writeline( 

446 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

447 ) 

448 for i in reversed(range(ndim)): 

449 if i > 0: 

450 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

451 code.writeline(f"tile_id //= num_tiles{i}") 

452 else: 

453 code.writeline(f"tile_id{i} = tile_id") 

454 code.newline() 

455 

456 # cta_offsets 

457 code.writeline("# tile offsets") 

458 for i in range(ndim): 

459 # Or else: AssertionError: Block pointers only support 32 bit 

460 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing 

461 # for 64 bit support 

462 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)") 

463 

464 # loads 

465 code.writeline("# loads") 

466 for i in range(schema.num_input_tensors()): 

467 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim))) 

468 order = _tuple_content(tuple(f"in{i}_stride_order{j}" for j in range(ndim))) 

469 code.writeline( 

470 f"in{i}_bptr = tl.make_block_ptr(" 

471 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

472 ) 

473 code.writeline( 

474 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) " 

475 "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)" 

476 ) 

477 code.newline() 

478 

479 # compute 

480 # TODO: sepearate this part 

481 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

482 outputs_to_scalar_fn = [ 

483 self.output_name(i) for i in range(schema.num_output_tensors()) 

484 ] 

485 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

486 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

487 

488 code.writeline("# compute") 

489 code.writeline( 

490 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

491 ) 

492 code.newline() 

493 

494 # stores 

495 code.writeline( 

496 "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype" 

497 ) 

498 for i in range(schema.num_output_tensors()): 

499 strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim))) 

500 order = _tuple_content( 

501 tuple(f"out{i}_stride_order{j}" for j in range(ndim)) 

502 ) 

503 code.writeline( 

504 f"out{i}_bptr = tl.make_block_ptr(" 

505 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

506 ) 

507 code.writeline( 

508 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))" 

509 ) 

510 

511 def gen_body_gsl_with_bptr(self, code): 

512 code.writeline("num_ctas = tle.num_programs(0)") 

513 code.writeline("for j in range(0, tiles_per_cta):") 

514 with code.indent(): 

515 code.writeline("tile_id = pid + j * num_ctas") 

516 self.gen_body_one_tile_per_cta_with_bptr(code) 

517 

518 def gen_body_one_tile_per_cta_without_bptr(self, code): 

519 ndim = self.ndim 

520 schema = self.fx 

521 

522 # reconstruct pid multi index 

523 code.writeline( 

524 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

525 ) 

526 for i in reversed(range(ndim)): 

527 if i > 0: 

528 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

529 code.writeline(f"tile_id //= num_tiles{i}") 

530 else: 

531 code.writeline(f"tile_id{i} = tile_id") 

532 code.newline() 

533 

534 # offsets 

535 for i in range(ndim): 

536 code.writeline( 

537 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})" 

538 ) 

539 

540 # masks 

541 for i in range(ndim): 

542 code.writeline(f"mask{i} = offsets{i} < s{i}") 

543 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim)) 

544 mask_combine = " & ".join(masks) 

545 code.writeline(f"mask = {mask_combine}") 

546 

547 # loads 

548 code.writeline("# loads") 

549 for i in range(schema.num_input_tensors()): 

550 offsets = tuple( 

551 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}" 

552 for j in range(ndim) 

553 ) 

554 offset_combine = " + ".join(offsets) 

555 code.writeline( 

556 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

557 ) 

558 

559 code.newline() 

560 

561 # compute 

562 # TODO: sepearate this part 

563 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

564 outputs_to_scalar_fn = [ 

565 self.output_name(i) for i in range(schema.num_output_tensors()) 

566 ] 

567 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

568 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

569 

570 code.writeline("# compute") 

571 code.writeline( 

572 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

573 ) 

574 code.newline() 

575 

576 # stores 

577 for i in range(schema.num_output_tensors()): 

578 offsets = tuple( 

579 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}" 

580 for j in range(ndim) 

581 ) 

582 offset_combine = " + ".join(offsets) 

583 code.writeline( 

584 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

585 ) 

586 

587 def gen_body_gsl_without_bptr(self, code): 

588 code.writeline("num_ctas = tle.num_programs(0)") 

589 code.writeline("for j in range(0, tiles_per_cta):") 

590 with code.indent(): 

591 code.writeline("tile_id = pid + j * num_ctas") 

592 self.gen_body_one_tile_per_cta_without_bptr(code) 

593 

594 def codegen_nd_tile_with_bptr(self, code): 

595 """Generate kernel nd tile & 1d grid with gsl support with block pointer.""" 

596 self.gen_import_function(code) 

597 self.gen_decorators(code) 

598 self.gen_signature(code, with_block_pointer=True) 

599 

600 # function body for rank-0 

601 if self.ndim == 0: 

602 with code.indent(): 

603 self.gen_body_for_0d(code) 

604 return code 

605 

606 with code.indent(): 

607 code.writeline("pid = tle.program_id(0)") 

608 self.gen_num_tiles(code) 

609 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

610 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

611 with code.indent(): 

612 code.writeline("tile_id = pid") 

613 self.gen_body_one_tile_per_cta_with_bptr(code) 

614 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

615 code.writeline("else: # grid-stride-loop style kernel") 

616 with code.indent(): 

617 self.gen_body_gsl_with_bptr(code) 

618 code.newline() 

619 return code 

620 

621 def codegen_nd_tile_without_bptr(self, code): 

622 self.gen_import_function(code) 

623 self.gen_decorators(code) 

624 self.gen_signature(code, with_block_pointer=False) 

625 

626 # function body for rank-0 

627 if self.ndim == 0: 

628 with code.indent(): 

629 self.gen_body_for_0d(code) 

630 return code 

631 

632 with code.indent(): 

633 code.writeline("pid = tle.program_id(0)") 

634 self.gen_num_tiles(code) 

635 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

636 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

637 with code.indent(): 

638 code.writeline("tile_id = pid") 

639 self.gen_body_one_tile_per_cta_without_bptr(code) 

640 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

641 code.writeline("else: # grid-stride-loop style kernel") 

642 with code.indent(): 

643 self.gen_body_gsl_without_bptr(code) 

644 code.newline() 

645 return code 

646 

647 def codegen_nd_tile(self, code): 

648 use_block_pointer = self.config.prefer_block_pointer 

649 if use_block_pointer: 

650 self.codegen_nd_tile_with_bptr(code) 

651 else: 

652 self.codegen_nd_tile_without_bptr(code) 

653 return code 

654 

655 def gen_body_one_tile_per_cta_1d_tile(self, code): 

656 ndim = self.ndim 

657 schema = self.fx 

658 

659 # tile id 

660 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)") 

661 code.writeline("mask = tid < num_tasks") 

662 

663 # multi index reconstruction 

664 for i in reversed(range(ndim)): 

665 if i > 0: 

666 code.writeline(f"i{i} = tid % s{i}") 

667 code.writeline(f"tid //= s{i}") 

668 else: 

669 code.writeline(f"i{i} = tid") 

670 code.newline() 

671 

672 # loads 

673 code.writeline("# loads") 

674 for i in range(schema.num_input_tensors()): 

675 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim)) 

676 offset_combine = " + ".join(offsets) 

677 code.writeline( 

678 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

679 ) 

680 

681 code.newline() 

682 

683 # compute 

684 # TODO: sepearate this part 

685 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

686 outputs_to_scalar_fn = [ 

687 self.output_name(i) for i in range(schema.num_output_tensors()) 

688 ] 

689 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

690 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

691 

692 code.writeline("# compute") 

693 code.writeline( 

694 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

695 ) 

696 code.newline() 

697 

698 # stores 

699 for i in range(schema.num_output_tensors()): 

700 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim)) 

701 offset_combine = " + ".join(offsets) 

702 code.writeline( 

703 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

704 ) 

705 

706 def gen_body_gsl_1d_tile(self, code): 

707 code.writeline("num_ctas = tle.num_programs(0)") 

708 code.writeline("for j in range(0, tiles_per_cta):") 

709 with code.indent(): 

710 code.writeline("tile_id = pid + j * num_ctas") 

711 self.gen_body_one_tile_per_cta_1d_tile(code) 

712 

713 def codegen_1d_tile(self, code): 

714 """Generate kernel 1d tile & 1d grid with gsl support.""" 

715 self.gen_import_function(code) 

716 self.gen_decorators(code) 

717 self.gen_signature_1d_tile(code) 

718 

719 # function body for rank-0 

720 if self.ndim == 0: 

721 with code.indent(): 

722 self.gen_body_for_0d(code) 

723 return code 

724 

725 with code.indent(): 

726 code.writeline("pid = tle.program_id(0)") 

727 # code.writeline("num_ctas = te.num_programs(0)") 

728 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

729 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

730 with code.indent(): 

731 code.writeline("tile_id = pid") 

732 self.gen_body_one_tile_per_cta_1d_tile(code) 

733 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

734 code.writeline("else: # grid-stride-loop style kernel") 

735 with code.indent(): 

736 self.gen_body_gsl_1d_tile(code) 

737 code.newline() 

738 return code 

739 

740 

741class WrapperGenerator: 

742 def __init__( 

743 self, 

744 function_schema: FunctionSchema, 

745 jit_fn_name: str, 

746 ndim: int, 

747 name: str, 

748 config: CodeGenConfig, 

749 ): 

750 self.fx = function_schema 

751 self.jit_fn_name = jit_fn_name 

752 self.ndim = ndim 

753 self.name = name 

754 self.config = config 

755 

756 def input_name(self, i): 

757 is_tensor = self.fx.is_tensor(i) 

758 name = "in" if is_tensor else "val" 

759 index = self.fx.input_index(i) 

760 return f"{name}{index}" 

761 

762 def output_name(self, i): 

763 return f"out{i}" 

764 

765 def gen_signature(self, code: IndentedBuffer): 

766 # TODO: check if triton handles constexprs transitively 

767 schema = self.fx 

768 params: List[str] = [] 

769 for i in range(schema.num_inputs()): 

770 if schema.is_tensor(i): 

771 params.append( 

772 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]" 

773 ) 

774 else: 

775 arg_type = schema.input_type(i) 

776 if arg_type is not None: 

777 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}") 

778 else: 

779 params.append(f"{self.input_name(i)}") 

780 # NOTE: [the wrapper's signature and rules for passing parameters ] 

781 # input params: must be passed by position, since the names are renamed to 

782 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd 

783 # So we enforce that these parameters must be passed by position. 

784 # maybe we can fix it later 

785 # output parameters: must be passed by keyword, since the scalar function 

786 # do not have output parameters(think of it as some scalar function, output 

787 # parameter does not make sense in this case.) They are added to allow destination 

788 # passing style API. Output parameter is convenient in cases where we want 

789 # to use some pre-defiend outputs(especially when they are some views of other 

790 # tensors). We emphasize that these parameters are added in-addition, we enforce 

791 # that they be passed by keyword. After all, out0, out1, ... does not mismatch 

792 # names form the scalar function, since it does not have output parameters. 

793 params.append("/") 

794 params.append("*") # output params must be passed by keyword 

795 

796 for i in range(schema.num_output_tensors()): 

797 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]") 

798 code.writeline(f"def {self.name}({_cs(params)}): ") 

799 

800 def gen_docstring(self, code: IndentedBuffer): 

801 schema = self.fx 

802 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""' 

803 code.writeline(doc) 

804 

805 def gen_same_shape_check(self, code: IndentedBuffer): 

806 schema: FunctionSchema = self.fx 

807 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [ 

808 f"out{i}.shape" for i in range(schema.num_output_tensors()) 

809 ] 

810 check: str = " == ".join(params) 

811 code.writeline(f"assert {check}, 'operand shapes mismatch'") 

812 

813 def gen_task_partition(self, code: IndentedBuffer): 

814 code.writeline("# task partitioning") 

815 ndim = self.ndim 

816 if ndim == 0: 

817 code.writeline("num_warps = 1") 

818 code.writeline("num_ctas = 1") 

819 else: 

820 code.writeline("shape = out0.shape") 

821 code.writeline("num_tasks = out0.numel()") 

822 code.writeline("if num_tasks == 0:") 

823 with code.indent(): 

824 self.gen_return(code) 

825 max_tile_size = self.config.max_tile_size 

826 major, _ = get_device_capability() 

827 if self.name.find("fill_scalar") != -1 and major >= 9: 

828 code.writeline("tile_sizes = tuple([64])") 

829 else: 

830 code.writeline( 

831 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)" 

832 ) 

833 code.writeline("tile_size = math.prod(tile_sizes)") 

834 code.writeline( 

835 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))" 

836 ) 

837 

838 if self.name.find("fill_scalar") != -1 and major >= 9: 

839 code.writeline("num_ctas = num_tiles") 

840 else: 

841 max_grid_size0 = self.config.max_grid_size[0] 

842 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

843 

844 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

845 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

846 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

847 code.writeline("grid = (num_ctas, 1, 1)") 

848 

849 def gen_task_partition_1d(self, code: IndentedBuffer): 

850 code.writeline("# task partitioning") 

851 ndim = self.ndim 

852 if ndim == 0: 

853 code.writeline("num_warps = 1") 

854 code.writeline("num_ctas = 1") 

855 else: 

856 code.writeline("shape = out0.shape") 

857 code.writeline("num_tasks = out0.numel()") 

858 code.writeline("if num_tasks == 0:") 

859 with code.indent(): 

860 self.gen_return(code) 

861 max_tile_size = self.config.max_tile_size 

862 

863 major, _ = get_device_capability() 

864 if self.name.find("fill_scalar") != -1 and major >= 9: 

865 code.writeline("tile_sizes = tuple([64])") 

866 else: 

867 code.writeline( 

868 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)" 

869 ) 

870 

871 code.writeline("tile_size = tile_sizes[0]") 

872 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)") 

873 

874 if self.name.find("fill_scalar") != -1 and major >= 9: 

875 code.writeline("num_ctas = num_tiles") 

876 else: 

877 max_grid_size0 = self.config.max_grid_size[0] 

878 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

879 

880 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

881 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

882 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

883 code.writeline("grid = (num_ctas, 1, 1)") 

884 

885 def gen_kernel_launch( 

886 self, 

887 code: IndentedBuffer, 

888 ): 

889 schema = self.fx 

890 ndim = self.ndim 

891 

892 with_block_pointer = self.config.prefer_block_pointer 

893 

894 code.writeline("# kernel launch") 

895 for i in range(schema.num_input_tensors()): 

896 code.writeline(f"in{i}_strides = in{i}.stride()") 

897 if not with_block_pointer: 

898 continue 

899 if ndim >= 2: # where ndim is 1, we don't need to compute stride order 

900 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)") 

901 else: 

902 code.writeline(f"in{i}_stride_order = (0,)") 

903 for i in range(schema.num_output_tensors()): 

904 code.writeline(f"out{i}_strides = out{i}.stride()") 

905 if not with_block_pointer: 

906 continue 

907 if ndim >= 2: 

908 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)") 

909 else: 

910 code.writeline(f"out{i}_stride_order = (0,)") 

911 

912 code.writeline("with torch_device_fn.device(in0.device.index):") 

913 with code.indent(): 

914 code.writeline(f"{self.jit_fn_name}[grid](") 

915 with code.indent(): 

916 params = [] 

917 # NOTE: WRAP 

918 for i in range(schema.num_inputs()): 

919 if schema.is_tensor(i): 

920 params.append(f"{self.input_name(i)}") 

921 else: 

922 params.append(self.input_name(i)) 

923 for i in range(schema.num_output_tensors()): 

924 params.append(f"{self.output_name(i)}") 

925 

926 code.writeline(f"{_cs(params)},") 

927 

928 if ndim > 0: 

929 for i in range(schema.num_input_tensors()): 

930 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

931 code.writeline(f"{s}, # stride for in{i}") 

932 if not with_block_pointer: 

933 continue 

934 order = ", ".join( 

935 f"in{i}_stride_order[{j}]" for j in range(ndim) 

936 ) 

937 code.writeline(f"{order}, # stride order for in{i}") 

938 

939 for i in range(schema.num_output_tensors()): 

940 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

941 code.writeline(f"{s}, # stride for out{i}") 

942 if not with_block_pointer: 

943 continue 

944 order = ", ".join( 

945 f"out{i}_stride_order[{j}]" for j in range(ndim) 

946 ) 

947 code.writeline(f"{order}, # stride orderfor out{i}") 

948 

949 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

950 code.writeline(f"{shape_args}, # task indexing space") 

951 code.writeline("num_tasks, # num tasks") 

952 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

953 for i in range(ndim): 

954 code.writeline(f"tile_size{i}=tile_sizes[{i}],") 

955 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

956 code.writeline("num_warps=num_warps,") 

957 code.writeline(")") 

958 

959 def gen_kernel_launch_1d( 

960 self, 

961 code: IndentedBuffer, 

962 ): 

963 schema = self.fx 

964 ndim = self.ndim 

965 

966 code.writeline("# kernel launch") 

967 for i in range(schema.num_input_tensors()): 

968 code.writeline(f"in{i}_strides = in{i}.stride()") 

969 for i in range(schema.num_output_tensors()): 

970 code.writeline(f"out{i}_strides = out{i}.stride()") 

971 

972 code.writeline("with torch_device_fn.device(in0.device.index):") 

973 with code.indent(): 

974 code.writeline(f"{self.jit_fn_name}[grid](") 

975 with code.indent(): 

976 params = [] 

977 # NOTE: WRAP 

978 for i in range(schema.num_inputs()): 

979 if schema.is_tensor(i): 

980 params.append(f"{self.input_name(i)}") 

981 else: 

982 params.append(self.input_name(i)) 

983 for i in range(schema.num_output_tensors()): 

984 params.append(f"{self.output_name(i)}") 

985 

986 code.writeline(f"{_cs(params)},") 

987 

988 if ndim > 0: 

989 for i in range(schema.num_input_tensors()): 

990 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

991 code.writeline(f"{s}, # stride for in{i}") 

992 for i in range(schema.num_output_tensors()): 

993 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

994 code.writeline(f"{s}, # stride for out{i}") 

995 

996 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

997 code.writeline(f"{shape_args}, # task indexing space") 

998 code.writeline("num_tasks, # num tasks") 

999 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

1000 code.writeline("tile_size=tile_size,") 

1001 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

1002 code.writeline("num_warps=num_warps,") 

1003 code.writeline(")") 

1004 

1005 def gen_return(self, code: IndentedBuffer): 

1006 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors())) 

1007 code.writeline(f"return {return_exprs}") 

1008 

1009 def codegen_nd_tile(self, code): 

1010 self.gen_signature(code) 

1011 

1012 with code.indent(): 

1013 self.gen_docstring(code) 

1014 self.gen_same_shape_check(code) 

1015 self.gen_task_partition(code) 

1016 self.gen_kernel_launch(code) 

1017 self.gen_return(code) 

1018 code.newline() 

1019 return code 

1020 

1021 def codegen_1d_tile(self, code): 

1022 self.gen_signature(code) 

1023 

1024 with code.indent(): 

1025 self.gen_docstring(code) 

1026 self.gen_same_shape_check(code) 

1027 self.gen_task_partition_1d(code) 

1028 self.gen_kernel_launch_1d(code) 

1029 self.gen_return(code) 

1030 code.newline() 

1031 return code 

1032 

1033 

1034class ModuleGenerator: 

1035 def __init__( 

1036 self, 

1037 function_schema: FunctionSchema, 

1038 scalar_fn: triton.JITFunction, 

1039 ndim: int, 

1040 jit_fn_name: str, 

1041 wrapper_name: str, 

1042 config: CodeGenConfig, 

1043 ): 

1044 self.config = config 

1045 self.wrapper_gen = WrapperGenerator( 

1046 function_schema, jit_fn_name, ndim, wrapper_name, config 

1047 ) 

1048 self.kernel_gen = KernelGenerator( 

1049 function_schema, scalar_fn, ndim, jit_fn_name, config 

1050 ) 

1051 

1052 @staticmethod 

1053 def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

1054 code.writeline("import math") 

1055 code.writeline("from typing import Union") 

1056 code.writeline("import torch") 

1057 code.writeline("import triton") 

1058 code.writeline("from triton import language as tl") 

1059 code.newline() 

1060 code.writeline("from flag_gems.utils.shape_utils import (") 

1061 code.writeline(" heuristics_for_tile_size,") 

1062 code.writeline(" heuristics_for_num_warps,") 

1063 code.writeline(" stride_order,") 

1064 code.writeline(")") 

1065 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer") 

1066 code.writeline("from flag_gems.utils.libentry import libentry") 

1067 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

1068 code.writeline("from flag_gems.runtime import torch_device_fn") 

1069 code.newline() 

1070 code.newline() 

1071 return code 

1072 

1073 def codegen(self, code: IndentedBuffer): 

1074 # the only runtime determined factor is the rank of the task space 

1075 code = self.generate_imports(code) 

1076 if self.config.prefer_1d_tile: 

1077 code = self.wrapper_gen.codegen_1d_tile(code) 

1078 code = self.kernel_gen.codegen_1d_tile(code) 

1079 else: 

1080 code = self.wrapper_gen.codegen_nd_tile(code) 

1081 code = self.kernel_gen.codegen_nd_tile(code) 

1082 return code 

1083 

1084 

1085class PointwiseDynamicFunction: 

1086 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

1087 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

1088 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

1089 """ 

1090 

1091 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None): 

1092 self.fx = op_desc 

1093 

1094 assert isinstance(scalar_fn, JITFunction) 

1095 self._scalar_fn = scalar_fn 

1096 self._scalar_fn_cache_key = scalar_fn.cache_key 

1097 self.pid = os.getpid() 

1098 

1099 self.config: CodeGenConfig = config or get_codegen_config() 

1100 

1101 # instantiated & cached overloads 

1102 self.overloads: Mapping[str, Callable] = {} 

1103 

1104 def __call__(self, *args, **kwargs): 

1105 # inputs must be passed by position, outputs must be passed by keyword 

1106 ndim, args, kwargs = self.prepare_args(*args, **kwargs) 

1107 overload = self.instantiate(ndim) 

1108 out = overload(*args, **kwargs) 

1109 # NOTE: overload keeps the type of outputs: 

1110 # if a pre-defiend output is a Tensor or StridedBuffer, the corresponding 

1111 # output is also a Tensor StridedBuffer, respectively 

1112 # since prepare_args Wraps all the arguments, the outputs are all StridedBuffer 

1113 # but if manually instantiated overload is directly called, take care of 

1114 # that manually 

1115 return self._unwrap(out) 

1116 

1117 @staticmethod 

1118 def use_fast_path(tensors): 

1119 return all_the_same_shape(tensors) and ( 

1120 all_c_contiguous(tensors) 

1121 or ( 

1122 all_the_same_stride(tensors) 

1123 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0]) 

1124 ) 

1125 ) 

1126 

1127 def prepare_args(self, *args, **kwargs): 

1128 # output allocation(when needed) 

1129 # task simplification & task-rank infernece & input-output reinterpretation 

1130 schema = self.fx 

1131 outputs_that_need_allocation: List[int] = [] 

1132 out_tensors = [] 

1133 for i in range(schema.num_output_tensors()): 

1134 k = f"out{i}" 

1135 if k in kwargs: 

1136 out_tensors.append(kwargs[k]) 

1137 else: 

1138 outputs_that_need_allocation.append(i) 

1139 # input arguments must be passed by position 

1140 if schema._is_tensor is not None: 

1141 if not check_tensor_attributes(args, (schema._is_tensor)): 

1142 raise ValueError( 

1143 "Input arguments must be passed by position, and the corresponding dtype must be specified." 

1144 ) 

1145 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)] 

1146 

1147 # output dtype promotions 

1148 outputs_dtypes_for_allocation = [] 

1149 for i in outputs_that_need_allocation: 

1150 *arg_indices, method = schema._promotion_methods[i] 

1151 promote_args = (args[j] for j in arg_indices) 

1152 _, dtype = type_promotion(*promote_args, type_promotion=method) 

1153 outputs_dtypes_for_allocation.append(dtype) 

1154 

1155 tensors = out_tensors + in_tensors 

1156 INT32_MAX = torch.iinfo(torch.int32).max 

1157 if tensors[0].numel() > INT32_MAX: 

1158 self.config.prefer_block_pointer = False 

1159 if self.use_fast_path(tensors): # dimension collapse & use physical ordering 

1160 allocated_outputs = [ 

1161 torch.empty_like(tensors[0], dtype=dtype) 

1162 for dtype in outputs_dtypes_for_allocation 

1163 ] 

1164 task_shape = (tensors[0].numel(),) 

1165 strides = (1,) 

1166 ndim = 1 

1167 args = tuple( 

1168 ( 

1169 StridedBuffer(item, task_shape, strides) 

1170 if schema.is_tensor(i) 

1171 else item 

1172 ) 

1173 for i, item in enumerate(args) 

1174 ) 

1175 kwargs = { 

1176 k: StridedBuffer(item, task_shape, strides) 

1177 for k, item in kwargs.items() 

1178 } 

1179 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1180 kwargs[f"out{output_id}"] = StridedBuffer( 

1181 allocated_outputs[seq_id], task_shape, strides 

1182 ) 

1183 else: 

1184 # a simple strategy: all the undefined tensors will follow the first 

1185 # tensor that is not broadcated, no attempts to simplify task, no reordering, 

1186 # no dimenion collapsing 

1187 shapes = tuple(item.shape for item in in_tensors) 

1188 

1189 task_shape = broadcast_shapes(shapes) 

1190 

1191 if out_tensors: 

1192 for index, item in enumerate(out_tensors): 

1193 if list(item.shape) != list(task_shape): 

1194 raise RuntimeError( 

1195 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!" 

1196 ) 

1197 # output arguments must not have internal overlapping for pointwise operation 

1198 if has_internal_overlapping(item) == MemOverlap.Yes: 

1199 raise RuntimeError( 

1200 "Pointwise Input arguments should not have internal overlapping." 

1201 ) 

1202 

1203 ndim = len(task_shape) 

1204 for item in tensors: 

1205 if item.shape == task_shape: 

1206 allocated_outputs = [ 

1207 torch.empty_like(item, dtype=dtype) 

1208 for dtype in outputs_dtypes_for_allocation 

1209 ] 

1210 break 

1211 else: # nobreak 

1212 device = tensors[0].device 

1213 allocated_outputs = [ 

1214 torch.empty(task_shape, dtype=dtype, device=device) 

1215 for dtype in outputs_dtypes_for_allocation 

1216 ] 

1217 args = tuple( 

1218 ( 

1219 StridedBuffer( 

1220 item, 

1221 task_shape, 

1222 broadcasted_stride(item.shape, item.stride(), task_shape), 

1223 ) 

1224 if schema.is_tensor(i) 

1225 else item 

1226 ) 

1227 for i, item in enumerate(args) 

1228 ) 

1229 kwargs = { 

1230 k: StridedBuffer( 

1231 item, 

1232 task_shape, 

1233 broadcasted_stride(item.shape, item.stride(), task_shape), 

1234 ) 

1235 for k, item in kwargs.items() 

1236 } 

1237 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1238 item = allocated_outputs[seq_id] 

1239 kwargs[f"out{output_id}"] = StridedBuffer( 

1240 item, 

1241 task_shape, 

1242 broadcasted_stride(item.shape, item.stride(), task_shape), 

1243 ) 

1244 return (ndim, args, kwargs) 

1245 

1246 def _unwrap(self, tensors): 

1247 # unwrap StridedBuffer to get Tensor 

1248 if self.fx.num_output_tensors() == 1: 

1249 item = tensors 

1250 return item.unwrap() 

1251 return tuple(item.unwrap() for item in tensors) 

1252 

1253 def instantiate(self, ndim): 

1254 # NOTE: manually instantiated overload does not have `prepare_args` as 

1255 # preprocessing, so you have to manually allocate output and make sure that 

1256 # the inputs & ouputs actually fits the manually instantiated overload 

1257 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1258 if key in self.overloads: 

1259 return self.overloads[key] 

1260 

1261 code = IndentedBuffer() 

1262 

1263 scalar_fn_name = self._scalar_fn.__name__ 

1264 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}" 

1265 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}" 

1266 module_gen = ModuleGenerator( 

1267 self.fx, 

1268 self._scalar_fn, 

1269 ndim, 

1270 kernel_name, 

1271 wrapper_name, 

1272 self.config, 

1273 ) 

1274 module_gen.codegen(code) 

1275 

1276 # NOTE: [why write the generated code to a file] 

1277 # triton uses inpsect to get the source of the jitted function, which requires 

1278 # that the source code can be found by inspect 

1279 # We write it into a file, since inspect cannot find the source of functions dynamically 

1280 # created via exec string. We can help inspect to find the source by hacking linecache 

1281 # library, but we find generating a module simpler, since we can generating 2 functions 

1282 # the kernel and the wrapper, and the wrapper calls the kernel. 

1283 file_name = ( 

1284 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_" 

1285 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}" 

1286 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}" 

1287 ".py" 

1288 ) 

1289 

1290 file_path = code_cache_dir() / file_name 

1291 write_atomic(file_path, code.getvalue()) 

1292 

1293 # load 

1294 spec = importlib.util.spec_from_file_location( 

1295 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}", 

1296 file_path, 

1297 ) 

1298 m = importlib.util.module_from_spec(spec) 

1299 # do not expose it to sys.modules 

1300 # sys.modules["_add_module"] = m 

1301 

1302 # NOTE: [why not import the scalar function] 

1303 # we do not re-import the scalar function, although the generated kernel **calls** it 

1304 # Since a function's __name__ may be changed, from the module where it is defined import its 

1305 # __name__ is not same; Also the same may be rebind to something else, importing via name 

1306 # cannot guarantee that scalar function is imported. 

1307 # So we copy the scalar function and its __globals__ to the generated module to do this 

1308 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime 

1309 spec.loader.exec_module(m) 

1310 m.__dict__.update(self._scalar_fn.__globals__) 

1311 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn 

1312 

1313 overload = getattr(m, wrapper_name) 

1314 self.overloads[key] = overload 

1315 return overload 

1316 

1317 

1318def pointwise_dynamic( 

1319 f: Optional[JITFunction] = None, 

1320 *, 

1321 num_inputs: Optional[int] = None, 

1322 is_tensor: Optional[List[bool]] = None, 

1323 dtypes: Optional[List[Optional[type]]] = None, 

1324 num_outputs: Optional[int] = None, 

1325 promotion_methods: Optional[Tuple[int, ...]] = None, 

1326 config: Optional[CodeGenConfig] = None, 

1327): 

1328 def decorator(fn): 

1329 nonlocal num_inputs 

1330 if (num_inputs is None) and (is_tensor is None) and (dtypes is None): 

1331 num_inputs = len(fn.arg_names) 

1332 op_desc = FunctionSchema( 

1333 num_inputs=num_inputs, 

1334 is_tensor=is_tensor, 

1335 dtypes=dtypes, 

1336 num_outputs=num_outputs, 

1337 promotion_methods=promotion_methods, 

1338 ) 

1339 return PointwiseDynamicFunction(op_desc, fn, config) 

1340 

1341 if f is not None: 

1342 return decorator(f) 

1343 return decorator