Coverage for src/flag_gems/runtime/backend/_cambricon/utils/pointwise_dynamic.py: 0%

1003 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import importlib 

2import os 

3from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple 

4 

5import torch 

6import triton 

7from triton.runtime.jit import JITFunction 

8 

9from flag_gems.utils.code_cache import code_cache_dir 

10from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

11from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config 

12from flag_gems.utils.shape_utils import ( 

13 MemOverlap, 

14 all_c_contiguous, 

15 all_the_same_shape, 

16 all_the_same_stride, 

17 broadcast_shapes, 

18 broadcasted_stride, 

19 check_tensor_attributes, 

20 has_internal_overlapping, 

21) 

22from flag_gems.utils.tensor_wrapper import StridedBuffer 

23from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion 

24 

25 

26# ------------------ Operation Description --------------------------- 

27def _type_name(type) -> str: 

28 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object" 

29 if type in (bool, int, float, str): 

30 return type.__name__ 

31 if isinstance(type, torch.dtype): 

32 return str(type) 

33 return str(type) 

34 

35 

36def _check_typed_list(container, type): 

37 for item in container: 

38 assert isinstance(item, type) 

39 

40 

41def _check_sized_list(container, size): 

42 assert len(container) == size 

43 

44 

45def _tuple_content(strings: Sequence[str]) -> str: 

46 # comma separated list 

47 if len(strings) == 0: 

48 return "" 

49 if len(strings) == 1: 

50 return f"{strings[0]}," 

51 else: 

52 return ", ".join(strings) 

53 

54 

55def _cs(strings: Iterable[str]) -> str: 

56 return ", ".join(strings) 

57 

58 

59def _broadcast_vec(i, ndim): 

60 axes = [":" if j == i else "None" for j in range(ndim)] 

61 return f"[{_cs(axes)}]" 

62 

63 

64class FunctionSchema: 

65 _num_inputs: int 

66 _is_tensor: List[bool] 

67 _dtypes: List[Optional[type]] 

68 

69 _num_input_tensors: int 

70 _num_non_tensor_inputs: int 

71 

72 _num_outputs: int 

73 _promotion_methods: List[Tuple[int, ...]] 

74 

75 def __init__( 

76 self, 

77 *, 

78 num_inputs: Optional[int] = None, 

79 is_tensor: Optional[List[bool]] = None, 

80 dtypes: Optional[List[Optional[type]]] = None, 

81 num_outputs: Optional[int] = None, 

82 promotion_methods=None, 

83 ): 

84 if is_tensor is not None: 

85 _check_typed_list(is_tensor, bool) 

86 if dtypes is not None: 

87 _check_typed_list(dtypes, (type, type(None))) 

88 

89 if promotion_methods is None: 

90 raise ValueError( 

91 "No type promotion method provided! You must provide type promotion method for each output!" 

92 ) 

93 else: 

94 self._promotion_methods = self.canonicalize_promotion_methods( 

95 promotion_methods 

96 ) 

97 if num_inputs is not None: 

98 self._num_inputs = num_inputs 

99 if is_tensor is not None: 

100 _check_sized_list(is_tensor, num_inputs) 

101 self._is_tensor = is_tensor 

102 else: 

103 self._is_tensor = [True] * num_inputs 

104 

105 if dtypes is not None: 

106 _check_sized_list(dtypes, num_inputs) 

107 self._dtypes = dtypes 

108 else: 

109 self._dtypes = [None] * num_inputs 

110 elif is_tensor is not None: 

111 self._num_inputs = len(is_tensor) 

112 self._is_tensor = is_tensor 

113 if dtypes is not None: 

114 _check_sized_list(dtypes, self._num_inputs) 

115 self._dtypes = dtypes 

116 else: 

117 self._dtypes = [None] * self._num_inputs 

118 elif dtypes is not None: 

119 self._num_inputs = len(dtypes) 

120 self._dtypes = dtypes 

121 if is_tensor is not None: 

122 _check_sized_list(is_tensor, self._num_inputs) 

123 self._is_tensor = is_tensor 

124 else: 

125 self._is_tensor = [item is None for item in dtypes] 

126 else: 

127 raise ValueError( 

128 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified." 

129 ) 

130 

131 if num_outputs is not None: 

132 self._num_outputs = num_outputs 

133 _check_sized_list(promotion_methods, num_outputs) 

134 else: 

135 self._num_outputs = len(promotion_methods) 

136 

137 assert self._num_inputs >= 1 

138 assert self._num_outputs >= 1 

139 

140 self._num_input_tensors = sum(self._is_tensor) 

141 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors 

142 self._input_id = self._compute_input_id() 

143 

144 @staticmethod 

145 def canonicalize_promotion_methods(promotion_methods): 

146 canonicalized = [] 

147 for item in promotion_methods: 

148 *arg_indices, method = item 

149 canonicalized.append( 

150 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method]) 

151 ) 

152 return canonicalized 

153 

154 def num_inputs(self): 

155 # num of arguments, outputs not included 

156 return self._num_inputs 

157 

158 def num_outputs(self): 

159 return self._num_outputs 

160 

161 def is_tensor(self, arg_id: int) -> bool: 

162 return self._is_tensor[arg_id] 

163 

164 def input_type(self, arg_id) -> Optional[type]: 

165 return self._dtypes[arg_id] 

166 

167 def output_type(self, i): 

168 return self._promotion_methods[i] 

169 

170 def num_input_tensors(self) -> int: 

171 return self._num_input_tensors 

172 

173 def num_output_tensors(self) -> int: 

174 return self._num_outputs 

175 

176 def num_non_tensor_args(self) -> int: 

177 return self._num_non_tensor_inputs 

178 

179 def signature(self, outputs_in_arg: bool = False) -> str: 

180 input_types = [] 

181 for is_tensor, dtype in zip(self._is_tensor, self._dtypes): 

182 if is_tensor: 

183 input_types.append("StridedBuffer") 

184 else: 

185 if dtype is None: 

186 input_types.append("scalar") 

187 else: 

188 input_types.append(_type_name(dtype)) 

189 

190 output_types = [] 

191 

192 if outputs_in_arg: 

193 for i in range(self.num_outputs()): 

194 output_types.append(f"StridedBuffer(a{1}!)") 

195 input_types.extend(output_types) 

196 else: 

197 for _ in range(self.num_outputs()): 

198 output_types.append("StridedBuffer") 

199 sig = f"Pointwise: {', '.join(input_types)} -> {', '.join(output_types)}" 

200 return sig 

201 

202 def _compute_input_id(self): 

203 input_tensor_index = 0 

204 non_tensor_index = 0 

205 mapping: List[int] = [] 

206 for i in range(self.num_inputs()): 

207 if self.is_tensor(i): 

208 mapping.append(input_tensor_index) 

209 input_tensor_index += 1 

210 else: 

211 mapping.append(non_tensor_index) 

212 non_tensor_index += 1 

213 return mapping 

214 

215 def input_index(self, idx): 

216 return self._input_id[idx] 

217 

218 def __str__(self) -> str: 

219 return self.signature(outputs_in_arg=False) 

220 

221 

222class KernelGenerator: 

223 def __init__( 

224 self, 

225 function_schema: FunctionSchema, 

226 scalar_fn: triton.JITFunction, 

227 rank: int, 

228 name: str, 

229 config: CodeGenConfig, 

230 ): 

231 self.fx = function_schema 

232 self.fn = scalar_fn 

233 self.ndim = rank 

234 self.name = name 

235 self.config = config 

236 

237 self.fn_name = scalar_fn.__name__ 

238 self.fn_module = scalar_fn.__module__ 

239 

240 def gen_import_function(self, code: IndentedBuffer): 

241 code.writeline(f'"""Quoted source of {self.fn_name}:') 

242 code.writemultiline(self.fn.src) 

243 code.writeline('"""') 

244 code.newline() 

245 

246 def gen_config_prune(self, code): 

247 code.writeline("def config_prune(configs, named_args, **kwargs):") 

248 with code.indent(): 

249 code.writeline("new_configs = []") 

250 code.writeline("elem_sizes = []") 

251 for i in range(self.fx.num_input_tensors()): 

252 code.writeline( 

253 f"elem_sizes.append(named_args['in{i}_ptr'].dtype.itemsize)" 

254 ) 

255 for i in range(self.fx.num_output_tensors()): 

256 code.writeline( 

257 f"elem_sizes.append(named_args['out{i}_ptr'].dtype.itemsize)" 

258 ) 

259 

260 code.writeline("max_elem_size = max(elem_sizes)") 

261 shape = ", ".join(f"s{i}" for i in range(self.ndim)) 

262 named_shape = ", ".join(f"named_args['s{i}']" for i in range(self.ndim)) 

263 code.writeline(f"{shape} = {named_shape}") 

264 tile_sizes = ", ".join(f"tile_size{i}" for i in range(self.ndim)) 

265 tile_size_dict = ", ".join( 

266 f"'tile_size{i}': tile_size{i}" for i in range(self.ndim) 

267 ) 

268 

269 code.writeline("if max_elem_size < 8:") 

270 with code.indent(): 

271 code.writeline("max_tile_sizes = [1024, 2048, 4096, 8192, 16000]") 

272 code.writeline("for max_tile_size in max_tile_sizes:") 

273 with code.indent(): 

274 code.writeline( 

275 f"({tile_sizes}, ) = heuristics_for_tile_size(max_tile_size, {shape})" 

276 ) 

277 code.writeline( 

278 f"new_configs.append(triton.Config({{{tile_size_dict}}}, num_stages=3, num_warps=1))" 

279 ) 

280 code.writeline("else:") 

281 with code.indent(): 

282 code.writeline("max_tile_sizes = [1024, 2048, 4096, 8000]") 

283 code.writeline("for max_tile_size in max_tile_sizes:") 

284 with code.indent(): 

285 code.writeline( 

286 f"({tile_sizes}, ) = heuristics_for_tile_size(max_tile_size, {shape})" 

287 ) 

288 code.writeline( 

289 f"new_configs.append(triton.Config({{{tile_size_dict}}}, num_stages=3, num_warps=1))" 

290 ) 

291 

292 code.writeline("return new_configs") 

293 code.newline() 

294 code.newline() 

295 

296 def gen_decorators(self, code): 

297 if self.ndim in [1, 2, 3, 4] and (not self.config.prefer_1d_tile): 

298 self.gen_config_prune(code) 

299 

300 num_non_tensor_args = self.fx.num_non_tensor_args() 

301 if num_non_tensor_args > 0: 

302 non_tensor_arg_names = ", ".join( 

303 f"'val{i}'" for i in range(num_non_tensor_args) 

304 ) 

305 

306 shapes = ", ".join(f"'s{i}'" for i in range(self.ndim)) 

307 stride_args = [] 

308 for i in range(self.fx.num_input_tensors()): 

309 stride_args.append(_cs(f"'in{i}_stride{j}'" for j in range(self.ndim))) 

310 for i in range(self.fx.num_output_tensors()): 

311 stride_args.append(_cs(f"'out{i}_stride{j}'" for j in range(self.ndim))) 

312 

313 code.writeline("@libentry()") 

314 if self.ndim == 1 and (not self.config.prefer_1d_tile): 

315 code.writeline("@libtuner(") 

316 with code.indent(): 

317 code.writeline("configs=[") 

318 with code.indent(): 

319 code.writeline( 

320 "triton.Config({'tile_size0': 1024}, num_stages=3, num_warps=1)," 

321 ) 

322 code.writeline( 

323 "triton.Config({'tile_size0': 2048}, num_stages=3, num_warps=1)," 

324 ) 

325 code.writeline("],") 

326 if num_non_tensor_args > 0: 

327 code.writeline( 

328 f"key=['num_tasks', {_cs(stride_args)}, {non_tensor_arg_names}]," 

329 ) 

330 else: 

331 code.writeline(f"key=['num_tasks', {_cs(stride_args)}],") 

332 code.writeline("prune_configs_by={'early_config_prune': config_prune},") 

333 output_params = [ 

334 f"out{i}_ptr" for i in range(self.fx.num_output_tensors()) 

335 ] 

336 output_elements = ", ".join(f"'{name}'" for name in output_params) 

337 code.writeline(f"restore_value=[{output_elements}],") 

338 code.writeline(")") 

339 

340 if self.ndim == 2 and (not self.config.prefer_1d_tile): 

341 code.writeline("@libtuner(") 

342 with code.indent(): 

343 code.writeline("configs=[") 

344 with code.indent(): 

345 code.writeline( 

346 "triton.Config({'tile_size0': 1, 'tile_size1': 1024}, num_stages=3, num_warps=1)," 

347 ) 

348 code.writeline( 

349 "triton.Config({'tile_size0': 1, 'tile_size1': 2048}, num_stages=3, num_warps=1)," 

350 ) 

351 code.writeline("],") 

352 if num_non_tensor_args > 0: 

353 code.writeline( 

354 f"key=[{shapes}, {_cs(stride_args)}, {non_tensor_arg_names}]," 

355 ) 

356 else: 

357 code.writeline(f"key=[{shapes}, {_cs(stride_args)}],") 

358 code.writeline("prune_configs_by={'early_config_prune': config_prune},") 

359 output_params = [ 

360 f"out{i}_ptr" for i in range(self.fx.num_output_tensors()) 

361 ] 

362 output_elements = ", ".join(f"'{name}'" for name in output_params) 

363 code.writeline(f"restore_value=[{output_elements}],") 

364 code.writeline(")") 

365 

366 if self.ndim == 3 and (not self.config.prefer_1d_tile): 

367 code.writeline("@libtuner(") 

368 with code.indent(): 

369 code.writeline("configs=[") 

370 with code.indent(): 

371 code.writeline( 

372 """ 

373 triton.Config({'tile_size0': 1, 'tile_size1': 1, 'tile_size2': 1024}, num_stages=3, num_warps=1), 

374 """ 

375 ) 

376 code.writeline( 

377 """ 

378 triton.Config({'tile_size0': 1, 'tile_size1': 1, 'tile_size2': 2048}, num_stages=3, num_warps=1), 

379 """ 

380 ) 

381 code.writeline("],") 

382 if num_non_tensor_args > 0: 

383 code.writeline( 

384 f"key=[{shapes}, {_cs(stride_args)}, {non_tensor_arg_names}]," 

385 ) 

386 else: 

387 code.writeline(f"key=[{shapes}, {_cs(stride_args)}],") 

388 code.writeline("prune_configs_by={'early_config_prune': config_prune},") 

389 output_params = [ 

390 f"out{i}_ptr" for i in range(self.fx.num_output_tensors()) 

391 ] 

392 output_elements = ", ".join(f"'{name}'" for name in output_params) 

393 code.writeline(f"restore_value=[{output_elements}],") 

394 code.writeline(")") 

395 

396 if self.ndim == 4 and (not self.config.prefer_1d_tile): 

397 code.writeline("@libtuner(") 

398 with code.indent(): 

399 code.writeline("configs=[") 

400 with code.indent(): 

401 code.writeline( 

402 """ 

403 triton.Config({'tile_size0': 1,'tile_size1': 1,'tile_size2': 1,'tile_size3': 1024},num_stages=3,num_warps=1), 

404 """ 

405 ) 

406 code.writeline( 

407 """ 

408 triton.Config({'tile_size0': 1,'tile_size1': 1,'tile_size2': 1,'tile_size3': 2048},num_stages=3,num_warps=1), 

409 """ 

410 ) 

411 code.writeline("],") 

412 if num_non_tensor_args > 0: 

413 code.writeline( 

414 f"key=[{shapes}, {_cs(stride_args)}, {non_tensor_arg_names}]," 

415 ) 

416 else: 

417 code.writeline(f"key=[{shapes}, {_cs(stride_args)}],") 

418 code.writeline("prune_configs_by={'early_config_prune': config_prune},") 

419 output_params = [ 

420 f"out{i}_ptr" for i in range(self.fx.num_output_tensors()) 

421 ] 

422 output_elements = ", ".join(f"'{name}'" for name in output_params) 

423 code.writeline(f"restore_value=[{output_elements}],") 

424 code.writeline(")") 

425 

426 if num_non_tensor_args > 0: 

427 # we do not specialize non tensor args since they are passed into the inlined function 

428 # which means that their values may not deserve specialization 

429 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)] 

430 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

431 else: 

432 code.writeline("@triton.jit") 

433 

434 def input_name(self, i): 

435 is_tensor = self.fx.is_tensor(i) 

436 name = "in" if is_tensor else "val" 

437 index = self.fx.input_index(i) 

438 return f"{name}{index}" 

439 

440 def output_name(self, i): 

441 return f"out{i}" 

442 

443 def gen_signature(self, code, with_block_pointer=False): 

444 code.writeline(f"def {self.name}(") 

445 with code.indent(): 

446 input_tensor_index = 0 

447 non_tensor_index = 0 

448 output_tensor_index = 0 

449 

450 schema = self.fx 

451 # signature: inputs ptrs & non tensor inputs 

452 for i in range(schema.num_inputs()): 

453 if schema.is_tensor(i): 

454 code.writeline( 

455 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

456 ) 

457 input_tensor_index += 1 

458 else: 

459 if schema.input_type(i) is not None: 

460 code.writeline( 

461 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

462 ) 

463 else: 

464 code.writeline(f"val{non_tensor_index},") 

465 non_tensor_index += 1 

466 

467 # signature: output ptrs 

468 for i in range(schema.num_outputs()): 

469 code.writeline( 

470 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

471 ) 

472 output_tensor_index += 1 

473 

474 # signature: strides, for each tensor arguments 

475 ndim = self.ndim 

476 if ndim > 0: 

477 # strides for inputs 

478 for i in range(schema.num_input_tensors()): 

479 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

480 code.writeline(f"{stride_args}, # strides for in{i}") 

481 if with_block_pointer: 

482 stride_order_args = _cs( 

483 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

484 ) 

485 code.writeline(f"{stride_order_args}, # stride order for in{i}") 

486 zero_stride_args = _cs( 

487 f"in{i}_zero_stride{j}: tl.constexpr" for j in range(ndim) 

488 ) 

489 code.writeline( 

490 f"{zero_stride_args}, # zero stride flag for in{i}" 

491 ) 

492 

493 # strides for outputs 

494 for i in range(schema.num_output_tensors()): 

495 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

496 code.writeline(f"{stride_args}, # strides for out{i}") 

497 if with_block_pointer: 

498 stride_order_args = _cs( 

499 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

500 ) 

501 code.writeline( 

502 f"{stride_order_args}, # stride order for out{i}" 

503 ) 

504 zero_stride_args = _cs( 

505 f"out{i}_zero_stride{j}: tl.constexpr" for j in range(ndim) 

506 ) 

507 code.writeline( 

508 f"{zero_stride_args}, # zero stride flag for out{i}" 

509 ) 

510 

511 # task space, used to reconstruct multi index 

512 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

513 code.writeline(f"{task_space_args}, # task_space") 

514 

515 # number of tasks, used to compute mask 

516 code.writeline("num_tasks,") 

517 if self.config.prefer_block_pointer: 

518 code.writeline("FALLBACK_BPTR: tl.constexpr,") 

519 

520 # tile size & tiles_per_cta, gsl style 

521 if ndim > 0: 

522 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim)) 

523 code.writeline(f"{tile_sizes},") 

524 if ndim > 4: 

525 code.writeline("tiles_per_cta: int,") 

526 code.writeline("one_tile_per_cta: tl.constexpr,") 

527 code.writeline("):") 

528 

529 def gen_signature_1d_tile(self, code): 

530 code.writeline(f"def {self.name}(") 

531 with code.indent(): 

532 input_tensor_index = 0 

533 non_tensor_index = 0 

534 output_tensor_index = 0 

535 

536 schema = self.fx 

537 # signature: inputs ptrs & non tensor inputs 

538 for i in range(schema.num_inputs()): 

539 if schema.is_tensor(i): 

540 code.writeline( 

541 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

542 ) 

543 input_tensor_index += 1 

544 else: 

545 if schema.input_type(i) is not None: 

546 code.writeline( 

547 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

548 ) 

549 else: 

550 code.writeline(f"val{non_tensor_index},") 

551 non_tensor_index += 1 

552 

553 # signature: output ptrs 

554 for i in range(schema.num_outputs()): 

555 code.writeline( 

556 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

557 ) 

558 output_tensor_index += 1 

559 

560 # signature: strides, for each tensor arguments 

561 ndim = self.ndim 

562 if ndim > 0: 

563 # strides for inputs 

564 for i in range(schema.num_input_tensors()): 

565 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

566 code.writeline(f"{stride_args}, # strides for in{i}") 

567 

568 # strides for outputs 

569 for i in range(schema.num_output_tensors()): 

570 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

571 code.writeline(f"{stride_args}, # strides for out{i}") 

572 

573 # task space, used to reconstruct multi index 

574 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

575 code.writeline(f"{task_space_args}, # task_space") 

576 

577 # number of tasks, used to compute mask 

578 code.writeline("num_tasks,") 

579 

580 if self.config.prefer_block_pointer: 

581 code.writeline("FALLBACK_BPTR: tl.constexpr,") 

582 

583 # tile size & tiles_per_cta, gsl style 

584 if ndim > 0: 

585 code.writeline("tiles_per_cta: int,") 

586 code.writeline("tile_size: tl.constexpr,") 

587 code.writeline("one_tile_per_cta: tl.constexpr,") 

588 code.writeline("):") 

589 

590 def gen_num_tiles(self, code): 

591 # tile-grid size 

592 ndim = self.ndim 

593 for i in range(ndim): 

594 if i < ndim: 

595 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})") 

596 

597 def gen_body_for_0d(self, code): 

598 schema = self.fx 

599 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

600 outputs_to_scalar_fn = [ 

601 self.output_name(i) for i in range(schema.num_output_tensors()) 

602 ] 

603 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

604 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

605 

606 code.writeline("# loads") 

607 for i in range(schema.num_input_tensors()): 

608 code.writeline( 

609 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) " 

610 "# workaround the bug on bool, we should use the pointer's dtype)" 

611 ) 

612 code.newline() 

613 

614 code.writeline("# compute") 

615 code.writeline( 

616 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

617 ) 

618 code.newline() 

619 

620 code.writeline("# stores") 

621 for i in range(schema.num_output_tensors()): 

622 code.writeline( 

623 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))" 

624 ) 

625 code.newline() 

626 return code 

627 

628 # nd tile 1d grid kernel with block pointer 

629 def gen_body_one_tile_per_cta_with_bptr(self, code): 

630 ndim = self.ndim 

631 schema = self.fx 

632 

633 # block pointer for each operand 

634 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim))) 

635 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim))) 

636 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim))) 

637 

638 # reconstruct pid multi index 

639 code.writeline( 

640 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

641 ) 

642 for i in reversed(range(ndim)): 

643 if i > 0: 

644 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

645 code.writeline(f"tile_id //= num_tiles{i}") 

646 else: 

647 code.writeline(f"tile_id{i} = tile_id") 

648 code.newline() 

649 

650 # cta_offsets 

651 code.writeline("# tile offsets") 

652 

653 # Because block pointer only support `tl.int32` indexing, when max offsets 

654 # of ptrs exceeding 2^31, we should fallback it to noraml indexing method. 

655 code.writeline("if not FALLBACK_BPTR:") 

656 with code.indent(): 

657 for i in range(ndim): 

658 # Or else: AssertionError: Block pointers only support 32 bit 

659 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing 

660 # for 64 bit support 

661 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)") 

662 

663 # loads 

664 code.writeline("# loads") 

665 for i in range(schema.num_input_tensors()): 

666 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim))) 

667 order = _tuple_content( 

668 tuple(f"in{i}_stride_order{j}" for j in range(ndim)) 

669 ) 

670 

671 for j in range(ndim): 

672 code.writeline(f"if in{i}_zero_stride{j}:") 

673 with code.indent(): 

674 code.writeline(f"in{i}_stride{j} = 0") 

675 

676 code.writeline( 

677 f"in{i}_bptr = tl.make_block_ptr(" 

678 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

679 ) 

680 

681 code.writeline( 

682 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) " 

683 ) 

684 code.newline() 

685 

686 # compute 

687 # TODO: sepearate this part 

688 inputs_to_scalar_fn = [ 

689 self.input_name(i) for i in range(schema.num_inputs()) 

690 ] 

691 outputs_to_scalar_fn = [ 

692 self.output_name(i) for i in range(schema.num_output_tensors()) 

693 ] 

694 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

695 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

696 

697 code.writeline("# compute") 

698 code.writeline( 

699 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

700 ) 

701 code.newline() 

702 

703 # stores 

704 for i in range(schema.num_output_tensors()): 

705 strides = _tuple_content( 

706 tuple(f"out{i}_stride{j}" for j in range(ndim)) 

707 ) 

708 order = _tuple_content( 

709 tuple(f"out{i}_stride_order{j}" for j in range(ndim)) 

710 ) 

711 

712 for j in range(ndim): 

713 code.writeline(f"if out{i}_zero_stride{j}:") 

714 with code.indent(): 

715 code.writeline(f"out{i}_stride{j} = 0") 

716 

717 code.writeline( 

718 f"out{i}_bptr = tl.make_block_ptr(" 

719 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

720 ) 

721 

722 code.writeline( 

723 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))" 

724 ) 

725 code.writeline("else:") 

726 with code.indent(): 

727 # offsets 

728 for i in range(ndim): 

729 code.writeline( 

730 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})" 

731 ) 

732 

733 # masks 

734 for i in range(ndim): 

735 code.writeline(f"mask{i} = offsets{i} < s{i}") 

736 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim)) 

737 mask_combine = " & ".join(masks) 

738 code.writeline(f"mask = {mask_combine}") 

739 

740 # loads 

741 code.writeline("# loads") 

742 for i in range(schema.num_input_tensors()): 

743 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim))) 

744 order = _tuple_content( 

745 tuple(f"in{i}_stride_order{j}" for j in range(ndim)) 

746 ) 

747 

748 for j in range(ndim): 

749 code.writeline(f"if in{i}_zero_stride{j}:") 

750 with code.indent(): 

751 code.writeline(f"in{i}_stride{j} = 0") 

752 offsets = tuple( 

753 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}" 

754 for j in range(ndim) 

755 ) 

756 offset_combine = " + ".join(offsets) 

757 code.writeline( 

758 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

759 ) 

760 

761 code.newline() 

762 

763 # compute 

764 inputs_to_scalar_fn = [ 

765 self.input_name(i) for i in range(schema.num_inputs()) 

766 ] 

767 outputs_to_scalar_fn = [ 

768 self.output_name(i) for i in range(schema.num_output_tensors()) 

769 ] 

770 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

771 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

772 

773 code.writeline("# compute") 

774 code.writeline( 

775 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

776 ) 

777 code.newline() 

778 

779 # store 

780 for i in range(schema.num_output_tensors()): 

781 strides = _tuple_content( 

782 tuple(f"out{i}_stride{j}" for j in range(ndim)) 

783 ) 

784 order = _tuple_content( 

785 tuple(f"out{i}_stride_order{j}" for j in range(ndim)) 

786 ) 

787 

788 for j in range(ndim): 

789 code.writeline(f"if out{i}_zero_stride{j}:") 

790 with code.indent(): 

791 code.writeline(f"out{i}_stride{j} = 0") 

792 

793 offsets = tuple( 

794 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}" 

795 for j in range(ndim) 

796 ) 

797 offset_combine = " + ".join(offsets) 

798 code.writeline( 

799 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

800 ) 

801 

802 def gen_body_gsl_with_bptr(self, code): 

803 code.writeline("num_ctas = tle.num_programs(0)") 

804 if self.ndim <= 4: 

805 num_tiles = " * ".join([f"num_tiles{i}" for i in range(self.ndim)]) 

806 code.writeline( 

807 f"tiles_per_cta = tl.cdiv({num_tiles}, num_ctas).to(tl.int32)" 

808 ) 

809 code.writeline("for j in range(0, tiles_per_cta):") 

810 with code.indent(): 

811 code.writeline("tile_id = pid + j * num_ctas") 

812 self.gen_body_one_tile_per_cta_with_bptr(code) 

813 

814 def gen_body_one_tile_per_cta_without_bptr(self, code): 

815 ndim = self.ndim 

816 schema = self.fx 

817 

818 # reconstruct pid multi index 

819 code.writeline( 

820 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

821 ) 

822 for i in reversed(range(ndim)): 

823 if i > 0: 

824 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

825 code.writeline(f"tile_id //= num_tiles{i}") 

826 else: 

827 code.writeline(f"tile_id{i} = tile_id") 

828 code.newline() 

829 

830 # offsets 

831 for i in range(ndim): 

832 code.writeline( 

833 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})" 

834 ) 

835 

836 # masks 

837 for i in range(ndim): 

838 code.writeline(f"mask{i} = offsets{i} < s{i}") 

839 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim)) 

840 mask_combine = " & ".join(masks) 

841 code.writeline(f"mask = {mask_combine}") 

842 

843 # loads 

844 code.writeline("# loads") 

845 for i in range(schema.num_input_tensors()): 

846 offsets = tuple( 

847 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}" 

848 for j in range(ndim) 

849 ) 

850 offset_combine = " + ".join(offsets) 

851 code.writeline( 

852 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

853 ) 

854 

855 code.newline() 

856 

857 # compute 

858 # TODO: sepearate this part 

859 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

860 outputs_to_scalar_fn = [ 

861 self.output_name(i) for i in range(schema.num_output_tensors()) 

862 ] 

863 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

864 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

865 

866 code.writeline("# compute") 

867 code.writeline( 

868 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

869 ) 

870 code.newline() 

871 

872 # stores 

873 for i in range(schema.num_output_tensors()): 

874 offsets = tuple( 

875 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}" 

876 for j in range(ndim) 

877 ) 

878 offset_combine = " + ".join(offsets) 

879 code.writeline( 

880 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

881 ) 

882 

883 def gen_body_gsl_without_bptr(self, code): 

884 code.writeline("num_ctas = tle.num_programs(0)") 

885 if self.ndim <= 4: 

886 num_tiles = " * ".join([f"num_tiles{i}" for i in range(self.ndim)]) 

887 code.writeline(f"tiles_per_cta = tl.cdiv({num_tiles}, num_ctas)") 

888 code.writeline("for j in range(0, tiles_per_cta):") 

889 with code.indent(): 

890 code.writeline("tile_id = pid + j * num_ctas") 

891 self.gen_body_one_tile_per_cta_without_bptr(code) 

892 

893 def codegen_nd_tile_with_bptr(self, code): 

894 """Generate kernel nd tile & 1d grid with gsl support with block pointer.""" 

895 self.gen_import_function(code) 

896 self.gen_decorators(code) 

897 self.gen_signature(code, with_block_pointer=True) 

898 

899 # function body for rank-0 

900 if self.ndim == 0: 

901 with code.indent(): 

902 self.gen_body_for_0d(code) 

903 return code 

904 

905 with code.indent(): 

906 code.writeline("pid = tle.program_id(0)") 

907 self.gen_num_tiles(code) 

908 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

909 if self.ndim > 4: 

910 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

911 with code.indent(): 

912 code.writeline("tile_id = pid") 

913 self.gen_body_one_tile_per_cta_with_bptr(code) 

914 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

915 code.writeline("else: # grid-stride-loop style kernel") 

916 with code.indent(): 

917 self.gen_body_gsl_with_bptr(code) 

918 else: 

919 self.gen_body_gsl_with_bptr(code) 

920 code.newline() 

921 return code 

922 

923 def codegen_nd_tile_without_bptr(self, code): 

924 self.gen_import_function(code) 

925 self.gen_decorators(code) 

926 self.gen_signature(code, with_block_pointer=False) 

927 

928 # function body for rank-0 

929 if self.ndim == 0: 

930 with code.indent(): 

931 self.gen_body_for_0d(code) 

932 return code 

933 

934 with code.indent(): 

935 code.writeline("pid = tle.program_id(0)") 

936 self.gen_num_tiles(code) 

937 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

938 if self.ndim > 4: 

939 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

940 with code.indent(): 

941 code.writeline("tile_id = pid") 

942 self.gen_body_one_tile_per_cta_without_bptr(code) 

943 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

944 code.writeline("else: # grid-stride-loop style kernel") 

945 with code.indent(): 

946 self.gen_body_gsl_without_bptr(code) 

947 else: 

948 self.gen_body_gsl_without_bptr(code) 

949 code.newline() 

950 return code 

951 

952 def codegen_nd_tile(self, code): 

953 use_block_pointer = self.config.prefer_block_pointer 

954 if use_block_pointer: 

955 self.codegen_nd_tile_with_bptr(code) 

956 else: 

957 self.codegen_nd_tile_without_bptr(code) 

958 return code 

959 

960 def gen_body_one_tile_per_cta_1d_tile(self, code): 

961 ndim = self.ndim 

962 schema = self.fx 

963 

964 # tile id 

965 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)") 

966 code.writeline("mask = tid < num_tasks") 

967 

968 # multi index reconstruction 

969 for i in reversed(range(ndim)): 

970 if i > 0: 

971 code.writeline(f"i{i} = tid % s{i}") 

972 code.writeline(f"tid //= s{i}") 

973 else: 

974 code.writeline(f"i{i} = tid") 

975 code.newline() 

976 

977 # loads 

978 code.writeline("# loads") 

979 for i in range(schema.num_input_tensors()): 

980 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim)) 

981 offset_combine = " + ".join(offsets) 

982 code.writeline( 

983 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

984 ) 

985 

986 code.newline() 

987 

988 # compute 

989 # TODO: sepearate this part 

990 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

991 outputs_to_scalar_fn = [ 

992 self.output_name(i) for i in range(schema.num_output_tensors()) 

993 ] 

994 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

995 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

996 

997 code.writeline("# compute") 

998 code.writeline( 

999 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

1000 ) 

1001 code.newline() 

1002 

1003 # stores 

1004 for i in range(schema.num_output_tensors()): 

1005 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim)) 

1006 offset_combine = " + ".join(offsets) 

1007 code.writeline( 

1008 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

1009 ) 

1010 

1011 def gen_body_gsl_1d_tile(self, code): 

1012 code.writeline("num_ctas = tle.num_programs(0)") 

1013 code.writeline("for j in range(0, tiles_per_cta):") 

1014 with code.indent(): 

1015 code.writeline("tile_id = pid + j * num_ctas") 

1016 self.gen_body_one_tile_per_cta_1d_tile(code) 

1017 

1018 def codegen_1d_tile(self, code): 

1019 """Generate kernel 1d tile & 1d grid with gsl support.""" 

1020 self.gen_import_function(code) 

1021 self.gen_decorators(code) 

1022 self.gen_signature_1d_tile(code) 

1023 

1024 # function body for rank-0 

1025 if self.ndim == 0: 

1026 with code.indent(): 

1027 self.gen_body_for_0d(code) 

1028 return code 

1029 

1030 with code.indent(): 

1031 code.writeline("pid = tle.program_id(0)") 

1032 # code.writeline("num_ctas = te.num_programs(0)") 

1033 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

1034 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

1035 with code.indent(): 

1036 code.writeline("tile_id = pid") 

1037 self.gen_body_one_tile_per_cta_1d_tile(code) 

1038 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

1039 code.writeline("else: # grid-stride-loop style kernel") 

1040 with code.indent(): 

1041 self.gen_body_gsl_1d_tile(code) 

1042 code.newline() 

1043 return code 

1044 

1045 

1046class WrapperGenerator: 

1047 def __init__( 

1048 self, 

1049 function_schema: FunctionSchema, 

1050 jit_fn_name: str, 

1051 ndim: int, 

1052 name: str, 

1053 config: CodeGenConfig, 

1054 ): 

1055 self.fx = function_schema 

1056 self.jit_fn_name = jit_fn_name 

1057 self.ndim = ndim 

1058 self.name = name 

1059 self.config = config 

1060 

1061 def input_name(self, i): 

1062 is_tensor = self.fx.is_tensor(i) 

1063 name = "in" if is_tensor else "val" 

1064 index = self.fx.input_index(i) 

1065 return f"{name}{index}" 

1066 

1067 def output_name(self, i): 

1068 return f"out{i}" 

1069 

1070 def gen_signature(self, code: IndentedBuffer): 

1071 # TODO: check if triton handles constexprs transitively 

1072 schema = self.fx 

1073 params: List[str] = [] 

1074 for i in range(schema.num_inputs()): 

1075 if schema.is_tensor(i): 

1076 params.append( 

1077 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]" 

1078 ) 

1079 else: 

1080 arg_type = schema.input_type(i) 

1081 if arg_type is not None: 

1082 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}") 

1083 else: 

1084 params.append(f"{self.input_name(i)}") 

1085 # NOTE: [the wrapper's signature and rules for passing parameters ] 

1086 # input params: must be passed by position, since the names are renamed to 

1087 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd 

1088 # So we enforce that these parameters must be passed by position. 

1089 # maybe we can fix it later 

1090 # output parameters: must be passed by keyword, since the scalar function 

1091 # do not have output parameters(think of it as some scalar function, output 

1092 # parameter does not make sense in this case.) They are added to allow destination 

1093 # passing style API. Output parameter is convenient in cases where we want 

1094 # to use some pre-defiend outputs(especially when they are some views of other 

1095 # tensors). We emphasize that these parameters are added in-addition, we enforce 

1096 # that they be passed by keyword. After all, out0, out1, ... does not mismatch 

1097 # names form the scalar function, since it does not have output parameters. 

1098 params.append("/") 

1099 params.append("*") # output params must be passed by keyword 

1100 

1101 for i in range(schema.num_output_tensors()): 

1102 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]") 

1103 code.writeline(f"def {self.name}({_cs(params)}): ") 

1104 

1105 def gen_docstring(self, code: IndentedBuffer): 

1106 schema = self.fx 

1107 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""' 

1108 code.writeline(doc) 

1109 

1110 def gen_same_shape_check(self, code: IndentedBuffer): 

1111 schema: FunctionSchema = self.fx 

1112 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [ 

1113 f"out{i}.shape" for i in range(schema.num_output_tensors()) 

1114 ] 

1115 check: str = " == ".join(params) 

1116 code.writeline(f"assert {check}, 'operand shapes mismatch'") 

1117 

1118 def gen_fallback_bptr(self, code: IndentedBuffer): 

1119 code.writeline("def fallback_bptr(t):") 

1120 with code.indent(): 

1121 code.writeline("ndim = t.dim()") 

1122 code.writeline("sizes = t.size()") 

1123 code.writeline("if t.numel() == 0:") 

1124 with code.indent(): 

1125 code.writeline("return False") 

1126 code.writeline("for i in range(ndim):") 

1127 with code.indent(): 

1128 code.writeline("if sizes[i] >= 2147483648:") 

1129 with code.indent(): 

1130 code.writeline("return True") 

1131 code.writeline("return False") 

1132 code.newline() 

1133 code.newline() 

1134 

1135 def gen_task_partition(self, code: IndentedBuffer): 

1136 code.writeline("# task partitioning") 

1137 ndim = self.ndim 

1138 if ndim == 0: 

1139 code.writeline("num_warps = 1") 

1140 code.writeline("num_ctas = 1") 

1141 else: 

1142 code.writeline("shape = out0.shape") 

1143 code.writeline("num_tasks = out0.numel()") 

1144 code.writeline("if num_tasks == 0:") 

1145 with code.indent(): 

1146 self.gen_return(code) 

1147 max_tile_size = self.config.max_tile_size 

1148 code.writeline( 

1149 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)" 

1150 ) 

1151 code.writeline("tile_size = math.prod(tile_sizes)") 

1152 code.writeline( 

1153 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))" 

1154 ) 

1155 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

1156 max_grid_size0 = self.config.max_grid_size[0] 

1157 code.writeline(f"num_ctas = min({max_grid_size0} // num_warps, num_tiles)") 

1158 

1159 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

1160 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

1161 if self.config.prefer_block_pointer: 

1162 code.writeline("FALLBACK_BPTR = False") 

1163 inputs = ",".join( 

1164 [f"in{i}" for i in range(self.fx.num_input_tensors())] 

1165 ) 

1166 outputs = ",".join( 

1167 [f"out{i}" for i in range(self.fx.num_output_tensors())] 

1168 ) 

1169 code.writeline(f"all_tensors = [{inputs}, {outputs}]") 

1170 code.writeline("for t in all_tensors:") 

1171 with code.indent(): 

1172 code.writeline("if fallback_bptr(t):") 

1173 with code.indent(): 

1174 code.writeline("FALLBACK_BPTR = True") 

1175 code.writeline("break") 

1176 if ndim > 0 and ndim <= 4: 

1177 max_grid_size0 = self.config.max_grid_size[0] 

1178 dynamic_num_tiles = " * ".join( 

1179 f"triton.cdiv(meta['s{i}'], meta['tile_size{i}'])" for i in range(ndim) 

1180 ) 

1181 code.writeline( 

1182 f"grid = lambda meta: (min({max_grid_size0} // num_warps, {dynamic_num_tiles}), )" 

1183 ) 

1184 else: 

1185 code.writeline("grid = (num_ctas, 1, 1)") 

1186 

1187 def gen_task_partition_1d(self, code: IndentedBuffer): 

1188 code.writeline("# task partitioning") 

1189 ndim = self.ndim 

1190 if ndim == 0: 

1191 code.writeline("num_warps = 1") 

1192 code.writeline("num_ctas = 1") 

1193 else: 

1194 code.writeline("shape = out0.shape") 

1195 code.writeline("num_tasks = out0.numel()") 

1196 code.writeline("if num_tasks == 0:") 

1197 with code.indent(): 

1198 self.gen_return(code) 

1199 max_tile_size = self.config.max_tile_size 

1200 code.writeline( 

1201 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)" 

1202 ) 

1203 code.writeline("tile_size = tile_sizes[0]") 

1204 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)") 

1205 max_grid_size0 = self.config.max_grid_size[0] 

1206 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

1207 

1208 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

1209 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

1210 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

1211 if self.config.prefer_block_pointer: 

1212 code.writeline("FALLBACK_BPTR = False") 

1213 inputs = ",".join( 

1214 [f"in{i}" for i in range(self.fx.num_input_tensors())] 

1215 ) 

1216 outputs = ",".join( 

1217 [f"out{i}" for i in range(self.fx.num_output_tensors())] 

1218 ) 

1219 code.writeline(f"all_tensors = [{inputs}, {outputs}]") 

1220 code.writeline("for t in all_tensors:") 

1221 with code.indent(): 

1222 code.writeline("if fallback_bptr(t):") 

1223 with code.indent(): 

1224 code.writeline("FALLBACK_BPTR = True") 

1225 code.writeline("break") 

1226 code.writeline("grid = (num_ctas, 1, 1)") 

1227 

1228 def gen_kernel_launch( 

1229 self, 

1230 code: IndentedBuffer, 

1231 ): 

1232 schema = self.fx 

1233 ndim = self.ndim 

1234 

1235 with_block_pointer = self.config.prefer_block_pointer 

1236 

1237 code.writeline("# kernel launch") 

1238 for i in range(schema.num_input_tensors()): 

1239 code.writeline(f"in{i}_strides = in{i}.stride()") 

1240 if not with_block_pointer: 

1241 continue 

1242 if ndim >= 2: # where ndim is 1, we don't need to compute stride order 

1243 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)") 

1244 else: 

1245 code.writeline(f"in{i}_stride_order = (0,)") 

1246 code.writeline( 

1247 f"in{i}_zero_strides = [True if s == 0 else False for s in in{i}_strides]" 

1248 ) 

1249 for i in range(schema.num_output_tensors()): 

1250 code.writeline(f"out{i}_strides = out{i}.stride()") 

1251 if not with_block_pointer: 

1252 continue 

1253 if ndim >= 2: 

1254 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)") 

1255 else: 

1256 code.writeline(f"out{i}_stride_order = (0,)") 

1257 code.writeline( 

1258 f"out{i}_zero_strides = [True if s == 0 else False for s in out{i}_strides]" 

1259 ) 

1260 

1261 code.writeline("with torch_device_fn.device(in0.device.index):") 

1262 with code.indent(): 

1263 code.writeline(f"{self.jit_fn_name}[grid](") 

1264 with code.indent(): 

1265 params = [] 

1266 # NOTE: WRAP 

1267 for i in range(schema.num_inputs()): 

1268 if schema.is_tensor(i): 

1269 params.append(f"{self.input_name(i)}") 

1270 else: 

1271 params.append(self.input_name(i)) 

1272 for i in range(schema.num_output_tensors()): 

1273 params.append(f"{self.output_name(i)}") 

1274 

1275 code.writeline(f"{_cs(params)},") 

1276 

1277 if ndim > 0: 

1278 for i in range(schema.num_input_tensors()): 

1279 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

1280 code.writeline(f"{s}, # stride for in{i}") 

1281 if with_block_pointer: 

1282 order = ", ".join( 

1283 f"in{i}_stride_order[{j}]" for j in range(ndim) 

1284 ) 

1285 code.writeline(f"{order}, # stride order for in{i}") 

1286 zero_strides = ", ".join( 

1287 f"in{i}_zero_strides[{j}]" for j in range(ndim) 

1288 ) 

1289 code.writeline( 

1290 f"{zero_strides}, # zero stride flag for in{i}" 

1291 ) 

1292 

1293 for i in range(schema.num_output_tensors()): 

1294 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

1295 code.writeline(f"{s}, # stride for out{i}") 

1296 if with_block_pointer: 

1297 order = ", ".join( 

1298 f"out{i}_stride_order[{j}]" for j in range(ndim) 

1299 ) 

1300 code.writeline(f"{order}, # stride orderfor out{i}") 

1301 zero_strides = ", ".join( 

1302 f"out{i}_zero_strides[{j}]" for j in range(ndim) 

1303 ) 

1304 code.writeline( 

1305 f"{zero_strides}, # zero stride flag for out{i}" 

1306 ) 

1307 

1308 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

1309 code.writeline(f"{shape_args}, # task indexing space") 

1310 code.writeline("num_tasks, # num tasks") 

1311 if self.config.prefer_block_pointer: 

1312 code.writeline("FALLBACK_BPTR=FALLBACK_BPTR,") 

1313 if ndim > 4: 

1314 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

1315 if ndim == 0 or ndim > 4: 

1316 for i in range(ndim): 

1317 code.writeline(f"tile_size{i}=tile_sizes[{i}],") 

1318 if ndim > 4: 

1319 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

1320 code.writeline("num_warps=num_warps,") 

1321 code.writeline(")") 

1322 

1323 def gen_kernel_launch_1d( 

1324 self, 

1325 code: IndentedBuffer, 

1326 ): 

1327 schema = self.fx 

1328 ndim = self.ndim 

1329 

1330 code.writeline("# kernel launch") 

1331 for i in range(schema.num_input_tensors()): 

1332 code.writeline(f"in{i}_strides = in{i}.stride()") 

1333 for i in range(schema.num_output_tensors()): 

1334 code.writeline(f"out{i}_strides = out{i}.stride()") 

1335 

1336 code.writeline("with torch_device_fn.device(in0.device.index):") 

1337 with code.indent(): 

1338 code.writeline(f"{self.jit_fn_name}[grid](") 

1339 with code.indent(): 

1340 params = [] 

1341 # NOTE: WRAP 

1342 for i in range(schema.num_inputs()): 

1343 if schema.is_tensor(i): 

1344 params.append(f"{self.input_name(i)}") 

1345 else: 

1346 params.append(self.input_name(i)) 

1347 for i in range(schema.num_output_tensors()): 

1348 params.append(f"{self.output_name(i)}") 

1349 

1350 code.writeline(f"{_cs(params)},") 

1351 

1352 if ndim > 0: 

1353 for i in range(schema.num_input_tensors()): 

1354 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

1355 code.writeline(f"{s}, # stride for in{i}") 

1356 for i in range(schema.num_output_tensors()): 

1357 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

1358 code.writeline(f"{s}, # stride for out{i}") 

1359 

1360 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

1361 code.writeline(f"{shape_args}, # task indexing space") 

1362 code.writeline("num_tasks, # num tasks") 

1363 if self.config.prefer_block_pointer: 

1364 code.writeline("FALLBACK_BPTR=FALLBACK_BPTR,") 

1365 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

1366 code.writeline("tile_size=tile_size,") 

1367 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

1368 code.writeline("num_warps=num_warps,") 

1369 code.writeline(")") 

1370 

1371 def gen_return(self, code: IndentedBuffer): 

1372 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors())) 

1373 code.writeline(f"return {return_exprs}") 

1374 

1375 def codegen_nd_tile(self, code): 

1376 if self.config.prefer_block_pointer: 

1377 self.gen_fallback_bptr(code) 

1378 self.gen_signature(code) 

1379 

1380 with code.indent(): 

1381 self.gen_docstring(code) 

1382 self.gen_same_shape_check(code) 

1383 self.gen_task_partition(code) 

1384 self.gen_kernel_launch(code) 

1385 self.gen_return(code) 

1386 code.newline() 

1387 return code 

1388 

1389 def codegen_1d_tile(self, code): 

1390 if self.config.prefer_block_pointer: 

1391 self.gen_fallback_bptr(code) 

1392 self.gen_signature(code) 

1393 

1394 with code.indent(): 

1395 self.gen_docstring(code) 

1396 self.gen_same_shape_check(code) 

1397 self.gen_task_partition_1d(code) 

1398 self.gen_kernel_launch_1d(code) 

1399 self.gen_return(code) 

1400 code.newline() 

1401 return code 

1402 

1403 

1404class ModuleGenerator: 

1405 def __init__( 

1406 self, 

1407 function_schema: FunctionSchema, 

1408 scalar_fn: triton.JITFunction, 

1409 ndim: int, 

1410 jit_fn_name: str, 

1411 wrapper_name: str, 

1412 config: CodeGenConfig, 

1413 ): 

1414 self.config = config 

1415 self.wrapper_gen = WrapperGenerator( 

1416 function_schema, jit_fn_name, ndim, wrapper_name, config 

1417 ) 

1418 self.kernel_gen = KernelGenerator( 

1419 function_schema, scalar_fn, ndim, jit_fn_name, config 

1420 ) 

1421 

1422 @staticmethod 

1423 def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

1424 code.writeline("import math") 

1425 code.writeline("from typing import Union") 

1426 code.writeline("import torch") 

1427 code.writeline("import triton") 

1428 code.writeline("from triton import language as tl") 

1429 code.newline() 

1430 code.writeline("from flag_gems.utils.shape_utils import (") 

1431 code.writeline(" heuristics_for_tile_size,") 

1432 code.writeline(" heuristics_for_num_warps,") 

1433 code.writeline(" stride_order,") 

1434 code.writeline(")") 

1435 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer") 

1436 code.writeline("from flag_gems.utils.libentry import libentry, libtuner") 

1437 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

1438 code.writeline("from flag_gems.runtime import torch_device_fn") 

1439 code.newline() 

1440 code.newline() 

1441 return code 

1442 

1443 def codegen(self, code: IndentedBuffer): 

1444 # the only runtime determined factor is the rank of the task space 

1445 code = self.generate_imports(code) 

1446 if self.config.prefer_1d_tile: 

1447 code = self.wrapper_gen.codegen_1d_tile(code) 

1448 code = self.kernel_gen.codegen_1d_tile(code) 

1449 else: 

1450 code = self.wrapper_gen.codegen_nd_tile(code) 

1451 code = self.kernel_gen.codegen_nd_tile(code) 

1452 return code 

1453 

1454 

1455class PointwiseDynamicFunction: 

1456 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

1457 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

1458 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

1459 """ 

1460 

1461 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None): 

1462 self.fx = op_desc 

1463 

1464 assert isinstance(scalar_fn, JITFunction) 

1465 self._scalar_fn = scalar_fn 

1466 self._scalar_fn_cache_key = scalar_fn.cache_key 

1467 self.pid = os.getpid() 

1468 

1469 self.config: CodeGenConfig = config or get_codegen_config() 

1470 

1471 # instantiated & cached overloads 

1472 self.overloads: Mapping[str, Callable] = {} 

1473 

1474 def __call__(self, *args, **kwargs): 

1475 # inputs must be passed by position, outputs must be passed by keyword 

1476 ndim, args, kwargs = self.prepare_args(*args, **kwargs) 

1477 overload = self.instantiate(ndim) 

1478 out = overload(*args, **kwargs) 

1479 # NOTE: overload keeps the type of outputs: 

1480 # if a pre-defiend output is a Tensor or StridedBuffer, the corresponding 

1481 # output is also a Tensor StridedBuffer, respectively 

1482 # since prepare_args Wraps all the arguments, the outputs are all StridedBuffer 

1483 # but if manually instantiated overload is directly called, take care of 

1484 # that manually 

1485 return self._unwrap(out) 

1486 

1487 @staticmethod 

1488 def use_fast_path(tensors): 

1489 return all_the_same_shape(tensors) and ( 

1490 all_c_contiguous(tensors) 

1491 or ( 

1492 all_the_same_stride(tensors) 

1493 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0]) 

1494 ) 

1495 ) 

1496 

1497 def prepare_args(self, *args, **kwargs): 

1498 # output allocation(when needed) 

1499 # task simplification & task-rank infernece & input-output reinterpretation 

1500 schema = self.fx 

1501 outputs_that_need_allocation: List[int] = [] 

1502 out_tensors = [] 

1503 for i in range(schema.num_output_tensors()): 

1504 k = f"out{i}" 

1505 if k in kwargs: 

1506 out_tensors.append(kwargs[k]) 

1507 else: 

1508 outputs_that_need_allocation.append(i) 

1509 # input arguments must be passed by position 

1510 if schema._is_tensor is not None: 

1511 if not check_tensor_attributes(args, (schema._is_tensor)): 

1512 raise ValueError( 

1513 "Input arguments must be passed by position, and the corresponding dtype must be specified." 

1514 ) 

1515 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)] 

1516 

1517 # output dtype promotions 

1518 outputs_dtypes_for_allocation = [] 

1519 for i in outputs_that_need_allocation: 

1520 *arg_indices, method = schema._promotion_methods[i] 

1521 promote_args = (args[j] for j in arg_indices) 

1522 _, dtype = type_promotion(*promote_args, type_promotion=method) 

1523 outputs_dtypes_for_allocation.append(dtype) 

1524 

1525 tensors = out_tensors + in_tensors 

1526 if self.use_fast_path(tensors): # dimension collapse & use physical ordering 

1527 allocated_outputs = [ 

1528 torch.empty_like(tensors[0], dtype=dtype) 

1529 for dtype in outputs_dtypes_for_allocation 

1530 ] 

1531 task_shape = (tensors[0].numel(),) 

1532 strides = (1,) 

1533 ndim = 1 

1534 args = tuple( 

1535 ( 

1536 StridedBuffer(item, task_shape, strides) 

1537 if schema.is_tensor(i) 

1538 else item 

1539 ) 

1540 for i, item in enumerate(args) 

1541 ) 

1542 kwargs = { 

1543 k: StridedBuffer(item, task_shape, strides) 

1544 for k, item in kwargs.items() 

1545 } 

1546 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1547 kwargs[f"out{output_id}"] = StridedBuffer( 

1548 allocated_outputs[seq_id], task_shape, strides 

1549 ) 

1550 else: 

1551 # a simple strategy: all the undefined tensors will follow the first 

1552 # tensor that is not broadcated, no attempts to simplify task, no reordering, 

1553 # no dimenion collapsing 

1554 shapes = tuple(item.shape for item in in_tensors) 

1555 

1556 task_shape = broadcast_shapes(shapes) 

1557 

1558 if out_tensors: 

1559 for index, item in enumerate(out_tensors): 

1560 if list(item.shape) != list(task_shape): 

1561 raise RuntimeError( 

1562 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!" 

1563 ) 

1564 # output arguments must not have internal overlapping for pointwise operation 

1565 if has_internal_overlapping(item) == MemOverlap.Yes: 

1566 raise RuntimeError( 

1567 "Pointwise Input arguments should not have internal overlapping." 

1568 ) 

1569 

1570 ndim = len(task_shape) 

1571 for item in tensors: 

1572 if item.shape == task_shape: 

1573 allocated_outputs = [ 

1574 torch.empty_like(item, dtype=dtype) 

1575 for dtype in outputs_dtypes_for_allocation 

1576 ] 

1577 break 

1578 else: # nobreak 

1579 device = tensors[0].device 

1580 allocated_outputs = [ 

1581 torch.empty(task_shape, dtype=dtype, device=device) 

1582 for dtype in outputs_dtypes_for_allocation 

1583 ] 

1584 args = tuple( 

1585 ( 

1586 StridedBuffer( 

1587 item, 

1588 task_shape, 

1589 broadcasted_stride(item.shape, item.stride(), task_shape), 

1590 ) 

1591 if schema.is_tensor(i) 

1592 else item 

1593 ) 

1594 for i, item in enumerate(args) 

1595 ) 

1596 kwargs = { 

1597 k: StridedBuffer( 

1598 item, 

1599 task_shape, 

1600 broadcasted_stride(item.shape, item.stride(), task_shape), 

1601 ) 

1602 for k, item in kwargs.items() 

1603 } 

1604 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1605 item = allocated_outputs[seq_id] 

1606 kwargs[f"out{output_id}"] = StridedBuffer( 

1607 item, 

1608 task_shape, 

1609 broadcasted_stride(item.shape, item.stride(), task_shape), 

1610 ) 

1611 return (ndim, args, kwargs) 

1612 

1613 def _unwrap(self, tensors): 

1614 # unwrap StridedBuffer to get Tensor 

1615 if self.fx.num_output_tensors() == 1: 

1616 item = tensors 

1617 return item.unwrap() 

1618 return tuple(item.unwrap() for item in tensors) 

1619 

1620 def instantiate(self, ndim): 

1621 # NOTE: manually instantiated overload does not have `prepare_args` as 

1622 # preprocessing, so you have to manually allocate output and make sure that 

1623 # the inputs & ouputs actually fits the manually instantiated overload 

1624 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1625 if key in self.overloads: 

1626 return self.overloads[key] 

1627 

1628 code = IndentedBuffer() 

1629 

1630 scalar_fn_name = self._scalar_fn.__name__ 

1631 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}" 

1632 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}" 

1633 module_gen = ModuleGenerator( 

1634 self.fx, 

1635 self._scalar_fn, 

1636 ndim, 

1637 kernel_name, 

1638 wrapper_name, 

1639 self.config, 

1640 ) 

1641 module_gen.codegen(code) 

1642 

1643 # NOTE: [why write the generated code to a file] 

1644 # triton uses inpsect to get the source of the jitted function, which requires 

1645 # that the source code can be found by inspect 

1646 # We write it into a file, since inspect cannot find the source of functions dynamically 

1647 # created via exec string. We can help inspect to find the source by hacking linecache 

1648 # library, but we find generating a module simpler, since we can generating 2 functions 

1649 # the kernel and the wrapper, and the wrapper calls the kernel. 

1650 file_name = ( 

1651 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_" 

1652 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}" 

1653 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}" 

1654 ".py" 

1655 ) 

1656 

1657 file_path = code_cache_dir() / file_name 

1658 write_atomic(file_path, code.getvalue()) 

1659 

1660 # load 

1661 spec = importlib.util.spec_from_file_location( 

1662 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}", 

1663 file_path, 

1664 ) 

1665 m = importlib.util.module_from_spec(spec) 

1666 # do not expose it to sys.modules 

1667 # sys.modules["_add_module"] = m 

1668 

1669 # NOTE: [why not import the scalar function] 

1670 # we do not re-import the scalar function, although the generated kernel **calls** it 

1671 # Since a function's __name__ may be changed, from the module where it is defined import its 

1672 # __name__ is not same; Also the same may be rebind to something else, importing via name 

1673 # cannot guarantee that scalar function is imported. 

1674 # So we copy the scalar function and its __globals__ to the generated module to do this 

1675 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime 

1676 spec.loader.exec_module(m) 

1677 m.__dict__.update(self._scalar_fn.__globals__) 

1678 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn 

1679 

1680 overload = getattr(m, wrapper_name) 

1681 self.overloads[key] = overload 

1682 return overload 

1683 

1684 

1685def pointwise_dynamic( 

1686 f: Optional[JITFunction] = None, 

1687 *, 

1688 num_inputs: Optional[int] = None, 

1689 is_tensor: Optional[List[bool]] = None, 

1690 dtypes: Optional[List[Optional[type]]] = None, 

1691 num_outputs: Optional[int] = None, 

1692 promotion_methods: Optional[Tuple[int, ...]] = None, 

1693 config: Optional[CodeGenConfig] = None, 

1694): 

1695 def decorator(fn): 

1696 nonlocal num_inputs 

1697 if (num_inputs is None) and (is_tensor is None) and (dtypes is None): 

1698 num_inputs = len(fn.arg_names) 

1699 op_desc = FunctionSchema( 

1700 num_inputs=num_inputs, 

1701 is_tensor=is_tensor, 

1702 dtypes=dtypes, 

1703 num_outputs=num_outputs, 

1704 promotion_methods=promotion_methods, 

1705 ) 

1706 return PointwiseDynamicFunction(op_desc, fn, config) 

1707 

1708 if f is not None: 

1709 return decorator(f) 

1710 return decorator