Coverage for src/flag_gems/utils/pointwise_dynamic_cpp_compat.py: 0%
862 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import importlib
2import os
3from dataclasses import dataclass
4from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple
6import torch
7import triton
8from triton.runtime.jit import JITFunction
10from flag_gems.utils.code_cache import code_cache_dir
11from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
12from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config
13from flag_gems.utils.device_info import get_device_capability
14from flag_gems.utils.shape_utils import (
15 MemOverlap,
16 all_c_contiguous,
17 all_the_same_shape,
18 all_the_same_stride,
19 broadcast_shapes,
20 broadcasted_stride,
21 check_tensor_attributes,
22 has_internal_overlapping,
23)
24from flag_gems.utils.tensor_wrapper import StridedBuffer
25from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion
28# ------------------ Operation Description ---------------------------
29def _type_name(type) -> str:
30 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object"
31 if type in (bool, int, float, str):
32 return type.__name__
33 if isinstance(type, torch.dtype):
34 return str(type)
35 return str(type)
38def _check_typed_list(container, type):
39 for item in container:
40 assert isinstance(item, type)
43def _check_sized_list(container, size):
44 assert len(container) == size
47def _tuple_content(strings: Sequence[str]) -> str:
48 # comma separated list
49 if len(strings) == 0:
50 return ""
51 if len(strings) == 1:
52 return f"{strings[0]},"
53 else:
54 return ", ".join(strings)
57def _cs(strings: Iterable[str]) -> str:
58 return ", ".join(strings)
61def _broadcast_vec(i, ndim):
62 axes = [":" if j == i else "None" for j in range(ndim)]
63 return f"[{_cs(axes)}]"
66class FunctionSchema:
67 _num_inputs: int
68 _is_tensor: List[bool]
69 _dtypes: List[Optional[type]]
71 _num_input_tensors: int
72 _num_non_tensor_inputs: int
74 _num_outputs: int
75 _promotion_methods: List[Tuple[int, ...]]
77 def __init__(
78 self,
79 *,
80 num_inputs: Optional[int] = None,
81 is_tensor: Optional[List[bool]] = None,
82 dtypes: Optional[List[Optional[type]]] = None,
83 num_outputs: Optional[int] = None,
84 promotion_methods=None,
85 ):
86 if is_tensor is not None:
87 _check_typed_list(is_tensor, bool)
88 if dtypes is not None:
89 _check_typed_list(dtypes, (type, type(None)))
91 if promotion_methods is None:
92 raise ValueError(
93 "No type promotion method provided! You must provide type promotion method for each output!"
94 )
95 else:
96 self._promotion_methods = self.canonicalize_promotion_methods(
97 promotion_methods
98 )
99 if num_inputs is not None:
100 self._num_inputs = num_inputs
101 if is_tensor is not None:
102 _check_sized_list(is_tensor, num_inputs)
103 self._is_tensor = is_tensor
104 else:
105 self._is_tensor = [True] * num_inputs
107 if dtypes is not None:
108 _check_sized_list(dtypes, num_inputs)
109 self._dtypes = dtypes
110 else:
111 self._dtypes = [None] * num_inputs
112 elif is_tensor is not None:
113 self._num_inputs = len(is_tensor)
114 self._is_tensor = is_tensor
115 if dtypes is not None:
116 _check_sized_list(dtypes, self._num_inputs)
117 self._dtypes = dtypes
118 else:
119 self._dtypes = [None] * self._num_inputs
120 elif dtypes is not None:
121 self._num_inputs = len(dtypes)
122 self._dtypes = dtypes
123 if is_tensor is not None:
124 _check_sized_list(is_tensor, self._num_inputs)
125 self._is_tensor = is_tensor
126 else:
127 self._is_tensor = [item is None for item in dtypes]
128 else:
129 raise ValueError(
130 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified."
131 )
133 if num_outputs is not None:
134 self._num_outputs = num_outputs
135 _check_sized_list(promotion_methods, num_outputs)
136 else:
137 self._num_outputs = len(promotion_methods)
139 assert self._num_inputs >= 1
140 assert self._num_outputs >= 1
142 self._num_input_tensors = sum(self._is_tensor)
143 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors
144 self._input_id = self._compute_input_id()
146 @staticmethod
147 def canonicalize_promotion_methods(promotion_methods):
148 canonicalized = []
149 for item in promotion_methods:
150 *arg_indices, method = item
151 canonicalized.append(
152 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method])
153 )
154 return canonicalized
156 def num_inputs(self):
157 # num of arguments, outputs not included
158 return self._num_inputs
160 def num_outputs(self):
161 return self._num_outputs
163 def is_tensor(self, arg_id: int) -> bool:
164 return self._is_tensor[arg_id]
166 def input_type(self, arg_id) -> Optional[type]:
167 return self._dtypes[arg_id]
169 def output_type(self, i):
170 return self._promotion_methods[i]
172 def num_input_tensors(self) -> int:
173 return self._num_input_tensors
175 def num_output_tensors(self) -> int:
176 return self._num_outputs
178 def num_non_tensor_args(self) -> int:
179 return self._num_non_tensor_inputs
181 def signature(self, outputs_in_arg: bool = False) -> str:
182 input_types = []
183 for is_tensor, dtype in zip(self._is_tensor, self._dtypes):
184 if is_tensor:
185 input_types.append("StridedBuffer")
186 else:
187 if dtype is None:
188 input_types.append("scalar")
189 else:
190 input_types.append(_type_name(dtype))
192 output_types = []
194 if outputs_in_arg:
195 for i in range(self.num_outputs()):
196 output_types.append(f"StridedBuffer(a{1}!)")
197 input_types.extend(output_types)
198 else:
199 for _ in range(self.num_outputs()):
200 output_types.append("StridedBuffer")
201 sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}'
202 return sig
204 def _compute_input_id(self):
205 input_tensor_index = 0
206 non_tensor_index = 0
207 mapping: List[int] = []
208 for i in range(self.num_inputs()):
209 if self.is_tensor(i):
210 mapping.append(input_tensor_index)
211 input_tensor_index += 1
212 else:
213 mapping.append(non_tensor_index)
214 non_tensor_index += 1
215 return mapping
217 def input_index(self, idx):
218 return self._input_id[idx]
220 def __str__(self) -> str:
221 return self.signature(outputs_in_arg=False)
224class KernelGenerator:
225 def __init__(
226 self,
227 function_schema: FunctionSchema,
228 scalar_fn: triton.JITFunction,
229 rank: int,
230 name: str,
231 config: CodeGenConfig,
232 ):
233 self.fx = function_schema
234 self.fn = scalar_fn
235 self.ndim = rank
236 self.name = name
237 self.config = config
239 self.fn_name = scalar_fn.__name__
240 self.fn_module = scalar_fn.__module__
242 def gen_import_function(self, code: IndentedBuffer):
243 code.writeline("@triton.jit")
244 code.writemultiline(self.fn.src)
245 code.newline()
247 def gen_decorators(self, code):
248 code.writeline("@libentry()")
249 num_non_tensor_args = self.fx.num_non_tensor_args()
250 if num_non_tensor_args > 0:
251 # we do not specialize non tensor args since they are passed into the inlined function
252 # which means that their values may not deserve specialization
253 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)]
254 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
255 else:
256 code.writeline("@triton.jit")
258 def input_name(self, i):
259 is_tensor = self.fx.is_tensor(i)
260 name = "in" if is_tensor else "val"
261 index = self.fx.input_index(i)
262 return f"{name}{index}"
264 def output_name(self, i):
265 return f"out{i}"
267 def gen_signature(self, code, with_block_pointer=False):
268 code.writeline(f"def {self.name}(")
269 with code.indent():
270 input_tensor_index = 0
271 non_tensor_index = 0
272 output_tensor_index = 0
274 schema = self.fx
275 # signature: inputs ptrs & non tensor inputs
276 for i in range(schema.num_inputs()):
277 if schema.is_tensor(i):
278 code.writeline(
279 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
280 )
281 input_tensor_index += 1
282 else:
283 if schema.input_type(i) is not None:
284 code.writeline(
285 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
286 )
287 else:
288 code.writeline(f"val{non_tensor_index},")
289 non_tensor_index += 1
291 # signature: output ptrs
292 for i in range(schema.num_outputs()):
293 code.writeline(
294 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
295 )
296 output_tensor_index += 1
298 # signature: strides, for each tensor arguments
299 ndim = self.ndim
300 if ndim > 0:
301 # strides for inputs
302 for i in range(schema.num_input_tensors()):
303 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
304 code.writeline(f"{stride_args}, # strides for in{i}")
305 if with_block_pointer:
306 stride_order_args = _cs(
307 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
308 )
309 code.writeline(f"{stride_order_args}, # stride order for in{i}")
311 # strides for outputs
312 for i in range(schema.num_output_tensors()):
313 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
314 code.writeline(f"{stride_args}, # strides for out{i}")
315 if with_block_pointer:
316 stride_order_args = _cs(
317 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
318 )
319 code.writeline(
320 f"{stride_order_args}, # stride order for out{i}"
321 )
323 # task space, used to reconstruct multi index
324 task_space_args = _cs(f"s{i}" for i in range(ndim))
325 code.writeline(f"{task_space_args}, # task_space")
327 # number of tasks, used to compute mask
328 code.writeline("num_tasks,")
330 # tile size & tiles_per_cta, gsl style
331 if ndim > 0:
332 code.writeline("tiles_per_cta: int,")
333 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim))
334 code.writeline(f"{tile_sizes},")
335 code.writeline("one_tile_per_cta: tl.constexpr,")
336 code.writeline("):")
338 def gen_signature_1d_tile(self, code):
339 code.writeline(f"def {self.name}(")
340 with code.indent():
341 input_tensor_index = 0
342 non_tensor_index = 0
343 output_tensor_index = 0
345 schema = self.fx
346 # signature: inputs ptrs & non tensor inputs
347 for i in range(schema.num_inputs()):
348 if schema.is_tensor(i):
349 code.writeline(
350 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
351 )
352 input_tensor_index += 1
353 else:
354 if schema.input_type(i) is not None:
355 code.writeline(
356 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
357 )
358 else:
359 code.writeline(f"val{non_tensor_index},")
360 non_tensor_index += 1
362 # signature: output ptrs
363 for i in range(schema.num_outputs()):
364 code.writeline(
365 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
366 )
367 output_tensor_index += 1
369 # signature: strides, for each tensor arguments
370 ndim = self.ndim
371 if ndim > 0:
372 # strides for inputs
373 for i in range(schema.num_input_tensors()):
374 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
375 code.writeline(f"{stride_args}, # strides for in{i}")
377 # strides for outputs
378 for i in range(schema.num_output_tensors()):
379 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
380 code.writeline(f"{stride_args}, # strides for out{i}")
382 # task space, used to reconstruct multi index
383 task_space_args = _cs(f"s{i}" for i in range(ndim))
384 code.writeline(f"{task_space_args}, # task_space")
386 # number of tasks, used to compute mask
387 code.writeline("num_tasks,")
389 # tile size & tiles_per_cta, gsl style
390 if ndim > 0:
391 code.writeline("tiles_per_cta: int,")
392 code.writeline("tile_size: tl.constexpr,")
393 code.writeline("one_tile_per_cta: tl.constexpr,")
394 code.writeline("):")
396 def gen_num_tiles(self, code):
397 # tile-grid size
398 ndim = self.ndim
399 for i in range(ndim):
400 if i < ndim:
401 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})")
403 def gen_body_for_0d(self, code):
404 schema = self.fx
405 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
406 outputs_to_scalar_fn = [
407 self.output_name(i) for i in range(schema.num_output_tensors())
408 ]
409 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
410 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
412 code.writeline("# loads")
413 for i in range(schema.num_input_tensors()):
414 code.writeline(
415 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) "
416 "# workaround the bug on bool, we should use the pointer's dtype)"
417 )
418 code.newline()
420 code.writeline("# compute")
421 code.writeline(
422 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
423 )
424 code.newline()
426 code.writeline("# stores")
427 for i in range(schema.num_output_tensors()):
428 code.writeline(
429 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))"
430 )
431 code.newline()
432 return code
434 # nd tile 1d grid kernel with block pointer
435 def gen_body_one_tile_per_cta_with_bptr(self, code):
436 ndim = self.ndim
437 schema = self.fx
439 # block pointer for each operand
440 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim)))
441 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim)))
442 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim)))
444 # reconstruct pid multi index
445 code.writeline(
446 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
447 )
448 for i in reversed(range(ndim)):
449 if i > 0:
450 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
451 code.writeline(f"tile_id //= num_tiles{i}")
452 else:
453 code.writeline(f"tile_id{i} = tile_id")
454 code.newline()
456 # cta_offsets
457 code.writeline("# tile offsets")
458 for i in range(ndim):
459 # Or else: AssertionError: Block pointers only support 32 bit
460 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing
461 # for 64 bit support
462 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)")
464 # loads
465 code.writeline("# loads")
466 for i in range(schema.num_input_tensors()):
467 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim)))
468 order = _tuple_content(tuple(f"in{i}_stride_order{j}" for j in range(ndim)))
469 code.writeline(
470 f"in{i}_bptr = tl.make_block_ptr("
471 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
472 )
473 code.writeline(
474 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) "
475 "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)"
476 )
477 code.newline()
479 # compute
480 # TODO: sepearate this part
481 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
482 outputs_to_scalar_fn = [
483 self.output_name(i) for i in range(schema.num_output_tensors())
484 ]
485 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
486 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
488 code.writeline("# compute")
489 code.writeline(
490 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
491 )
492 code.newline()
494 # stores
495 code.writeline(
496 "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype"
497 )
498 for i in range(schema.num_output_tensors()):
499 strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim)))
500 order = _tuple_content(
501 tuple(f"out{i}_stride_order{j}" for j in range(ndim))
502 )
503 code.writeline(
504 f"out{i}_bptr = tl.make_block_ptr("
505 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
506 )
507 code.writeline(
508 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))"
509 )
511 def gen_body_gsl_with_bptr(self, code):
512 code.writeline("num_ctas = tle.num_programs(0)")
513 code.writeline("for j in range(0, tiles_per_cta):")
514 with code.indent():
515 code.writeline("tile_id = pid + j * num_ctas")
516 self.gen_body_one_tile_per_cta_with_bptr(code)
518 def gen_body_one_tile_per_cta_without_bptr(self, code):
519 ndim = self.ndim
520 schema = self.fx
522 # reconstruct pid multi index
523 code.writeline(
524 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
525 )
526 for i in reversed(range(ndim)):
527 if i > 0:
528 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
529 code.writeline(f"tile_id //= num_tiles{i}")
530 else:
531 code.writeline(f"tile_id{i} = tile_id")
532 code.newline()
534 # offsets
535 for i in range(ndim):
536 code.writeline(
537 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})"
538 )
540 # masks
541 for i in range(ndim):
542 code.writeline(f"mask{i} = offsets{i} < s{i}")
543 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim))
544 mask_combine = " & ".join(masks)
545 code.writeline(f"mask = {mask_combine}")
547 # loads
548 code.writeline("# loads")
549 for i in range(schema.num_input_tensors()):
550 offsets = tuple(
551 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}"
552 for j in range(ndim)
553 )
554 offset_combine = " + ".join(offsets)
555 code.writeline(
556 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
557 )
559 code.newline()
561 # compute
562 # TODO: sepearate this part
563 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
564 outputs_to_scalar_fn = [
565 self.output_name(i) for i in range(schema.num_output_tensors())
566 ]
567 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
568 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
570 code.writeline("# compute")
571 code.writeline(
572 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
573 )
574 code.newline()
576 # stores
577 for i in range(schema.num_output_tensors()):
578 offsets = tuple(
579 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}"
580 for j in range(ndim)
581 )
582 offset_combine = " + ".join(offsets)
583 code.writeline(
584 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
585 )
587 def gen_body_gsl_without_bptr(self, code):
588 code.writeline("num_ctas = tle.num_programs(0)")
589 code.writeline("for j in range(0, tiles_per_cta):")
590 with code.indent():
591 code.writeline("tile_id = pid + j * num_ctas")
592 self.gen_body_one_tile_per_cta_without_bptr(code)
594 def codegen_nd_tile_with_bptr(self, code):
595 """Generate kernel nd tile & 1d grid with gsl support with block pointer."""
596 self.gen_import_function(code)
597 self.gen_decorators(code)
598 self.gen_signature(code, with_block_pointer=True)
600 # function body for rank-0
601 if self.ndim == 0:
602 with code.indent():
603 self.gen_body_for_0d(code)
604 return code
606 with code.indent():
607 code.writeline("pid = tle.program_id(0)")
608 self.gen_num_tiles(code)
609 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
610 code.writeline("if one_tile_per_cta: # monolitic kernel style")
611 with code.indent():
612 code.writeline("tile_id = pid")
613 self.gen_body_one_tile_per_cta_with_bptr(code)
614 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
615 code.writeline("else: # grid-stride-loop style kernel")
616 with code.indent():
617 self.gen_body_gsl_with_bptr(code)
618 code.newline()
619 return code
621 def codegen_nd_tile_without_bptr(self, code):
622 self.gen_import_function(code)
623 self.gen_decorators(code)
624 self.gen_signature(code, with_block_pointer=False)
626 # function body for rank-0
627 if self.ndim == 0:
628 with code.indent():
629 self.gen_body_for_0d(code)
630 return code
632 with code.indent():
633 code.writeline("pid = tle.program_id(0)")
634 self.gen_num_tiles(code)
635 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
636 code.writeline("if one_tile_per_cta: # monolitic kernel style")
637 with code.indent():
638 code.writeline("tile_id = pid")
639 self.gen_body_one_tile_per_cta_without_bptr(code)
640 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
641 code.writeline("else: # grid-stride-loop style kernel")
642 with code.indent():
643 self.gen_body_gsl_without_bptr(code)
644 code.newline()
645 return code
647 def codegen_nd_tile(self, code):
648 use_block_pointer = self.config.prefer_block_pointer
649 if use_block_pointer:
650 self.codegen_nd_tile_with_bptr(code)
651 else:
652 self.codegen_nd_tile_without_bptr(code)
653 return code
655 def gen_body_one_tile_per_cta_1d_tile(self, code):
656 ndim = self.ndim
657 schema = self.fx
659 # tile id
660 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)")
661 code.writeline("mask = tid < num_tasks")
663 # multi index reconstruction
664 for i in reversed(range(ndim)):
665 if i > 0:
666 code.writeline(f"i{i} = tid % s{i}")
667 code.writeline(f"tid //= s{i}")
668 else:
669 code.writeline(f"i{i} = tid")
670 code.newline()
672 # loads
673 code.writeline("# loads")
674 for i in range(schema.num_input_tensors()):
675 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim))
676 offset_combine = " + ".join(offsets)
677 code.writeline(
678 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
679 )
681 code.newline()
683 # compute
684 # TODO: sepearate this part
685 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
686 outputs_to_scalar_fn = [
687 self.output_name(i) for i in range(schema.num_output_tensors())
688 ]
689 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
690 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
692 code.writeline("# compute")
693 code.writeline(
694 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
695 )
696 code.newline()
698 # stores
699 for i in range(schema.num_output_tensors()):
700 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim))
701 offset_combine = " + ".join(offsets)
702 code.writeline(
703 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
704 )
706 def gen_body_gsl_1d_tile(self, code):
707 code.writeline("num_ctas = tle.num_programs(0)")
708 code.writeline("for j in range(0, tiles_per_cta):")
709 with code.indent():
710 code.writeline("tile_id = pid + j * num_ctas")
711 self.gen_body_one_tile_per_cta_1d_tile(code)
713 def codegen_1d_tile(self, code):
714 """Generate kernel 1d tile & 1d grid with gsl support."""
715 self.gen_import_function(code)
716 self.gen_decorators(code)
717 self.gen_signature_1d_tile(code)
719 # function body for rank-0
720 if self.ndim == 0:
721 with code.indent():
722 self.gen_body_for_0d(code)
723 return code
725 with code.indent():
726 code.writeline("pid = tle.program_id(0)")
727 # code.writeline("num_ctas = te.num_programs(0)")
728 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
729 code.writeline("if one_tile_per_cta: # monolitic kernel style")
730 with code.indent():
731 code.writeline("tile_id = pid")
732 self.gen_body_one_tile_per_cta_1d_tile(code)
733 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
734 code.writeline("else: # grid-stride-loop style kernel")
735 with code.indent():
736 self.gen_body_gsl_1d_tile(code)
737 code.newline()
738 return code
741class WrapperGenerator:
742 def __init__(
743 self,
744 function_schema: FunctionSchema,
745 jit_fn_name: str,
746 ndim: int,
747 name: str,
748 config: CodeGenConfig,
749 ):
750 self.fx = function_schema
751 self.jit_fn_name = jit_fn_name
752 self.ndim = ndim
753 self.name = name
754 self.config = config
756 def input_name(self, i):
757 is_tensor = self.fx.is_tensor(i)
758 name = "in" if is_tensor else "val"
759 index = self.fx.input_index(i)
760 return f"{name}{index}"
762 def output_name(self, i):
763 return f"out{i}"
765 def gen_signature(self, code: IndentedBuffer):
766 # TODO: check if triton handles constexprs transitively
767 schema = self.fx
768 params: List[str] = []
769 for i in range(schema.num_inputs()):
770 if schema.is_tensor(i):
771 params.append(
772 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]"
773 )
774 else:
775 arg_type = schema.input_type(i)
776 if arg_type is not None:
777 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}")
778 else:
779 params.append(f"{self.input_name(i)}")
780 # NOTE: [the wrapper's signature and rules for passing parameters ]
781 # input params: must be passed by position, since the names are renamed to
782 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd
783 # So we enforce that these parameters must be passed by position.
784 # maybe we can fix it later
785 # output parameters: must be passed by keyword, since the scalar function
786 # do not have output parameters(think of it as some scalar function, output
787 # parameter does not make sense in this case.) They are added to allow destination
788 # passing style API. Output parameter is convenient in cases where we want
789 # to use some pre-defiend outputs(especially when they are some views of other
790 # tensors). We emphasize that these parameters are added in-addition, we enforce
791 # that they be passed by keyword. After all, out0, out1, ... does not mismatch
792 # names form the scalar function, since it does not have output parameters.
793 params.append("/")
794 params.append("*") # output params must be passed by keyword
796 for i in range(schema.num_output_tensors()):
797 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]")
798 code.writeline(f"def {self.name}({_cs(params)}): ")
800 def gen_docstring(self, code: IndentedBuffer):
801 schema = self.fx
802 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""'
803 code.writeline(doc)
805 def gen_same_shape_check(self, code: IndentedBuffer):
806 schema: FunctionSchema = self.fx
807 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [
808 f"out{i}.shape" for i in range(schema.num_output_tensors())
809 ]
810 check: str = " == ".join(params)
811 code.writeline(f"assert {check}, 'operand shapes mismatch'")
813 def gen_task_partition(self, code: IndentedBuffer):
814 code.writeline("# task partitioning")
815 ndim = self.ndim
816 if ndim == 0:
817 code.writeline("num_warps = 1")
818 code.writeline("num_ctas = 1")
819 else:
820 code.writeline("shape = out0.shape")
821 code.writeline("num_tasks = out0.numel()")
822 code.writeline("if num_tasks == 0:")
823 with code.indent():
824 self.gen_return(code)
825 max_tile_size = self.config.max_tile_size
826 major, _ = get_device_capability()
827 if self.name.find("fill_scalar") != -1 and major >= 9:
828 code.writeline("tile_sizes = tuple([64])")
829 else:
830 code.writeline(
831 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)"
832 )
833 code.writeline("tile_size = math.prod(tile_sizes)")
834 code.writeline(
835 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))"
836 )
838 if self.name.find("fill_scalar") != -1 and major >= 9:
839 code.writeline("num_ctas = num_tiles")
840 else:
841 max_grid_size0 = self.config.max_grid_size[0]
842 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
844 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
845 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
846 code.writeline("one_tile_per_cta = tiles_per_cta==1")
847 code.writeline("grid = (num_ctas, 1, 1)")
849 def gen_task_partition_1d(self, code: IndentedBuffer):
850 code.writeline("# task partitioning")
851 ndim = self.ndim
852 if ndim == 0:
853 code.writeline("num_warps = 1")
854 code.writeline("num_ctas = 1")
855 else:
856 code.writeline("shape = out0.shape")
857 code.writeline("num_tasks = out0.numel()")
858 code.writeline("if num_tasks == 0:")
859 with code.indent():
860 self.gen_return(code)
861 max_tile_size = self.config.max_tile_size
863 major, _ = get_device_capability()
864 if self.name.find("fill_scalar") != -1 and major >= 9:
865 code.writeline("tile_sizes = tuple([64])")
866 else:
867 code.writeline(
868 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)"
869 )
871 code.writeline("tile_size = tile_sizes[0]")
872 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)")
874 if self.name.find("fill_scalar") != -1 and major >= 9:
875 code.writeline("num_ctas = num_tiles")
876 else:
877 max_grid_size0 = self.config.max_grid_size[0]
878 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
880 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
881 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
882 code.writeline("one_tile_per_cta = tiles_per_cta==1")
883 code.writeline("grid = (num_ctas, 1, 1)")
885 def gen_kernel_launch(
886 self,
887 code: IndentedBuffer,
888 ):
889 schema = self.fx
890 ndim = self.ndim
892 with_block_pointer = self.config.prefer_block_pointer
894 code.writeline("# kernel launch")
895 for i in range(schema.num_input_tensors()):
896 code.writeline(f"in{i}_strides = in{i}.stride()")
897 if not with_block_pointer:
898 continue
899 if ndim >= 2: # where ndim is 1, we don't need to compute stride order
900 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)")
901 else:
902 code.writeline(f"in{i}_stride_order = (0,)")
903 for i in range(schema.num_output_tensors()):
904 code.writeline(f"out{i}_strides = out{i}.stride()")
905 if not with_block_pointer:
906 continue
907 if ndim >= 2:
908 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)")
909 else:
910 code.writeline(f"out{i}_stride_order = (0,)")
912 code.writeline("with torch_device_fn.device(in0.device.index):")
913 with code.indent():
914 code.writeline(f"{self.jit_fn_name}[grid](")
915 with code.indent():
916 params = []
917 # NOTE: WRAP
918 for i in range(schema.num_inputs()):
919 if schema.is_tensor(i):
920 params.append(f"{self.input_name(i)}")
921 else:
922 params.append(self.input_name(i))
923 for i in range(schema.num_output_tensors()):
924 params.append(f"{self.output_name(i)}")
926 code.writeline(f"{_cs(params)},")
928 if ndim > 0:
929 for i in range(schema.num_input_tensors()):
930 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
931 code.writeline(f"{s}, # stride for in{i}")
932 if not with_block_pointer:
933 continue
934 order = ", ".join(
935 f"in{i}_stride_order[{j}]" for j in range(ndim)
936 )
937 code.writeline(f"{order}, # stride order for in{i}")
939 for i in range(schema.num_output_tensors()):
940 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
941 code.writeline(f"{s}, # stride for out{i}")
942 if not with_block_pointer:
943 continue
944 order = ", ".join(
945 f"out{i}_stride_order[{j}]" for j in range(ndim)
946 )
947 code.writeline(f"{order}, # stride orderfor out{i}")
949 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
950 code.writeline(f"{shape_args}, # task indexing space")
951 code.writeline("num_tasks, # num tasks")
952 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
953 for i in range(ndim):
954 code.writeline(f"tile_size{i}=tile_sizes[{i}],")
955 code.writeline("one_tile_per_cta=one_tile_per_cta,")
956 code.writeline("num_warps=num_warps,")
957 code.writeline(")")
959 def gen_kernel_launch_1d(
960 self,
961 code: IndentedBuffer,
962 ):
963 schema = self.fx
964 ndim = self.ndim
966 code.writeline("# kernel launch")
967 for i in range(schema.num_input_tensors()):
968 code.writeline(f"in{i}_strides = in{i}.stride()")
969 for i in range(schema.num_output_tensors()):
970 code.writeline(f"out{i}_strides = out{i}.stride()")
972 code.writeline("with torch_device_fn.device(in0.device.index):")
973 with code.indent():
974 code.writeline(f"{self.jit_fn_name}[grid](")
975 with code.indent():
976 params = []
977 # NOTE: WRAP
978 for i in range(schema.num_inputs()):
979 if schema.is_tensor(i):
980 params.append(f"{self.input_name(i)}")
981 else:
982 params.append(self.input_name(i))
983 for i in range(schema.num_output_tensors()):
984 params.append(f"{self.output_name(i)}")
986 code.writeline(f"{_cs(params)},")
988 if ndim > 0:
989 for i in range(schema.num_input_tensors()):
990 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
991 code.writeline(f"{s}, # stride for in{i}")
992 for i in range(schema.num_output_tensors()):
993 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
994 code.writeline(f"{s}, # stride for out{i}")
996 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
997 code.writeline(f"{shape_args}, # task indexing space")
998 code.writeline("num_tasks, # num tasks")
999 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
1000 code.writeline("tile_size=tile_size,")
1001 code.writeline("one_tile_per_cta=one_tile_per_cta,")
1002 code.writeline("num_warps=num_warps,")
1003 code.writeline(")")
1005 def gen_return(self, code: IndentedBuffer):
1006 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors()))
1007 code.writeline(f"return {return_exprs}")
1009 def codegen_nd_tile(self, code):
1010 self.gen_signature(code)
1012 with code.indent():
1013 self.gen_docstring(code)
1014 self.gen_same_shape_check(code)
1015 self.gen_task_partition(code)
1016 self.gen_kernel_launch(code)
1017 self.gen_return(code)
1018 code.newline()
1019 return code
1021 def codegen_1d_tile(self, code):
1022 self.gen_signature(code)
1024 with code.indent():
1025 self.gen_docstring(code)
1026 self.gen_same_shape_check(code)
1027 self.gen_task_partition_1d(code)
1028 self.gen_kernel_launch_1d(code)
1029 self.gen_return(code)
1030 code.newline()
1031 return code
1034class ModuleGenerator:
1035 def __init__(
1036 self,
1037 function_schema: FunctionSchema,
1038 scalar_fn: triton.JITFunction,
1039 ndim: int,
1040 jit_fn_name: str,
1041 wrapper_name: str,
1042 config: CodeGenConfig,
1043 ):
1044 self.config = config
1045 self.scalar_fn = scalar_fn
1046 self.wrapper_gen = WrapperGenerator(
1047 function_schema, jit_fn_name, ndim, wrapper_name, config
1048 )
1049 self.kernel_gen = KernelGenerator(
1050 function_schema, scalar_fn, ndim, jit_fn_name, config
1051 )
1053 @staticmethod
1054 def _collect_jit_deps(scalar_fn):
1055 """Collect extra imports and local @triton.jit helper sources.
1057 Parses the source module where scalar_fn is defined using AST.
1058 Returns a tuple of:
1059 - extra_imports: dict of module_path -> set of names
1060 - local_sources: list of source strings for local @triton.jit
1061 functions (those NOT decorated with @pointwise_dynamic)
1062 """
1063 import ast
1064 import inspect
1066 py_fn = getattr(scalar_fn, "fn", scalar_fn)
1067 module_name = getattr(py_fn, "__module__", None)
1068 if not module_name:
1069 return {}, []
1070 try:
1071 mod = importlib.import_module(module_name)
1072 source_file = inspect.getfile(mod)
1073 except (ImportError, TypeError, OSError):
1074 return {}, []
1075 try:
1076 with open(source_file) as f:
1077 module_source = f.read()
1078 source_lines = module_source.splitlines(keepends=True)
1079 tree = ast.parse(module_source)
1080 except (OSError, SyntaxError):
1081 return {}, []
1083 # Collect non-standard import-from lines
1084 ALREADY_IMPORTED = {
1085 "math",
1086 "typing",
1087 "torch",
1088 "triton",
1089 "triton.language",
1090 "flag_gems.utils.shape_utils",
1091 "flag_gems.utils.tensor_wrapper",
1092 "flag_gems.utils.libentry",
1093 "flag_gems.utils",
1094 "flag_gems.runtime",
1095 "flag_gems.utils.pointwise_dynamic",
1096 }
1097 extra_imports = {}
1098 for node in ast.iter_child_nodes(tree):
1099 if isinstance(node, ast.ImportFrom) and node.module:
1100 if node.module in ALREADY_IMPORTED:
1101 continue
1102 names = {alias.name for alias in node.names}
1103 extra_imports.setdefault(node.module, set()).update(names)
1105 # Collect local @triton.jit functions (without @pointwise_dynamic)
1106 def _has_decorator(func_node, name):
1107 for dec in func_node.decorator_list:
1108 src = "".join(source_lines[dec.lineno - 1 : dec.end_lineno])
1109 if name in src:
1110 return True
1111 return False
1113 def _extract_source(func_node):
1114 start = func_node.lineno - 1
1115 if func_node.decorator_list:
1116 start = func_node.decorator_list[0].lineno - 1
1117 end = func_node.end_lineno
1118 return "".join(source_lines[start:end])
1120 local_sources = []
1121 for node in ast.iter_child_nodes(tree):
1122 if not isinstance(node, ast.FunctionDef):
1123 continue
1124 if not _has_decorator(node, "triton.jit") and not _has_decorator(
1125 node, "jit"
1126 ):
1127 continue
1128 if _has_decorator(node, "pointwise_dynamic"):
1129 continue
1130 local_sources.append(_extract_source(node))
1132 return extra_imports, local_sources
1134 def generate_imports(self, code: IndentedBuffer) -> IndentedBuffer:
1135 code.writeline("import math")
1136 code.writeline("from typing import Union")
1137 code.writeline("import torch")
1138 code.writeline("import triton")
1139 code.writeline("from triton import language as tl")
1140 code.newline()
1141 code.writeline("from flag_gems.utils.shape_utils import (")
1142 code.writeline(" heuristics_for_tile_size,")
1143 code.writeline(" heuristics_for_num_warps,")
1144 code.writeline(" stride_order,")
1145 code.writeline(")")
1146 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer")
1147 code.writeline("from flag_gems.utils.libentry import libentry")
1148 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
1149 code.writeline("from flag_gems.runtime import torch_device_fn")
1151 # Generate extra imports and local JIT deps of the scalar function
1152 jit_dep_imports, local_jit_sources = self._collect_jit_deps(self.scalar_fn)
1153 for module_path, names in sorted(jit_dep_imports.items()):
1154 sorted_names = ", ".join(sorted(names))
1155 code.writeline(f"from {module_path} import {sorted_names}")
1157 code.newline()
1158 code.newline()
1160 # Emit local @triton.jit helper functions
1161 for source in local_jit_sources:
1162 for line in source.splitlines():
1163 code.writeline(line)
1164 code.newline()
1166 return code
1168 def codegen(self, code: IndentedBuffer):
1169 code = self.generate_imports(code)
1170 if self.config.prefer_1d_tile:
1171 code = self.wrapper_gen.codegen_1d_tile(code)
1172 code = self.kernel_gen.codegen_1d_tile(code)
1173 else:
1174 code = self.wrapper_gen.codegen_nd_tile(code)
1175 code = self.kernel_gen.codegen_nd_tile(code)
1176 return code
1179@dataclass
1180class KernelInfo:
1181 """Information about a generated kernel for C++ integration."""
1183 file_path: str
1184 kernel_name: str
1185 wrapper_name: str
1186 ndim: int
1189class PointwiseDynamicFunction:
1190 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
1191 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
1192 The generated code are written out to the cache directory (defaults to ~/.flaggems).
1193 """
1195 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None):
1196 self.fx = op_desc
1198 assert isinstance(scalar_fn, JITFunction)
1199 self._scalar_fn = scalar_fn
1200 self._scalar_fn_cache_key = scalar_fn.cache_key
1201 self.pid = os.getpid()
1203 self.config: CodeGenConfig = config or get_codegen_config()
1205 # instantiated & cached overloads
1206 self.overloads: Mapping[str, Callable] = {}
1207 # cached kernel info for C++ integration
1208 self._kernel_info_cache: Mapping[str, KernelInfo] = {}
1210 def __call__(self, *args, **kwargs):
1211 # inputs must be passed by position, outputs must be passed by keyword
1212 ndim, args, kwargs = self.prepare_args(*args, **kwargs)
1213 overload = self.instantiate(ndim)
1214 out = overload(*args, **kwargs)
1215 # NOTE: overload keeps the type of outputs:
1216 # if a pre-defiend output is a Tensor or StridedBuffer, the corresponding
1217 # output is also a Tensor StridedBuffer, respectively
1218 # since prepare_args Wraps all the arguments, the outputs are all StridedBuffer
1219 # but if manually instantiated overload is directly called, take care of
1220 # that manually
1221 return self._unwrap(out)
1223 @staticmethod
1224 def use_fast_path(tensors):
1225 return all_the_same_shape(tensors) and (
1226 all_c_contiguous(tensors)
1227 or (
1228 all_the_same_stride(tensors)
1229 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0])
1230 )
1231 )
1233 def prepare_args(self, *args, **kwargs):
1234 # output allocation(when needed)
1235 # task simplification & task-rank infernece & input-output reinterpretation
1236 schema = self.fx
1237 outputs_that_need_allocation: List[int] = []
1238 out_tensors = []
1239 for i in range(schema.num_output_tensors()):
1240 k = f"out{i}"
1241 if k in kwargs:
1242 out_tensors.append(kwargs[k])
1243 else:
1244 outputs_that_need_allocation.append(i)
1245 # input arguments must be passed by position
1246 if schema._is_tensor is not None:
1247 if not check_tensor_attributes(args, (schema._is_tensor)):
1248 raise ValueError(
1249 "Input arguments must be passed by position, and the corresponding dtype must be specified."
1250 )
1251 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)]
1253 # output dtype promotions
1254 outputs_dtypes_for_allocation = []
1255 for i in outputs_that_need_allocation:
1256 *arg_indices, method = schema._promotion_methods[i]
1257 promote_args = (args[j] for j in arg_indices)
1258 _, dtype = type_promotion(*promote_args, type_promotion=method)
1259 outputs_dtypes_for_allocation.append(dtype)
1261 tensors = out_tensors + in_tensors
1262 INT32_MAX = torch.iinfo(torch.int32).max
1263 if tensors[0].numel() > INT32_MAX:
1264 self.config.prefer_block_pointer = False
1265 if self.use_fast_path(tensors): # dimension collapse & use physical ordering
1266 allocated_outputs = [
1267 torch.empty_like(tensors[0], dtype=dtype)
1268 for dtype in outputs_dtypes_for_allocation
1269 ]
1270 task_shape = (tensors[0].numel(),)
1271 strides = (1,)
1272 ndim = 1
1273 args = tuple(
1274 (
1275 StridedBuffer(item, task_shape, strides)
1276 if schema.is_tensor(i)
1277 else item
1278 )
1279 for i, item in enumerate(args)
1280 )
1281 kwargs = {
1282 k: StridedBuffer(item, task_shape, strides)
1283 for k, item in kwargs.items()
1284 }
1285 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1286 kwargs[f"out{output_id}"] = StridedBuffer(
1287 allocated_outputs[seq_id], task_shape, strides
1288 )
1289 else:
1290 # a simple strategy: all the undefined tensors will follow the first
1291 # tensor that is not broadcated, no attempts to simplify task, no reordering,
1292 # no dimenion collapsing
1293 shapes = tuple(item.shape for item in in_tensors)
1295 task_shape = broadcast_shapes(shapes)
1297 if out_tensors:
1298 for index, item in enumerate(out_tensors):
1299 if list(item.shape) != list(task_shape):
1300 raise RuntimeError(
1301 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!"
1302 )
1303 # output arguments must not have internal overlapping for pointwise operation
1304 if has_internal_overlapping(item) == MemOverlap.Yes:
1305 raise RuntimeError(
1306 "Pointwise Input arguments should not have internal overlapping."
1307 )
1309 ndim = len(task_shape)
1310 for item in tensors:
1311 if item.shape == task_shape:
1312 allocated_outputs = [
1313 torch.empty_like(item, dtype=dtype)
1314 for dtype in outputs_dtypes_for_allocation
1315 ]
1316 break
1317 else: # nobreak
1318 device = tensors[0].device
1319 allocated_outputs = [
1320 torch.empty(task_shape, dtype=dtype, device=device)
1321 for dtype in outputs_dtypes_for_allocation
1322 ]
1323 args = tuple(
1324 (
1325 StridedBuffer(
1326 item,
1327 task_shape,
1328 broadcasted_stride(item.shape, item.stride(), task_shape),
1329 )
1330 if schema.is_tensor(i)
1331 else item
1332 )
1333 for i, item in enumerate(args)
1334 )
1335 kwargs = {
1336 k: StridedBuffer(
1337 item,
1338 task_shape,
1339 broadcasted_stride(item.shape, item.stride(), task_shape),
1340 )
1341 for k, item in kwargs.items()
1342 }
1343 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1344 item = allocated_outputs[seq_id]
1345 kwargs[f"out{output_id}"] = StridedBuffer(
1346 item,
1347 task_shape,
1348 broadcasted_stride(item.shape, item.stride(), task_shape),
1349 )
1350 return (ndim, args, kwargs)
1352 def _unwrap(self, tensors):
1353 # unwrap StridedBuffer to get Tensor
1354 if self.fx.num_output_tensors() == 1:
1355 item = tensors
1356 return item.unwrap()
1357 return tuple(item.unwrap() for item in tensors)
1359 def _compute_kernel_names(self, ndim: int) -> Tuple[str, str, str]:
1360 """Compute kernel name, wrapper name, and file path for a given ndim.
1362 This is the single source of truth for naming, used by both instantiate()
1363 and get_kernel_info() to ensure consistency.
1365 Returns:
1366 Tuple of (kernel_name, wrapper_name, file_path)
1367 """
1368 scalar_fn_name = self._scalar_fn.__name__
1369 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}"
1370 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}"
1372 file_name = (
1373 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_"
1374 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}"
1375 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}"
1376 ".py"
1377 )
1378 file_path = str(code_cache_dir() / file_name)
1380 return kernel_name, wrapper_name, file_path
1382 def instantiate(self, ndim):
1383 # NOTE: manually instantiated overload does not have `prepare_args` as
1384 # preprocessing, so you have to manually allocate output and make sure that
1385 # the inputs & ouputs actually fits the manually instantiated overload
1386 key = f"{ndim}_{self.config.prefer_block_pointer}"
1387 if key in self.overloads:
1388 return self.overloads[key]
1390 code = IndentedBuffer()
1392 # Use helper to compute names (single source of truth)
1393 kernel_name, wrapper_name, file_path = self._compute_kernel_names(ndim)
1395 module_gen = ModuleGenerator(
1396 self.fx,
1397 self._scalar_fn,
1398 ndim,
1399 kernel_name,
1400 wrapper_name,
1401 self.config,
1402 )
1403 module_gen.codegen(code)
1405 # NOTE: [why write the generated code to a file]
1406 # triton uses inpsect to get the source of the jitted function, which requires
1407 # that the source code can be found by inspect
1408 # We write it into a file, since inspect cannot find the source of functions dynamically
1409 # created via exec string. We can help inspect to find the source by hacking linecache
1410 # library, but we find generating a module simpler, since we can generating 2 functions
1411 # the kernel and the wrapper, and the wrapper calls the kernel.
1412 write_atomic(file_path, code.getvalue())
1414 # load
1415 spec = importlib.util.spec_from_file_location(
1416 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}",
1417 file_path,
1418 )
1419 m = importlib.util.module_from_spec(spec)
1420 # do not expose it to sys.modules
1421 # sys.modules["_add_module"] = m
1423 # NOTE: [why not import the scalar function]
1424 # we do not re-import the scalar function, although the generated kernel **calls** it
1425 # Since a function's __name__ may be changed, from the module where it is defined import its
1426 # __name__ is not same; Also the same may be rebind to something else, importing via name
1427 # cannot guarantee that scalar function is imported.
1428 # So we copy the scalar function and its __globals__ to the generated module to do this
1429 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime
1430 spec.loader.exec_module(m)
1431 m.__dict__.update(self._scalar_fn.__globals__)
1432 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn
1434 overload = getattr(m, wrapper_name)
1435 self.overloads[key] = overload
1437 # Cache kernel info for C++ integration
1438 self._kernel_info_cache[key] = KernelInfo(
1439 file_path=file_path,
1440 kernel_name=kernel_name,
1441 wrapper_name=wrapper_name,
1442 ndim=ndim,
1443 )
1445 return overload
1447 def get_kernel_info(self, ndim: int) -> KernelInfo:
1448 """Get kernel information for a given ndim.
1450 This method is useful for C++ integration to get the file path and
1451 kernel name without duplicating the naming logic.
1453 If the kernel hasn't been instantiated yet, this will instantiate it first.
1455 Args:
1456 ndim: The rank of the task space
1458 Returns:
1459 KernelInfo with file_path, kernel_name, wrapper_name, and ndim
1460 """
1461 key = f"{ndim}_{self.config.prefer_block_pointer}"
1463 # Ensure the kernel is instantiated
1464 if key not in self._kernel_info_cache:
1465 self.instantiate(ndim)
1467 return self._kernel_info_cache[key]
1470def pointwise_dynamic(
1471 f: Optional[JITFunction] = None,
1472 *,
1473 num_inputs: Optional[int] = None,
1474 is_tensor: Optional[List[bool]] = None,
1475 dtypes: Optional[List[Optional[type]]] = None,
1476 num_outputs: Optional[int] = None,
1477 promotion_methods: Optional[Tuple[int, ...]] = None,
1478 config: Optional[CodeGenConfig] = None,
1479):
1480 def decorator(fn):
1481 nonlocal num_inputs
1482 if (num_inputs is None) and (is_tensor is None) and (dtypes is None):
1483 num_inputs = len(fn.arg_names)
1484 op_desc = FunctionSchema(
1485 num_inputs=num_inputs,
1486 is_tensor=is_tensor,
1487 dtypes=dtypes,
1488 num_outputs=num_outputs,
1489 promotion_methods=promotion_methods,
1490 )
1491 return PointwiseDynamicFunction(op_desc, fn, config)
1493 if f is not None:
1494 return decorator(f)
1495 return decorator