Coverage for src/flag_gems/utils/pointwise_dynamic_backup.py: 0%
805 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import importlib
2import os
3from dataclasses import dataclass
4from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple
6import torch
7import triton
8from triton.runtime.jit import JITFunction
10from flag_gems.utils.code_cache import code_cache_dir
11from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
12from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config
13from flag_gems.utils.device_info import get_device_capability
14from flag_gems.utils.shape_utils import (
15 MemOverlap,
16 all_c_contiguous,
17 all_the_same_shape,
18 all_the_same_stride,
19 broadcast_shapes,
20 broadcasted_stride,
21 check_tensor_attributes,
22 has_internal_overlapping,
23)
24from flag_gems.utils.tensor_wrapper import StridedBuffer
25from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion
28# ------------------ Operation Description ---------------------------
29def _type_name(type) -> str:
30 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object"
31 if type in (bool, int, float, str):
32 return type.__name__
33 if isinstance(type, torch.dtype):
34 return str(type)
35 return str(type)
38def _check_typed_list(container, type):
39 for item in container:
40 assert isinstance(item, type)
43def _check_sized_list(container, size):
44 assert len(container) == size
47def _tuple_content(strings: Sequence[str]) -> str:
48 # comma separated list
49 if len(strings) == 0:
50 return ""
51 if len(strings) == 1:
52 return f"{strings[0]},"
53 else:
54 return ", ".join(strings)
57def _cs(strings: Iterable[str]) -> str:
58 return ", ".join(strings)
61def _broadcast_vec(i, ndim):
62 axes = [":" if j == i else "None" for j in range(ndim)]
63 return f"[{_cs(axes)}]"
66class FunctionSchema:
67 _num_inputs: int
68 _is_tensor: List[bool]
69 _dtypes: List[Optional[type]]
71 _num_input_tensors: int
72 _num_non_tensor_inputs: int
74 _num_outputs: int
75 _promotion_methods: List[Tuple[int, ...]]
77 def __init__(
78 self,
79 *,
80 num_inputs: Optional[int] = None,
81 is_tensor: Optional[List[bool]] = None,
82 dtypes: Optional[List[Optional[type]]] = None,
83 num_outputs: Optional[int] = None,
84 promotion_methods=None,
85 ):
86 if is_tensor is not None:
87 _check_typed_list(is_tensor, bool)
88 if dtypes is not None:
89 _check_typed_list(dtypes, (type, type(None)))
91 if promotion_methods is None:
92 raise ValueError(
93 "No type promotion method provided! You must provide type promotion method for each output!"
94 )
95 else:
96 self._promotion_methods = self.canonicalize_promotion_methods(
97 promotion_methods
98 )
99 if num_inputs is not None:
100 self._num_inputs = num_inputs
101 if is_tensor is not None:
102 _check_sized_list(is_tensor, num_inputs)
103 self._is_tensor = is_tensor
104 else:
105 self._is_tensor = [True] * num_inputs
107 if dtypes is not None:
108 _check_sized_list(dtypes, num_inputs)
109 self._dtypes = dtypes
110 else:
111 self._dtypes = [None] * num_inputs
112 elif is_tensor is not None:
113 self._num_inputs = len(is_tensor)
114 self._is_tensor = is_tensor
115 if dtypes is not None:
116 _check_sized_list(dtypes, self._num_inputs)
117 self._dtypes = dtypes
118 else:
119 self._dtypes = [None] * self._num_inputs
120 elif dtypes is not None:
121 self._num_inputs = len(dtypes)
122 self._dtypes = dtypes
123 if is_tensor is not None:
124 _check_sized_list(is_tensor, self._num_inputs)
125 self._is_tensor = is_tensor
126 else:
127 self._is_tensor = [item is None for item in dtypes]
128 else:
129 raise ValueError(
130 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified."
131 )
133 if num_outputs is not None:
134 self._num_outputs = num_outputs
135 _check_sized_list(promotion_methods, num_outputs)
136 else:
137 self._num_outputs = len(promotion_methods)
139 assert self._num_inputs >= 1
140 assert self._num_outputs >= 1
142 self._num_input_tensors = sum(self._is_tensor)
143 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors
144 self._input_id = self._compute_input_id()
146 @staticmethod
147 def canonicalize_promotion_methods(promotion_methods):
148 canonicalized = []
149 for item in promotion_methods:
150 *arg_indices, method = item
151 canonicalized.append(
152 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method])
153 )
154 return canonicalized
156 def num_inputs(self):
157 # num of arguments, outputs not included
158 return self._num_inputs
160 def num_outputs(self):
161 return self._num_outputs
163 def is_tensor(self, arg_id: int) -> bool:
164 return self._is_tensor[arg_id]
166 def input_type(self, arg_id) -> Optional[type]:
167 return self._dtypes[arg_id]
169 def output_type(self, i):
170 return self._promotion_methods[i]
172 def num_input_tensors(self) -> int:
173 return self._num_input_tensors
175 def num_output_tensors(self) -> int:
176 return self._num_outputs
178 def num_non_tensor_args(self) -> int:
179 return self._num_non_tensor_inputs
181 def signature(self, outputs_in_arg: bool = False) -> str:
182 input_types = []
183 for is_tensor, dtype in zip(self._is_tensor, self._dtypes):
184 if is_tensor:
185 input_types.append("StridedBuffer")
186 else:
187 if dtype is None:
188 input_types.append("scalar")
189 else:
190 input_types.append(_type_name(dtype))
192 output_types = []
194 if outputs_in_arg:
195 for i in range(self.num_outputs()):
196 output_types.append(f"StridedBuffer(a{1}!)")
197 input_types.extend(output_types)
198 else:
199 for _ in range(self.num_outputs()):
200 output_types.append("StridedBuffer")
201 sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}'
202 return sig
204 def _compute_input_id(self):
205 input_tensor_index = 0
206 non_tensor_index = 0
207 mapping: List[int] = []
208 for i in range(self.num_inputs()):
209 if self.is_tensor(i):
210 mapping.append(input_tensor_index)
211 input_tensor_index += 1
212 else:
213 mapping.append(non_tensor_index)
214 non_tensor_index += 1
215 return mapping
217 def input_index(self, idx):
218 return self._input_id[idx]
220 def __str__(self) -> str:
221 return self.signature(outputs_in_arg=False)
224class KernelGenerator:
225 def __init__(
226 self,
227 function_schema: FunctionSchema,
228 scalar_fn: triton.JITFunction,
229 rank: int,
230 name: str,
231 config: CodeGenConfig,
232 ):
233 self.fx = function_schema
234 self.fn = scalar_fn
235 self.ndim = rank
236 self.name = name
237 self.config = config
239 self.fn_name = scalar_fn.__name__
240 self.fn_module = scalar_fn.__module__
242 def gen_import_function(self, code: IndentedBuffer):
243 code.writeline(f'"""Quoted source of {self.fn_name}:')
244 code.writemultiline(self.fn.src)
245 code.writeline('"""')
246 code.newline()
248 def gen_decorators(self, code):
249 code.writeline("@libentry()")
250 num_non_tensor_args = self.fx.num_non_tensor_args()
251 if num_non_tensor_args > 0:
252 # we do not specialize non tensor args since they are passed into the inlined function
253 # which means that their values may not deserve specialization
254 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)]
255 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
256 else:
257 code.writeline("@triton.jit")
259 def input_name(self, i):
260 is_tensor = self.fx.is_tensor(i)
261 name = "in" if is_tensor else "val"
262 index = self.fx.input_index(i)
263 return f"{name}{index}"
265 def output_name(self, i):
266 return f"out{i}"
268 def gen_signature(self, code, with_block_pointer=False):
269 code.writeline(f"def {self.name}(")
270 with code.indent():
271 input_tensor_index = 0
272 non_tensor_index = 0
273 output_tensor_index = 0
275 schema = self.fx
276 # signature: inputs ptrs & non tensor inputs
277 for i in range(schema.num_inputs()):
278 if schema.is_tensor(i):
279 code.writeline(
280 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
281 )
282 input_tensor_index += 1
283 else:
284 if schema.input_type(i) is not None:
285 code.writeline(
286 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
287 )
288 else:
289 code.writeline(f"val{non_tensor_index},")
290 non_tensor_index += 1
292 # signature: output ptrs
293 for i in range(schema.num_outputs()):
294 code.writeline(
295 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
296 )
297 output_tensor_index += 1
299 # signature: strides, for each tensor arguments
300 ndim = self.ndim
301 if ndim > 0:
302 # strides for inputs
303 for i in range(schema.num_input_tensors()):
304 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
305 code.writeline(f"{stride_args}, # strides for in{i}")
306 if with_block_pointer:
307 stride_order_args = _cs(
308 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
309 )
310 code.writeline(f"{stride_order_args}, # stride order for in{i}")
312 # strides for outputs
313 for i in range(schema.num_output_tensors()):
314 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
315 code.writeline(f"{stride_args}, # strides for out{i}")
316 if with_block_pointer:
317 stride_order_args = _cs(
318 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
319 )
320 code.writeline(
321 f"{stride_order_args}, # stride order for out{i}"
322 )
324 # task space, used to reconstruct multi index
325 task_space_args = _cs(f"s{i}" for i in range(ndim))
326 code.writeline(f"{task_space_args}, # task_space")
328 # number of tasks, used to compute mask
329 code.writeline("num_tasks,")
331 # tile size & tiles_per_cta, gsl style
332 if ndim > 0:
333 code.writeline("tiles_per_cta: int,")
334 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim))
335 code.writeline(f"{tile_sizes},")
336 code.writeline("one_tile_per_cta: tl.constexpr,")
337 code.writeline("):")
339 def gen_signature_1d_tile(self, code):
340 code.writeline(f"def {self.name}(")
341 with code.indent():
342 input_tensor_index = 0
343 non_tensor_index = 0
344 output_tensor_index = 0
346 schema = self.fx
347 # signature: inputs ptrs & non tensor inputs
348 for i in range(schema.num_inputs()):
349 if schema.is_tensor(i):
350 code.writeline(
351 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
352 )
353 input_tensor_index += 1
354 else:
355 if schema.input_type(i) is not None:
356 code.writeline(
357 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
358 )
359 else:
360 code.writeline(f"val{non_tensor_index},")
361 non_tensor_index += 1
363 # signature: output ptrs
364 for i in range(schema.num_outputs()):
365 code.writeline(
366 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
367 )
368 output_tensor_index += 1
370 # signature: strides, for each tensor arguments
371 ndim = self.ndim
372 if ndim > 0:
373 # strides for inputs
374 for i in range(schema.num_input_tensors()):
375 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
376 code.writeline(f"{stride_args}, # strides for in{i}")
378 # strides for outputs
379 for i in range(schema.num_output_tensors()):
380 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
381 code.writeline(f"{stride_args}, # strides for out{i}")
383 # task space, used to reconstruct multi index
384 task_space_args = _cs(f"s{i}" for i in range(ndim))
385 code.writeline(f"{task_space_args}, # task_space")
387 # number of tasks, used to compute mask
388 code.writeline("num_tasks,")
390 # tile size & tiles_per_cta, gsl style
391 if ndim > 0:
392 code.writeline("tiles_per_cta: int,")
393 code.writeline("tile_size: tl.constexpr,")
394 code.writeline("one_tile_per_cta: tl.constexpr,")
395 code.writeline("):")
397 def gen_num_tiles(self, code):
398 # tile-grid size
399 ndim = self.ndim
400 for i in range(ndim):
401 if i < ndim:
402 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})")
404 def gen_body_for_0d(self, code):
405 schema = self.fx
406 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
407 outputs_to_scalar_fn = [
408 self.output_name(i) for i in range(schema.num_output_tensors())
409 ]
410 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
411 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
413 code.writeline("# loads")
414 for i in range(schema.num_input_tensors()):
415 code.writeline(
416 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) "
417 "# workaround the bug on bool, we should use the pointer's dtype)"
418 )
419 code.newline()
421 code.writeline("# compute")
422 code.writeline(
423 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
424 )
425 code.newline()
427 code.writeline("# stores")
428 for i in range(schema.num_output_tensors()):
429 code.writeline(
430 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))"
431 )
432 code.newline()
433 return code
435 # nd tile 1d grid kernel with block pointer
436 def gen_body_one_tile_per_cta_with_bptr(self, code):
437 ndim = self.ndim
438 schema = self.fx
440 # block pointer for each operand
441 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim)))
442 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim)))
443 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim)))
445 # reconstruct pid multi index
446 code.writeline(
447 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
448 )
449 for i in reversed(range(ndim)):
450 if i > 0:
451 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
452 code.writeline(f"tile_id //= num_tiles{i}")
453 else:
454 code.writeline(f"tile_id{i} = tile_id")
455 code.newline()
457 # cta_offsets
458 code.writeline("# tile offsets")
459 for i in range(ndim):
460 # Or else: AssertionError: Block pointers only support 32 bit
461 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing
462 # for 64 bit support
463 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)")
465 # loads
466 code.writeline("# loads")
467 for i in range(schema.num_input_tensors()):
468 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim)))
469 order = _tuple_content(tuple(f"in{i}_stride_order{j}" for j in range(ndim)))
470 code.writeline(
471 f"in{i}_bptr = tl.make_block_ptr("
472 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
473 )
474 code.writeline(
475 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) "
476 "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)"
477 )
478 code.newline()
480 # compute
481 # TODO: sepearate this part
482 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
483 outputs_to_scalar_fn = [
484 self.output_name(i) for i in range(schema.num_output_tensors())
485 ]
486 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
487 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
489 code.writeline("# compute")
490 code.writeline(
491 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
492 )
493 code.newline()
495 # stores
496 code.writeline(
497 "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype"
498 )
499 for i in range(schema.num_output_tensors()):
500 strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim)))
501 order = _tuple_content(
502 tuple(f"out{i}_stride_order{j}" for j in range(ndim))
503 )
504 code.writeline(
505 f"out{i}_bptr = tl.make_block_ptr("
506 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
507 )
508 code.writeline(
509 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))"
510 )
512 def gen_body_gsl_with_bptr(self, code):
513 code.writeline("num_ctas = tle.num_programs(0)")
514 code.writeline("for j in range(0, tiles_per_cta):")
515 with code.indent():
516 code.writeline("tile_id = pid + j * num_ctas")
517 self.gen_body_one_tile_per_cta_with_bptr(code)
519 def gen_body_one_tile_per_cta_without_bptr(self, code):
520 ndim = self.ndim
521 schema = self.fx
523 # reconstruct pid multi index
524 code.writeline(
525 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
526 )
527 for i in reversed(range(ndim)):
528 if i > 0:
529 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
530 code.writeline(f"tile_id //= num_tiles{i}")
531 else:
532 code.writeline(f"tile_id{i} = tile_id")
533 code.newline()
535 # offsets
536 for i in range(ndim):
537 code.writeline(
538 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})"
539 )
541 # masks
542 for i in range(ndim):
543 code.writeline(f"mask{i} = offsets{i} < s{i}")
544 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim))
545 mask_combine = " & ".join(masks)
546 code.writeline(f"mask = {mask_combine}")
548 # loads
549 code.writeline("# loads")
550 for i in range(schema.num_input_tensors()):
551 offsets = tuple(
552 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}"
553 for j in range(ndim)
554 )
555 offset_combine = " + ".join(offsets)
556 code.writeline(
557 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
558 )
560 code.newline()
562 # compute
563 # TODO: sepearate this part
564 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
565 outputs_to_scalar_fn = [
566 self.output_name(i) for i in range(schema.num_output_tensors())
567 ]
568 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
569 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
571 code.writeline("# compute")
572 code.writeline(
573 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
574 )
575 code.newline()
577 # stores
578 for i in range(schema.num_output_tensors()):
579 offsets = tuple(
580 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}"
581 for j in range(ndim)
582 )
583 offset_combine = " + ".join(offsets)
584 code.writeline(
585 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
586 )
588 def gen_body_gsl_without_bptr(self, code):
589 code.writeline("num_ctas = tle.num_programs(0)")
590 code.writeline("for j in range(0, tiles_per_cta):")
591 with code.indent():
592 code.writeline("tile_id = pid + j * num_ctas")
593 self.gen_body_one_tile_per_cta_without_bptr(code)
595 def codegen_nd_tile_with_bptr(self, code):
596 """Generate kernel nd tile & 1d grid with gsl support with block pointer."""
597 self.gen_import_function(code)
598 self.gen_decorators(code)
599 self.gen_signature(code, with_block_pointer=True)
601 # function body for rank-0
602 if self.ndim == 0:
603 with code.indent():
604 self.gen_body_for_0d(code)
605 return code
607 with code.indent():
608 code.writeline("pid = tle.program_id(0)")
609 self.gen_num_tiles(code)
610 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
611 code.writeline("if one_tile_per_cta: # monolitic kernel style")
612 with code.indent():
613 code.writeline("tile_id = pid")
614 self.gen_body_one_tile_per_cta_with_bptr(code)
615 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
616 code.writeline("else: # grid-stride-loop style kernel")
617 with code.indent():
618 self.gen_body_gsl_with_bptr(code)
619 code.newline()
620 return code
622 def codegen_nd_tile_without_bptr(self, code):
623 self.gen_import_function(code)
624 self.gen_decorators(code)
625 self.gen_signature(code, with_block_pointer=False)
627 # function body for rank-0
628 if self.ndim == 0:
629 with code.indent():
630 self.gen_body_for_0d(code)
631 return code
633 with code.indent():
634 code.writeline("pid = tle.program_id(0)")
635 self.gen_num_tiles(code)
636 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
637 code.writeline("if one_tile_per_cta: # monolitic kernel style")
638 with code.indent():
639 code.writeline("tile_id = pid")
640 self.gen_body_one_tile_per_cta_without_bptr(code)
641 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
642 code.writeline("else: # grid-stride-loop style kernel")
643 with code.indent():
644 self.gen_body_gsl_without_bptr(code)
645 code.newline()
646 return code
648 def codegen_nd_tile(self, code):
649 use_block_pointer = self.config.prefer_block_pointer
650 if use_block_pointer:
651 self.codegen_nd_tile_with_bptr(code)
652 else:
653 self.codegen_nd_tile_without_bptr(code)
654 return code
656 def gen_body_one_tile_per_cta_1d_tile(self, code):
657 ndim = self.ndim
658 schema = self.fx
660 # tile id
661 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)")
662 code.writeline("mask = tid < num_tasks")
664 # multi index reconstruction
665 for i in reversed(range(ndim)):
666 if i > 0:
667 code.writeline(f"i{i} = tid % s{i}")
668 code.writeline(f"tid //= s{i}")
669 else:
670 code.writeline(f"i{i} = tid")
671 code.newline()
673 # loads
674 code.writeline("# loads")
675 for i in range(schema.num_input_tensors()):
676 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim))
677 offset_combine = " + ".join(offsets)
678 code.writeline(
679 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
680 )
682 code.newline()
684 # compute
685 # TODO: sepearate this part
686 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
687 outputs_to_scalar_fn = [
688 self.output_name(i) for i in range(schema.num_output_tensors())
689 ]
690 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
691 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
693 code.writeline("# compute")
694 code.writeline(
695 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
696 )
697 code.newline()
699 # stores
700 for i in range(schema.num_output_tensors()):
701 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim))
702 offset_combine = " + ".join(offsets)
703 code.writeline(
704 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
705 )
707 def gen_body_gsl_1d_tile(self, code):
708 code.writeline("num_ctas = tle.num_programs(0)")
709 code.writeline("for j in range(0, tiles_per_cta):")
710 with code.indent():
711 code.writeline("tile_id = pid + j * num_ctas")
712 self.gen_body_one_tile_per_cta_1d_tile(code)
714 def codegen_1d_tile(self, code):
715 """Generate kernel 1d tile & 1d grid with gsl support."""
716 self.gen_import_function(code)
717 self.gen_decorators(code)
718 self.gen_signature_1d_tile(code)
720 # function body for rank-0
721 if self.ndim == 0:
722 with code.indent():
723 self.gen_body_for_0d(code)
724 return code
726 with code.indent():
727 code.writeline("pid = tle.program_id(0)")
728 # code.writeline("num_ctas = te.num_programs(0)")
729 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
730 code.writeline("if one_tile_per_cta: # monolitic kernel style")
731 with code.indent():
732 code.writeline("tile_id = pid")
733 self.gen_body_one_tile_per_cta_1d_tile(code)
734 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
735 code.writeline("else: # grid-stride-loop style kernel")
736 with code.indent():
737 self.gen_body_gsl_1d_tile(code)
738 code.newline()
739 return code
742class WrapperGenerator:
743 def __init__(
744 self,
745 function_schema: FunctionSchema,
746 jit_fn_name: str,
747 ndim: int,
748 name: str,
749 config: CodeGenConfig,
750 ):
751 self.fx = function_schema
752 self.jit_fn_name = jit_fn_name
753 self.ndim = ndim
754 self.name = name
755 self.config = config
757 def input_name(self, i):
758 is_tensor = self.fx.is_tensor(i)
759 name = "in" if is_tensor else "val"
760 index = self.fx.input_index(i)
761 return f"{name}{index}"
763 def output_name(self, i):
764 return f"out{i}"
766 def gen_signature(self, code: IndentedBuffer):
767 # TODO: check if triton handles constexprs transitively
768 schema = self.fx
769 params: List[str] = []
770 for i in range(schema.num_inputs()):
771 if schema.is_tensor(i):
772 params.append(
773 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]"
774 )
775 else:
776 arg_type = schema.input_type(i)
777 if arg_type is not None:
778 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}")
779 else:
780 params.append(f"{self.input_name(i)}")
781 # NOTE: [the wrapper's signature and rules for passing parameters ]
782 # input params: must be passed by position, since the names are renamed to
783 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd
784 # So we enforce that these parameters must be passed by position.
785 # maybe we can fix it later
786 # output parameters: must be passed by keyword, since the scalar function
787 # do not have output parameters(think of it as some scalar function, output
788 # parameter does not make sense in this case.) They are added to allow destination
789 # passing style API. Output parameter is convenient in cases where we want
790 # to use some pre-defiend outputs(especially when they are some views of other
791 # tensors). We emphasize that these parameters are added in-addition, we enforce
792 # that they be passed by keyword. After all, out0, out1, ... does not mismatch
793 # names form the scalar function, since it does not have output parameters.
794 params.append("/")
795 params.append("*") # output params must be passed by keyword
797 for i in range(schema.num_output_tensors()):
798 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]")
799 code.writeline(f"def {self.name}({_cs(params)}): ")
801 def gen_docstring(self, code: IndentedBuffer):
802 schema = self.fx
803 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""'
804 code.writeline(doc)
806 def gen_same_shape_check(self, code: IndentedBuffer):
807 schema: FunctionSchema = self.fx
808 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [
809 f"out{i}.shape" for i in range(schema.num_output_tensors())
810 ]
811 check: str = " == ".join(params)
812 code.writeline(f"assert {check}, 'operand shapes mismatch'")
814 def gen_task_partition(self, code: IndentedBuffer):
815 code.writeline("# task partitioning")
816 ndim = self.ndim
817 if ndim == 0:
818 code.writeline("num_warps = 1")
819 code.writeline("num_ctas = 1")
820 else:
821 code.writeline("shape = out0.shape")
822 code.writeline("num_tasks = out0.numel()")
823 code.writeline("if num_tasks == 0:")
824 with code.indent():
825 self.gen_return(code)
826 max_tile_size = self.config.max_tile_size
827 major, _ = get_device_capability()
828 if self.name.find("fill_scalar") != -1 and major >= 9:
829 code.writeline("tile_sizes = tuple([64])")
830 else:
831 code.writeline(
832 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)"
833 )
834 code.writeline("tile_size = math.prod(tile_sizes)")
835 code.writeline(
836 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))"
837 )
839 if self.name.find("fill_scalar") != -1 and major >= 9:
840 code.writeline("num_ctas = num_tiles")
841 else:
842 max_grid_size0 = self.config.max_grid_size[0]
843 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
845 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
846 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
847 code.writeline("one_tile_per_cta = tiles_per_cta==1")
848 code.writeline("grid = (num_ctas, 1, 1)")
850 def gen_task_partition_1d(self, code: IndentedBuffer):
851 code.writeline("# task partitioning")
852 ndim = self.ndim
853 if ndim == 0:
854 code.writeline("num_warps = 1")
855 code.writeline("num_ctas = 1")
856 else:
857 code.writeline("shape = out0.shape")
858 code.writeline("num_tasks = out0.numel()")
859 code.writeline("if num_tasks == 0:")
860 with code.indent():
861 self.gen_return(code)
862 max_tile_size = self.config.max_tile_size
864 major, _ = get_device_capability()
865 if self.name.find("fill_scalar") != -1 and major >= 9:
866 code.writeline("tile_sizes = tuple([64])")
867 else:
868 code.writeline(
869 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)"
870 )
872 code.writeline("tile_size = tile_sizes[0]")
873 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)")
875 if self.name.find("fill_scalar") != -1 and major >= 9:
876 code.writeline("num_ctas = num_tiles")
877 else:
878 max_grid_size0 = self.config.max_grid_size[0]
879 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
881 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
882 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
883 code.writeline("one_tile_per_cta = tiles_per_cta==1")
884 code.writeline("grid = (num_ctas, 1, 1)")
886 def gen_kernel_launch(
887 self,
888 code: IndentedBuffer,
889 ):
890 schema = self.fx
891 ndim = self.ndim
893 with_block_pointer = self.config.prefer_block_pointer
895 code.writeline("# kernel launch")
896 for i in range(schema.num_input_tensors()):
897 code.writeline(f"in{i}_strides = in{i}.stride()")
898 if not with_block_pointer:
899 continue
900 if ndim >= 2: # where ndim is 1, we don't need to compute stride order
901 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)")
902 else:
903 code.writeline(f"in{i}_stride_order = (0,)")
904 for i in range(schema.num_output_tensors()):
905 code.writeline(f"out{i}_strides = out{i}.stride()")
906 if not with_block_pointer:
907 continue
908 if ndim >= 2:
909 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)")
910 else:
911 code.writeline(f"out{i}_stride_order = (0,)")
913 code.writeline("with torch_device_fn.device(in0.device.index):")
914 with code.indent():
915 code.writeline(f"{self.jit_fn_name}[grid](")
916 with code.indent():
917 params = []
918 # NOTE: WRAP
919 for i in range(schema.num_inputs()):
920 if schema.is_tensor(i):
921 params.append(f"{self.input_name(i)}")
922 else:
923 params.append(self.input_name(i))
924 for i in range(schema.num_output_tensors()):
925 params.append(f"{self.output_name(i)}")
927 code.writeline(f"{_cs(params)},")
929 if ndim > 0:
930 for i in range(schema.num_input_tensors()):
931 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
932 code.writeline(f"{s}, # stride for in{i}")
933 if not with_block_pointer:
934 continue
935 order = ", ".join(
936 f"in{i}_stride_order[{j}]" for j in range(ndim)
937 )
938 code.writeline(f"{order}, # stride order for in{i}")
940 for i in range(schema.num_output_tensors()):
941 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
942 code.writeline(f"{s}, # stride for out{i}")
943 if not with_block_pointer:
944 continue
945 order = ", ".join(
946 f"out{i}_stride_order[{j}]" for j in range(ndim)
947 )
948 code.writeline(f"{order}, # stride orderfor out{i}")
950 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
951 code.writeline(f"{shape_args}, # task indexing space")
952 code.writeline("num_tasks, # num tasks")
953 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
954 for i in range(ndim):
955 code.writeline(f"tile_size{i}=tile_sizes[{i}],")
956 code.writeline("one_tile_per_cta=one_tile_per_cta,")
957 code.writeline("num_warps=num_warps,")
958 code.writeline(")")
960 def gen_kernel_launch_1d(
961 self,
962 code: IndentedBuffer,
963 ):
964 schema = self.fx
965 ndim = self.ndim
967 code.writeline("# kernel launch")
968 for i in range(schema.num_input_tensors()):
969 code.writeline(f"in{i}_strides = in{i}.stride()")
970 for i in range(schema.num_output_tensors()):
971 code.writeline(f"out{i}_strides = out{i}.stride()")
973 code.writeline("with torch_device_fn.device(in0.device.index):")
974 with code.indent():
975 code.writeline(f"{self.jit_fn_name}[grid](")
976 with code.indent():
977 params = []
978 # NOTE: WRAP
979 for i in range(schema.num_inputs()):
980 if schema.is_tensor(i):
981 params.append(f"{self.input_name(i)}")
982 else:
983 params.append(self.input_name(i))
984 for i in range(schema.num_output_tensors()):
985 params.append(f"{self.output_name(i)}")
987 code.writeline(f"{_cs(params)},")
989 if ndim > 0:
990 for i in range(schema.num_input_tensors()):
991 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
992 code.writeline(f"{s}, # stride for in{i}")
993 for i in range(schema.num_output_tensors()):
994 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
995 code.writeline(f"{s}, # stride for out{i}")
997 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
998 code.writeline(f"{shape_args}, # task indexing space")
999 code.writeline("num_tasks, # num tasks")
1000 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
1001 code.writeline("tile_size=tile_size,")
1002 code.writeline("one_tile_per_cta=one_tile_per_cta,")
1003 code.writeline("num_warps=num_warps,")
1004 code.writeline(")")
1006 def gen_return(self, code: IndentedBuffer):
1007 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors()))
1008 code.writeline(f"return {return_exprs}")
1010 def codegen_nd_tile(self, code):
1011 self.gen_signature(code)
1013 with code.indent():
1014 self.gen_docstring(code)
1015 self.gen_same_shape_check(code)
1016 self.gen_task_partition(code)
1017 self.gen_kernel_launch(code)
1018 self.gen_return(code)
1019 code.newline()
1020 return code
1022 def codegen_1d_tile(self, code):
1023 self.gen_signature(code)
1025 with code.indent():
1026 self.gen_docstring(code)
1027 self.gen_same_shape_check(code)
1028 self.gen_task_partition_1d(code)
1029 self.gen_kernel_launch_1d(code)
1030 self.gen_return(code)
1031 code.newline()
1032 return code
1035class ModuleGenerator:
1036 def __init__(
1037 self,
1038 function_schema: FunctionSchema,
1039 scalar_fn: triton.JITFunction,
1040 ndim: int,
1041 jit_fn_name: str,
1042 wrapper_name: str,
1043 config: CodeGenConfig,
1044 ):
1045 self.config = config
1046 self.wrapper_gen = WrapperGenerator(
1047 function_schema, jit_fn_name, ndim, wrapper_name, config
1048 )
1049 self.kernel_gen = KernelGenerator(
1050 function_schema, scalar_fn, ndim, jit_fn_name, config
1051 )
1053 @staticmethod
1054 def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
1055 code.writeline("import math")
1056 code.writeline("from typing import Union")
1057 code.writeline("import torch")
1058 code.writeline("import triton")
1059 code.writeline("from triton import language as tl")
1060 code.newline()
1061 code.writeline("from flag_gems.utils.shape_utils import (")
1062 code.writeline(" heuristics_for_tile_size,")
1063 code.writeline(" heuristics_for_num_warps,")
1064 code.writeline(" stride_order,")
1065 code.writeline(")")
1066 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer")
1067 code.writeline("from flag_gems.utils.libentry import libentry")
1068 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
1069 code.writeline("from flag_gems.runtime import torch_device_fn")
1070 code.newline()
1071 code.newline()
1072 return code
1074 def codegen(self, code: IndentedBuffer):
1075 # the only runtime determined factor is the rank of the task space
1076 code = self.generate_imports(code)
1077 if self.config.prefer_1d_tile:
1078 code = self.wrapper_gen.codegen_1d_tile(code)
1079 code = self.kernel_gen.codegen_1d_tile(code)
1080 else:
1081 code = self.wrapper_gen.codegen_nd_tile(code)
1082 code = self.kernel_gen.codegen_nd_tile(code)
1083 return code
1086@dataclass
1087class KernelInfo:
1088 """Information about a generated kernel for C++ integration."""
1090 file_path: str
1091 kernel_name: str
1092 wrapper_name: str
1093 ndim: int
1096class PointwiseDynamicFunction:
1097 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
1098 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
1099 The generated code are written out to the cache directory (defaults to ~/.flaggems).
1100 """
1102 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None):
1103 self.fx = op_desc
1105 assert isinstance(scalar_fn, JITFunction)
1106 self._scalar_fn = scalar_fn
1107 self._scalar_fn_cache_key = scalar_fn.cache_key
1108 self.pid = os.getpid()
1110 self.config: CodeGenConfig = config or get_codegen_config()
1112 # instantiated & cached overloads
1113 self.overloads: Mapping[str, Callable] = {}
1114 # cached kernel info for C++ integration
1115 self._kernel_info_cache: Mapping[str, KernelInfo] = {}
1117 def __call__(self, *args, **kwargs):
1118 # inputs must be passed by position, outputs must be passed by keyword
1119 ndim, args, kwargs = self.prepare_args(*args, **kwargs)
1120 overload = self.instantiate(ndim)
1121 out = overload(*args, **kwargs)
1122 # NOTE: overload keeps the type of outputs:
1123 # if a pre-defiend output is a Tensor or StridedBuffer, the corresponding
1124 # output is also a Tensor StridedBuffer, respectively
1125 # since prepare_args Wraps all the arguments, the outputs are all StridedBuffer
1126 # but if manually instantiated overload is directly called, take care of
1127 # that manually
1128 return self._unwrap(out)
1130 @staticmethod
1131 def use_fast_path(tensors):
1132 return all_the_same_shape(tensors) and (
1133 all_c_contiguous(tensors)
1134 or (
1135 all_the_same_stride(tensors)
1136 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0])
1137 )
1138 )
1140 def prepare_args(self, *args, **kwargs):
1141 # output allocation(when needed)
1142 # task simplification & task-rank infernece & input-output reinterpretation
1143 schema = self.fx
1144 outputs_that_need_allocation: List[int] = []
1145 out_tensors = []
1146 for i in range(schema.num_output_tensors()):
1147 k = f"out{i}"
1148 if k in kwargs:
1149 out_tensors.append(kwargs[k])
1150 else:
1151 outputs_that_need_allocation.append(i)
1152 # input arguments must be passed by position
1153 if schema._is_tensor is not None:
1154 if not check_tensor_attributes(args, (schema._is_tensor)):
1155 raise ValueError(
1156 "Input arguments must be passed by position, and the corresponding dtype must be specified."
1157 )
1158 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)]
1160 # output dtype promotions
1161 outputs_dtypes_for_allocation = []
1162 for i in outputs_that_need_allocation:
1163 *arg_indices, method = schema._promotion_methods[i]
1164 promote_args = (args[j] for j in arg_indices)
1165 _, dtype = type_promotion(*promote_args, type_promotion=method)
1166 outputs_dtypes_for_allocation.append(dtype)
1168 tensors = out_tensors + in_tensors
1169 INT32_MAX = torch.iinfo(torch.int32).max
1170 if tensors[0].numel() > INT32_MAX:
1171 self.config.prefer_block_pointer = False
1172 if self.use_fast_path(tensors): # dimension collapse & use physical ordering
1173 allocated_outputs = [
1174 torch.empty_like(tensors[0], dtype=dtype)
1175 for dtype in outputs_dtypes_for_allocation
1176 ]
1177 task_shape = (tensors[0].numel(),)
1178 strides = (1,)
1179 ndim = 1
1180 args = tuple(
1181 (
1182 StridedBuffer(item, task_shape, strides)
1183 if schema.is_tensor(i)
1184 else item
1185 )
1186 for i, item in enumerate(args)
1187 )
1188 kwargs = {
1189 k: StridedBuffer(item, task_shape, strides)
1190 for k, item in kwargs.items()
1191 }
1192 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1193 kwargs[f"out{output_id}"] = StridedBuffer(
1194 allocated_outputs[seq_id], task_shape, strides
1195 )
1196 else:
1197 # a simple strategy: all the undefined tensors will follow the first
1198 # tensor that is not broadcated, no attempts to simplify task, no reordering,
1199 # no dimenion collapsing
1200 shapes = tuple(item.shape for item in in_tensors)
1202 task_shape = broadcast_shapes(shapes)
1204 if out_tensors:
1205 for index, item in enumerate(out_tensors):
1206 if list(item.shape) != list(task_shape):
1207 raise RuntimeError(
1208 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!"
1209 )
1210 # output arguments must not have internal overlapping for pointwise operation
1211 if has_internal_overlapping(item) == MemOverlap.Yes:
1212 raise RuntimeError(
1213 "Pointwise Input arguments should not have internal overlapping."
1214 )
1216 ndim = len(task_shape)
1217 for item in tensors:
1218 if item.shape == task_shape:
1219 allocated_outputs = [
1220 torch.empty_like(item, dtype=dtype)
1221 for dtype in outputs_dtypes_for_allocation
1222 ]
1223 break
1224 else: # nobreak
1225 device = tensors[0].device
1226 allocated_outputs = [
1227 torch.empty(task_shape, dtype=dtype, device=device)
1228 for dtype in outputs_dtypes_for_allocation
1229 ]
1230 args = tuple(
1231 (
1232 StridedBuffer(
1233 item,
1234 task_shape,
1235 broadcasted_stride(item.shape, item.stride(), task_shape),
1236 )
1237 if schema.is_tensor(i)
1238 else item
1239 )
1240 for i, item in enumerate(args)
1241 )
1242 kwargs = {
1243 k: StridedBuffer(
1244 item,
1245 task_shape,
1246 broadcasted_stride(item.shape, item.stride(), task_shape),
1247 )
1248 for k, item in kwargs.items()
1249 }
1250 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1251 item = allocated_outputs[seq_id]
1252 kwargs[f"out{output_id}"] = StridedBuffer(
1253 item,
1254 task_shape,
1255 broadcasted_stride(item.shape, item.stride(), task_shape),
1256 )
1257 return (ndim, args, kwargs)
1259 def _unwrap(self, tensors):
1260 # unwrap StridedBuffer to get Tensor
1261 if self.fx.num_output_tensors() == 1:
1262 item = tensors
1263 return item.unwrap()
1264 return tuple(item.unwrap() for item in tensors)
1266 def _compute_kernel_names(self, ndim: int) -> Tuple[str, str, str]:
1267 """Compute kernel name, wrapper name, and file path for a given ndim.
1269 This is the single source of truth for naming, used by both instantiate()
1270 and get_kernel_info() to ensure consistency.
1272 Returns:
1273 Tuple of (kernel_name, wrapper_name, file_path)
1274 """
1275 scalar_fn_name = self._scalar_fn.__name__
1276 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}"
1277 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}"
1279 file_name = (
1280 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_"
1281 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}"
1282 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}"
1283 ".py"
1284 )
1285 file_path = str(code_cache_dir() / file_name)
1287 return kernel_name, wrapper_name, file_path
1289 def instantiate(self, ndim):
1290 # NOTE: manually instantiated overload does not have `prepare_args` as
1291 # preprocessing, so you have to manually allocate output and make sure that
1292 # the inputs & ouputs actually fits the manually instantiated overload
1293 key = f"{ndim}_{self.config.prefer_block_pointer}"
1294 if key in self.overloads:
1295 return self.overloads[key]
1297 code = IndentedBuffer()
1299 # Use helper to compute names (single source of truth)
1300 kernel_name, wrapper_name, file_path = self._compute_kernel_names(ndim)
1302 module_gen = ModuleGenerator(
1303 self.fx,
1304 self._scalar_fn,
1305 ndim,
1306 kernel_name,
1307 wrapper_name,
1308 self.config,
1309 )
1310 module_gen.codegen(code)
1312 # NOTE: [why write the generated code to a file]
1313 # triton uses inpsect to get the source of the jitted function, which requires
1314 # that the source code can be found by inspect
1315 # We write it into a file, since inspect cannot find the source of functions dynamically
1316 # created via exec string. We can help inspect to find the source by hacking linecache
1317 # library, but we find generating a module simpler, since we can generating 2 functions
1318 # the kernel and the wrapper, and the wrapper calls the kernel.
1319 write_atomic(file_path, code.getvalue())
1321 # load
1322 spec = importlib.util.spec_from_file_location(
1323 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}",
1324 file_path,
1325 )
1326 m = importlib.util.module_from_spec(spec)
1327 # do not expose it to sys.modules
1328 # sys.modules["_add_module"] = m
1330 # NOTE: [why not import the scalar function]
1331 # we do not re-import the scalar function, although the generated kernel **calls** it
1332 # Since a function's __name__ may be changed, from the module where it is defined import its
1333 # __name__ is not same; Also the same may be rebind to something else, importing via name
1334 # cannot guarantee that scalar function is imported.
1335 # So we copy the scalar function and its __globals__ to the generated module to do this
1336 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime
1337 spec.loader.exec_module(m)
1338 m.__dict__.update(self._scalar_fn.__globals__)
1339 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn
1341 overload = getattr(m, wrapper_name)
1342 self.overloads[key] = overload
1344 # Cache kernel info for C++ integration
1345 self._kernel_info_cache[key] = KernelInfo(
1346 file_path=file_path,
1347 kernel_name=kernel_name,
1348 wrapper_name=wrapper_name,
1349 ndim=ndim,
1350 )
1352 return overload
1354 def get_kernel_info(self, ndim: int) -> KernelInfo:
1355 """Get kernel information for a given ndim.
1357 This method is useful for C++ integration to get the file path and
1358 kernel name without duplicating the naming logic.
1360 If the kernel hasn't been instantiated yet, this will instantiate it first.
1362 Args:
1363 ndim: The rank of the task space
1365 Returns:
1366 KernelInfo with file_path, kernel_name, wrapper_name, and ndim
1367 """
1368 key = f"{ndim}_{self.config.prefer_block_pointer}"
1370 # Ensure the kernel is instantiated
1371 if key not in self._kernel_info_cache:
1372 self.instantiate(ndim)
1374 return self._kernel_info_cache[key]
1377def pointwise_dynamic(
1378 f: Optional[JITFunction] = None,
1379 *,
1380 num_inputs: Optional[int] = None,
1381 is_tensor: Optional[List[bool]] = None,
1382 dtypes: Optional[List[Optional[type]]] = None,
1383 num_outputs: Optional[int] = None,
1384 promotion_methods: Optional[Tuple[int, ...]] = None,
1385 config: Optional[CodeGenConfig] = None,
1386):
1387 def decorator(fn):
1388 nonlocal num_inputs
1389 if (num_inputs is None) and (is_tensor is None) and (dtypes is None):
1390 num_inputs = len(fn.arg_names)
1391 op_desc = FunctionSchema(
1392 num_inputs=num_inputs,
1393 is_tensor=is_tensor,
1394 dtypes=dtypes,
1395 num_outputs=num_outputs,
1396 promotion_methods=promotion_methods,
1397 )
1398 return PointwiseDynamicFunction(op_desc, fn, config)
1400 if f is not None:
1401 return decorator(f)
1402 return decorator