Coverage for src/flag_gems/utils/pointwise_dynamic.py: 96%
788 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import importlib
2import os
3from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple
5import torch
6import triton
7from triton.runtime.jit import JITFunction
9from flag_gems.utils.code_cache import code_cache_dir
10from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config
12from flag_gems.utils.device_info import get_device_capability
13from flag_gems.utils.shape_utils import (
14 MemOverlap,
15 all_c_contiguous,
16 all_the_same_shape,
17 all_the_same_stride,
18 broadcast_shapes,
19 broadcasted_stride,
20 check_tensor_attributes,
21 has_internal_overlapping,
22)
23from flag_gems.utils.tensor_wrapper import StridedBuffer
24from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion
27# ------------------ Operation Description ---------------------------
28def _type_name(type) -> str:
29 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object"
30 if type in (bool, int, float, str):
31 return type.__name__
32 if isinstance(type, torch.dtype):
33 return str(type)
34 return str(type)
37def _check_typed_list(container, type):
38 for item in container:
39 assert isinstance(item, type)
42def _check_sized_list(container, size):
43 assert len(container) == size
46def _tuple_content(strings: Sequence[str]) -> str:
47 # comma separated list
48 if len(strings) == 0:
49 return ""
50 if len(strings) == 1:
51 return f"{strings[0]},"
52 else:
53 return ", ".join(strings)
56def _cs(strings: Iterable[str]) -> str:
57 return ", ".join(strings)
60def _broadcast_vec(i, ndim):
61 axes = [":" if j == i else "None" for j in range(ndim)]
62 return f"[{_cs(axes)}]"
65class FunctionSchema:
66 _num_inputs: int
67 _is_tensor: List[bool]
68 _dtypes: List[Optional[type]]
70 _num_input_tensors: int
71 _num_non_tensor_inputs: int
73 _num_outputs: int
74 _promotion_methods: List[Tuple[int, ...]]
76 def __init__(
77 self,
78 *,
79 num_inputs: Optional[int] = None,
80 is_tensor: Optional[List[bool]] = None,
81 dtypes: Optional[List[Optional[type]]] = None,
82 num_outputs: Optional[int] = None,
83 promotion_methods=None,
84 ):
85 if is_tensor is not None:
86 _check_typed_list(is_tensor, bool)
87 if dtypes is not None:
88 _check_typed_list(dtypes, (type, type(None)))
90 if promotion_methods is None:
91 raise ValueError(
92 "No type promotion method provided! You must provide type promotion method for each output!"
93 )
94 else:
95 self._promotion_methods = self.canonicalize_promotion_methods(
96 promotion_methods
97 )
98 if num_inputs is not None:
99 self._num_inputs = num_inputs
100 if is_tensor is not None:
101 _check_sized_list(is_tensor, num_inputs)
102 self._is_tensor = is_tensor
103 else:
104 self._is_tensor = [True] * num_inputs
106 if dtypes is not None:
107 _check_sized_list(dtypes, num_inputs)
108 self._dtypes = dtypes
109 else:
110 self._dtypes = [None] * num_inputs
111 elif is_tensor is not None:
112 self._num_inputs = len(is_tensor)
113 self._is_tensor = is_tensor
114 if dtypes is not None:
115 _check_sized_list(dtypes, self._num_inputs)
116 self._dtypes = dtypes
117 else:
118 self._dtypes = [None] * self._num_inputs
119 elif dtypes is not None:
120 self._num_inputs = len(dtypes)
121 self._dtypes = dtypes
122 if is_tensor is not None:
123 _check_sized_list(is_tensor, self._num_inputs)
124 self._is_tensor = is_tensor
125 else:
126 self._is_tensor = [item is None for item in dtypes]
127 else:
128 raise ValueError(
129 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified."
130 )
132 if num_outputs is not None:
133 self._num_outputs = num_outputs
134 _check_sized_list(promotion_methods, num_outputs)
135 else:
136 self._num_outputs = len(promotion_methods)
138 assert self._num_inputs >= 1
139 assert self._num_outputs >= 1
141 self._num_input_tensors = sum(self._is_tensor)
142 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors
143 self._input_id = self._compute_input_id()
145 @staticmethod
146 def canonicalize_promotion_methods(promotion_methods):
147 canonicalized = []
148 for item in promotion_methods:
149 *arg_indices, method = item
150 canonicalized.append(
151 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method])
152 )
153 return canonicalized
155 def num_inputs(self):
156 # num of arguments, outputs not included
157 return self._num_inputs
159 def num_outputs(self):
160 return self._num_outputs
162 def is_tensor(self, arg_id: int) -> bool:
163 return self._is_tensor[arg_id]
165 def input_type(self, arg_id) -> Optional[type]:
166 return self._dtypes[arg_id]
168 def output_type(self, i):
169 return self._promotion_methods[i]
171 def num_input_tensors(self) -> int:
172 return self._num_input_tensors
174 def num_output_tensors(self) -> int:
175 return self._num_outputs
177 def num_non_tensor_args(self) -> int:
178 return self._num_non_tensor_inputs
180 def signature(self, outputs_in_arg: bool = False) -> str:
181 input_types = []
182 for is_tensor, dtype in zip(self._is_tensor, self._dtypes):
183 if is_tensor:
184 input_types.append("StridedBuffer")
185 else:
186 if dtype is None:
187 input_types.append("scalar")
188 else:
189 input_types.append(_type_name(dtype))
191 output_types = []
193 if outputs_in_arg:
194 for i in range(self.num_outputs()):
195 output_types.append(f"StridedBuffer(a{1}!)")
196 input_types.extend(output_types)
197 else:
198 for _ in range(self.num_outputs()):
199 output_types.append("StridedBuffer")
200 sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}'
201 return sig
203 def _compute_input_id(self):
204 input_tensor_index = 0
205 non_tensor_index = 0
206 mapping: List[int] = []
207 for i in range(self.num_inputs()):
208 if self.is_tensor(i):
209 mapping.append(input_tensor_index)
210 input_tensor_index += 1
211 else:
212 mapping.append(non_tensor_index)
213 non_tensor_index += 1
214 return mapping
216 def input_index(self, idx):
217 return self._input_id[idx]
219 def __str__(self) -> str:
220 return self.signature(outputs_in_arg=False)
223class KernelGenerator:
224 def __init__(
225 self,
226 function_schema: FunctionSchema,
227 scalar_fn: triton.JITFunction,
228 rank: int,
229 name: str,
230 config: CodeGenConfig,
231 ):
232 self.fx = function_schema
233 self.fn = scalar_fn
234 self.ndim = rank
235 self.name = name
236 self.config = config
238 self.fn_name = scalar_fn.__name__
239 self.fn_module = scalar_fn.__module__
241 def gen_import_function(self, code: IndentedBuffer):
242 code.writeline(f'"""Quoted source of {self.fn_name}:')
243 code.writemultiline(self.fn.src)
244 code.writeline('"""')
245 code.newline()
247 def gen_decorators(self, code):
248 code.writeline("@libentry()")
249 num_non_tensor_args = self.fx.num_non_tensor_args()
250 if num_non_tensor_args > 0:
251 # we do not specialize non tensor args since they are passed into the inlined function
252 # which means that their values may not deserve specialization
253 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)]
254 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
255 else:
256 code.writeline("@triton.jit")
258 def input_name(self, i):
259 is_tensor = self.fx.is_tensor(i)
260 name = "in" if is_tensor else "val"
261 index = self.fx.input_index(i)
262 return f"{name}{index}"
264 def output_name(self, i):
265 return f"out{i}"
267 def gen_signature(self, code, with_block_pointer=False):
268 code.writeline(f"def {self.name}(")
269 with code.indent():
270 input_tensor_index = 0
271 non_tensor_index = 0
272 output_tensor_index = 0
274 schema = self.fx
275 # signature: inputs ptrs & non tensor inputs
276 for i in range(schema.num_inputs()):
277 if schema.is_tensor(i):
278 code.writeline(
279 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
280 )
281 input_tensor_index += 1
282 else:
283 if schema.input_type(i) is not None:
284 code.writeline(
285 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
286 )
287 else:
288 code.writeline(f"val{non_tensor_index},")
289 non_tensor_index += 1
291 # signature: output ptrs
292 for i in range(schema.num_outputs()):
293 code.writeline(
294 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
295 )
296 output_tensor_index += 1
298 # signature: strides, for each tensor arguments
299 ndim = self.ndim
300 if ndim > 0:
301 # strides for inputs
302 for i in range(schema.num_input_tensors()):
303 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
304 code.writeline(f"{stride_args}, # strides for in{i}")
305 if with_block_pointer:
306 stride_order_args = _cs(
307 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
308 )
309 code.writeline(f"{stride_order_args}, # stride order for in{i}")
311 # strides for outputs
312 for i in range(schema.num_output_tensors()):
313 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
314 code.writeline(f"{stride_args}, # strides for out{i}")
315 if with_block_pointer:
316 stride_order_args = _cs(
317 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
318 )
319 code.writeline(
320 f"{stride_order_args}, # stride order for out{i}"
321 )
323 # task space, used to reconstruct multi index
324 task_space_args = _cs(f"s{i}" for i in range(ndim))
325 code.writeline(f"{task_space_args}, # task_space")
327 # number of tasks, used to compute mask
328 code.writeline("num_tasks,")
330 # tile size & tiles_per_cta, gsl style
331 if ndim > 0:
332 code.writeline("tiles_per_cta: int,")
333 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim))
334 code.writeline(f"{tile_sizes},")
335 code.writeline("one_tile_per_cta: tl.constexpr,")
336 code.writeline("):")
338 def gen_signature_1d_tile(self, code):
339 code.writeline(f"def {self.name}(")
340 with code.indent():
341 input_tensor_index = 0
342 non_tensor_index = 0
343 output_tensor_index = 0
345 schema = self.fx
346 # signature: inputs ptrs & non tensor inputs
347 for i in range(schema.num_inputs()):
348 if schema.is_tensor(i):
349 code.writeline(
350 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
351 )
352 input_tensor_index += 1
353 else:
354 if schema.input_type(i) is not None:
355 code.writeline(
356 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
357 )
358 else:
359 code.writeline(f"val{non_tensor_index},")
360 non_tensor_index += 1
362 # signature: output ptrs
363 for i in range(schema.num_outputs()):
364 code.writeline(
365 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
366 )
367 output_tensor_index += 1
369 # signature: strides, for each tensor arguments
370 ndim = self.ndim
371 if ndim > 0:
372 # strides for inputs
373 for i in range(schema.num_input_tensors()):
374 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
375 code.writeline(f"{stride_args}, # strides for in{i}")
377 # strides for outputs
378 for i in range(schema.num_output_tensors()):
379 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
380 code.writeline(f"{stride_args}, # strides for out{i}")
382 # task space, used to reconstruct multi index
383 task_space_args = _cs(f"s{i}" for i in range(ndim))
384 code.writeline(f"{task_space_args}, # task_space")
386 # number of tasks, used to compute mask
387 code.writeline("num_tasks,")
389 # tile size & tiles_per_cta, gsl style
390 if ndim > 0:
391 code.writeline("tiles_per_cta: int,")
392 code.writeline("tile_size: tl.constexpr,")
393 code.writeline("one_tile_per_cta: tl.constexpr,")
394 code.writeline("):")
396 def gen_num_tiles(self, code):
397 # tile-grid size
398 ndim = self.ndim
399 for i in range(ndim):
400 if i < ndim:
401 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})")
403 def gen_body_for_0d(self, code):
404 schema = self.fx
405 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
406 outputs_to_scalar_fn = [
407 self.output_name(i) for i in range(schema.num_output_tensors())
408 ]
409 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
410 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
412 code.writeline("# loads")
413 for i in range(schema.num_input_tensors()):
414 code.writeline(
415 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) "
416 "# workaround the bug on bool, we should use the pointer's dtype)"
417 )
418 code.newline()
420 code.writeline("# compute")
421 code.writeline(
422 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
423 )
424 code.newline()
426 code.writeline("# stores")
427 for i in range(schema.num_output_tensors()):
428 code.writeline(
429 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))"
430 )
431 code.newline()
432 return code
434 # nd tile 1d grid kernel with block pointer
435 def gen_body_one_tile_per_cta_with_bptr(self, code):
436 ndim = self.ndim
437 schema = self.fx
439 # block pointer for each operand
440 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim)))
441 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim)))
442 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim)))
444 # reconstruct pid multi index
445 code.writeline(
446 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
447 )
448 for i in reversed(range(ndim)):
449 if i > 0:
450 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
451 code.writeline(f"tile_id //= num_tiles{i}")
452 else:
453 code.writeline(f"tile_id{i} = tile_id")
454 code.newline()
456 # cta_offsets
457 code.writeline("# tile offsets")
458 for i in range(ndim):
459 # Or else: AssertionError: Block pointers only support 32 bit
460 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing
461 # for 64 bit support
462 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)")
464 # loads
465 code.writeline("# loads")
466 for i in range(schema.num_input_tensors()):
467 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim)))
468 order = _tuple_content(tuple(f"in{i}_stride_order{j}" for j in range(ndim)))
469 code.writeline(
470 f"in{i}_bptr = tl.make_block_ptr("
471 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
472 )
473 code.writeline(
474 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) "
475 "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)"
476 )
477 code.newline()
479 # compute
480 # TODO: sepearate this part
481 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
482 outputs_to_scalar_fn = [
483 self.output_name(i) for i in range(schema.num_output_tensors())
484 ]
485 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
486 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
488 code.writeline("# compute")
489 code.writeline(
490 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
491 )
492 code.newline()
494 # stores
495 code.writeline(
496 "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype"
497 )
498 for i in range(schema.num_output_tensors()):
499 strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim)))
500 order = _tuple_content(
501 tuple(f"out{i}_stride_order{j}" for j in range(ndim))
502 )
503 code.writeline(
504 f"out{i}_bptr = tl.make_block_ptr("
505 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
506 )
507 code.writeline(
508 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))"
509 )
511 def gen_body_gsl_with_bptr(self, code):
512 code.writeline("num_ctas = tle.num_programs(0)")
513 code.writeline("for j in range(0, tiles_per_cta):")
514 with code.indent():
515 code.writeline("tile_id = pid + j * num_ctas")
516 self.gen_body_one_tile_per_cta_with_bptr(code)
518 def gen_body_one_tile_per_cta_without_bptr(self, code):
519 ndim = self.ndim
520 schema = self.fx
522 # reconstruct pid multi index
523 code.writeline(
524 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
525 )
526 for i in reversed(range(ndim)):
527 if i > 0:
528 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
529 code.writeline(f"tile_id //= num_tiles{i}")
530 else:
531 code.writeline(f"tile_id{i} = tile_id")
532 code.newline()
534 # offsets
535 for i in range(ndim):
536 code.writeline(
537 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})"
538 )
540 # masks
541 for i in range(ndim):
542 code.writeline(f"mask{i} = offsets{i} < s{i}")
543 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim))
544 mask_combine = " & ".join(masks)
545 code.writeline(f"mask = {mask_combine}")
547 # loads
548 code.writeline("# loads")
549 for i in range(schema.num_input_tensors()):
550 offsets = tuple(
551 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}"
552 for j in range(ndim)
553 )
554 offset_combine = " + ".join(offsets)
555 code.writeline(
556 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
557 )
559 code.newline()
561 # compute
562 # TODO: sepearate this part
563 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
564 outputs_to_scalar_fn = [
565 self.output_name(i) for i in range(schema.num_output_tensors())
566 ]
567 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
568 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
570 code.writeline("# compute")
571 code.writeline(
572 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
573 )
574 code.newline()
576 # stores
577 for i in range(schema.num_output_tensors()):
578 offsets = tuple(
579 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}"
580 for j in range(ndim)
581 )
582 offset_combine = " + ".join(offsets)
583 code.writeline(
584 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
585 )
587 def gen_body_gsl_without_bptr(self, code):
588 code.writeline("num_ctas = tle.num_programs(0)")
589 code.writeline("for j in range(0, tiles_per_cta):")
590 with code.indent():
591 code.writeline("tile_id = pid + j * num_ctas")
592 self.gen_body_one_tile_per_cta_without_bptr(code)
594 def codegen_nd_tile_with_bptr(self, code):
595 """Generate kernel nd tile & 1d grid with gsl support with block pointer."""
596 self.gen_import_function(code)
597 self.gen_decorators(code)
598 self.gen_signature(code, with_block_pointer=True)
600 # function body for rank-0
601 if self.ndim == 0:
602 with code.indent():
603 self.gen_body_for_0d(code)
604 return code
606 with code.indent():
607 code.writeline("pid = tle.program_id(0)")
608 self.gen_num_tiles(code)
609 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
610 code.writeline("if one_tile_per_cta: # monolitic kernel style")
611 with code.indent():
612 code.writeline("tile_id = pid")
613 self.gen_body_one_tile_per_cta_with_bptr(code)
614 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
615 code.writeline("else: # grid-stride-loop style kernel")
616 with code.indent():
617 self.gen_body_gsl_with_bptr(code)
618 code.newline()
619 return code
621 def codegen_nd_tile_without_bptr(self, code):
622 self.gen_import_function(code)
623 self.gen_decorators(code)
624 self.gen_signature(code, with_block_pointer=False)
626 # function body for rank-0
627 if self.ndim == 0:
628 with code.indent():
629 self.gen_body_for_0d(code)
630 return code
632 with code.indent():
633 code.writeline("pid = tle.program_id(0)")
634 self.gen_num_tiles(code)
635 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
636 code.writeline("if one_tile_per_cta: # monolitic kernel style")
637 with code.indent():
638 code.writeline("tile_id = pid")
639 self.gen_body_one_tile_per_cta_without_bptr(code)
640 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
641 code.writeline("else: # grid-stride-loop style kernel")
642 with code.indent():
643 self.gen_body_gsl_without_bptr(code)
644 code.newline()
645 return code
647 def codegen_nd_tile(self, code):
648 use_block_pointer = self.config.prefer_block_pointer
649 if use_block_pointer:
650 self.codegen_nd_tile_with_bptr(code)
651 else:
652 self.codegen_nd_tile_without_bptr(code)
653 return code
655 def gen_body_one_tile_per_cta_1d_tile(self, code):
656 ndim = self.ndim
657 schema = self.fx
659 # tile id
660 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)")
661 code.writeline("mask = tid < num_tasks")
663 # multi index reconstruction
664 for i in reversed(range(ndim)):
665 if i > 0:
666 code.writeline(f"i{i} = tid % s{i}")
667 code.writeline(f"tid //= s{i}")
668 else:
669 code.writeline(f"i{i} = tid")
670 code.newline()
672 # loads
673 code.writeline("# loads")
674 for i in range(schema.num_input_tensors()):
675 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim))
676 offset_combine = " + ".join(offsets)
677 code.writeline(
678 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
679 )
681 code.newline()
683 # compute
684 # TODO: sepearate this part
685 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
686 outputs_to_scalar_fn = [
687 self.output_name(i) for i in range(schema.num_output_tensors())
688 ]
689 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
690 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
692 code.writeline("# compute")
693 code.writeline(
694 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
695 )
696 code.newline()
698 # stores
699 for i in range(schema.num_output_tensors()):
700 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim))
701 offset_combine = " + ".join(offsets)
702 code.writeline(
703 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
704 )
706 def gen_body_gsl_1d_tile(self, code):
707 code.writeline("num_ctas = tle.num_programs(0)")
708 code.writeline("for j in range(0, tiles_per_cta):")
709 with code.indent():
710 code.writeline("tile_id = pid + j * num_ctas")
711 self.gen_body_one_tile_per_cta_1d_tile(code)
713 def codegen_1d_tile(self, code):
714 """Generate kernel 1d tile & 1d grid with gsl support."""
715 self.gen_import_function(code)
716 self.gen_decorators(code)
717 self.gen_signature_1d_tile(code)
719 # function body for rank-0
720 if self.ndim == 0:
721 with code.indent():
722 self.gen_body_for_0d(code)
723 return code
725 with code.indent():
726 code.writeline("pid = tle.program_id(0)")
727 # code.writeline("num_ctas = te.num_programs(0)")
728 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
729 code.writeline("if one_tile_per_cta: # monolitic kernel style")
730 with code.indent():
731 code.writeline("tile_id = pid")
732 self.gen_body_one_tile_per_cta_1d_tile(code)
733 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
734 code.writeline("else: # grid-stride-loop style kernel")
735 with code.indent():
736 self.gen_body_gsl_1d_tile(code)
737 code.newline()
738 return code
741class WrapperGenerator:
742 def __init__(
743 self,
744 function_schema: FunctionSchema,
745 jit_fn_name: str,
746 ndim: int,
747 name: str,
748 config: CodeGenConfig,
749 ):
750 self.fx = function_schema
751 self.jit_fn_name = jit_fn_name
752 self.ndim = ndim
753 self.name = name
754 self.config = config
756 def input_name(self, i):
757 is_tensor = self.fx.is_tensor(i)
758 name = "in" if is_tensor else "val"
759 index = self.fx.input_index(i)
760 return f"{name}{index}"
762 def output_name(self, i):
763 return f"out{i}"
765 def gen_signature(self, code: IndentedBuffer):
766 # TODO: check if triton handles constexprs transitively
767 schema = self.fx
768 params: List[str] = []
769 for i in range(schema.num_inputs()):
770 if schema.is_tensor(i):
771 params.append(
772 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]"
773 )
774 else:
775 arg_type = schema.input_type(i)
776 if arg_type is not None:
777 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}")
778 else:
779 params.append(f"{self.input_name(i)}")
780 # NOTE: [the wrapper's signature and rules for passing parameters ]
781 # input params: must be passed by position, since the names are renamed to
782 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd
783 # So we enforce that these parameters must be passed by position.
784 # maybe we can fix it later
785 # output parameters: must be passed by keyword, since the scalar function
786 # do not have output parameters(think of it as some scalar function, output
787 # parameter does not make sense in this case.) They are added to allow destination
788 # passing style API. Output parameter is convenient in cases where we want
789 # to use some pre-defiend outputs(especially when they are some views of other
790 # tensors). We emphasize that these parameters are added in-addition, we enforce
791 # that they be passed by keyword. After all, out0, out1, ... does not mismatch
792 # names form the scalar function, since it does not have output parameters.
793 params.append("/")
794 params.append("*") # output params must be passed by keyword
796 for i in range(schema.num_output_tensors()):
797 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]")
798 code.writeline(f"def {self.name}({_cs(params)}): ")
800 def gen_docstring(self, code: IndentedBuffer):
801 schema = self.fx
802 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""'
803 code.writeline(doc)
805 def gen_same_shape_check(self, code: IndentedBuffer):
806 schema: FunctionSchema = self.fx
807 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [
808 f"out{i}.shape" for i in range(schema.num_output_tensors())
809 ]
810 check: str = " == ".join(params)
811 code.writeline(f"assert {check}, 'operand shapes mismatch'")
813 def gen_task_partition(self, code: IndentedBuffer):
814 code.writeline("# task partitioning")
815 ndim = self.ndim
816 if ndim == 0:
817 code.writeline("num_warps = 1")
818 code.writeline("num_ctas = 1")
819 else:
820 code.writeline("shape = out0.shape")
821 code.writeline("num_tasks = out0.numel()")
822 code.writeline("if num_tasks == 0:")
823 with code.indent():
824 self.gen_return(code)
825 max_tile_size = self.config.max_tile_size
826 major, _ = get_device_capability()
827 if self.name.find("fill_scalar") != -1 and major >= 9:
828 code.writeline("tile_sizes = tuple([64])")
829 else:
830 code.writeline(
831 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)"
832 )
833 code.writeline("tile_size = math.prod(tile_sizes)")
834 code.writeline(
835 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))"
836 )
838 if self.name.find("fill_scalar") != -1 and major >= 9:
839 code.writeline("num_ctas = num_tiles")
840 else:
841 max_grid_size0 = self.config.max_grid_size[0]
842 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
844 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
845 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
846 code.writeline("one_tile_per_cta = tiles_per_cta==1")
847 code.writeline("grid = (num_ctas, 1, 1)")
849 def gen_task_partition_1d(self, code: IndentedBuffer):
850 code.writeline("# task partitioning")
851 ndim = self.ndim
852 if ndim == 0:
853 code.writeline("num_warps = 1")
854 code.writeline("num_ctas = 1")
855 else:
856 code.writeline("shape = out0.shape")
857 code.writeline("num_tasks = out0.numel()")
858 code.writeline("if num_tasks == 0:")
859 with code.indent():
860 self.gen_return(code)
861 max_tile_size = self.config.max_tile_size
863 major, _ = get_device_capability()
864 if self.name.find("fill_scalar") != -1 and major >= 9:
865 code.writeline("tile_sizes = tuple([64])")
866 else:
867 code.writeline(
868 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)"
869 )
871 code.writeline("tile_size = tile_sizes[0]")
872 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)")
874 if self.name.find("fill_scalar") != -1 and major >= 9:
875 code.writeline("num_ctas = num_tiles")
876 else:
877 max_grid_size0 = self.config.max_grid_size[0]
878 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
880 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
881 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
882 code.writeline("one_tile_per_cta = tiles_per_cta==1")
883 code.writeline("grid = (num_ctas, 1, 1)")
885 def gen_kernel_launch(
886 self,
887 code: IndentedBuffer,
888 ):
889 schema = self.fx
890 ndim = self.ndim
892 with_block_pointer = self.config.prefer_block_pointer
894 code.writeline("# kernel launch")
895 for i in range(schema.num_input_tensors()):
896 code.writeline(f"in{i}_strides = in{i}.stride()")
897 if not with_block_pointer:
898 continue
899 if ndim >= 2: # where ndim is 1, we don't need to compute stride order
900 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)")
901 else:
902 code.writeline(f"in{i}_stride_order = (0,)")
903 for i in range(schema.num_output_tensors()):
904 code.writeline(f"out{i}_strides = out{i}.stride()")
905 if not with_block_pointer:
906 continue
907 if ndim >= 2:
908 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)")
909 else:
910 code.writeline(f"out{i}_stride_order = (0,)")
912 code.writeline("with torch_device_fn.device(in0.device.index):")
913 with code.indent():
914 code.writeline(f"{self.jit_fn_name}[grid](")
915 with code.indent():
916 params = []
917 # NOTE: WRAP
918 for i in range(schema.num_inputs()):
919 if schema.is_tensor(i):
920 params.append(f"{self.input_name(i)}")
921 else:
922 params.append(self.input_name(i))
923 for i in range(schema.num_output_tensors()):
924 params.append(f"{self.output_name(i)}")
926 code.writeline(f"{_cs(params)},")
928 if ndim > 0:
929 for i in range(schema.num_input_tensors()):
930 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
931 code.writeline(f"{s}, # stride for in{i}")
932 if not with_block_pointer:
933 continue
934 order = ", ".join(
935 f"in{i}_stride_order[{j}]" for j in range(ndim)
936 )
937 code.writeline(f"{order}, # stride order for in{i}")
939 for i in range(schema.num_output_tensors()):
940 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
941 code.writeline(f"{s}, # stride for out{i}")
942 if not with_block_pointer:
943 continue
944 order = ", ".join(
945 f"out{i}_stride_order[{j}]" for j in range(ndim)
946 )
947 code.writeline(f"{order}, # stride orderfor out{i}")
949 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
950 code.writeline(f"{shape_args}, # task indexing space")
951 code.writeline("num_tasks, # num tasks")
952 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
953 for i in range(ndim):
954 code.writeline(f"tile_size{i}=tile_sizes[{i}],")
955 code.writeline("one_tile_per_cta=one_tile_per_cta,")
956 code.writeline("num_warps=num_warps,")
957 code.writeline(")")
959 def gen_kernel_launch_1d(
960 self,
961 code: IndentedBuffer,
962 ):
963 schema = self.fx
964 ndim = self.ndim
966 code.writeline("# kernel launch")
967 for i in range(schema.num_input_tensors()):
968 code.writeline(f"in{i}_strides = in{i}.stride()")
969 for i in range(schema.num_output_tensors()):
970 code.writeline(f"out{i}_strides = out{i}.stride()")
972 code.writeline("with torch_device_fn.device(in0.device.index):")
973 with code.indent():
974 code.writeline(f"{self.jit_fn_name}[grid](")
975 with code.indent():
976 params = []
977 # NOTE: WRAP
978 for i in range(schema.num_inputs()):
979 if schema.is_tensor(i):
980 params.append(f"{self.input_name(i)}")
981 else:
982 params.append(self.input_name(i))
983 for i in range(schema.num_output_tensors()):
984 params.append(f"{self.output_name(i)}")
986 code.writeline(f"{_cs(params)},")
988 if ndim > 0:
989 for i in range(schema.num_input_tensors()):
990 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
991 code.writeline(f"{s}, # stride for in{i}")
992 for i in range(schema.num_output_tensors()):
993 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
994 code.writeline(f"{s}, # stride for out{i}")
996 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
997 code.writeline(f"{shape_args}, # task indexing space")
998 code.writeline("num_tasks, # num tasks")
999 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
1000 code.writeline("tile_size=tile_size,")
1001 code.writeline("one_tile_per_cta=one_tile_per_cta,")
1002 code.writeline("num_warps=num_warps,")
1003 code.writeline(")")
1005 def gen_return(self, code: IndentedBuffer):
1006 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors()))
1007 code.writeline(f"return {return_exprs}")
1009 def codegen_nd_tile(self, code):
1010 self.gen_signature(code)
1012 with code.indent():
1013 self.gen_docstring(code)
1014 self.gen_same_shape_check(code)
1015 self.gen_task_partition(code)
1016 self.gen_kernel_launch(code)
1017 self.gen_return(code)
1018 code.newline()
1019 return code
1021 def codegen_1d_tile(self, code):
1022 self.gen_signature(code)
1024 with code.indent():
1025 self.gen_docstring(code)
1026 self.gen_same_shape_check(code)
1027 self.gen_task_partition_1d(code)
1028 self.gen_kernel_launch_1d(code)
1029 self.gen_return(code)
1030 code.newline()
1031 return code
1034class ModuleGenerator:
1035 def __init__(
1036 self,
1037 function_schema: FunctionSchema,
1038 scalar_fn: triton.JITFunction,
1039 ndim: int,
1040 jit_fn_name: str,
1041 wrapper_name: str,
1042 config: CodeGenConfig,
1043 ):
1044 self.config = config
1045 self.wrapper_gen = WrapperGenerator(
1046 function_schema, jit_fn_name, ndim, wrapper_name, config
1047 )
1048 self.kernel_gen = KernelGenerator(
1049 function_schema, scalar_fn, ndim, jit_fn_name, config
1050 )
1052 @staticmethod
1053 def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
1054 code.writeline("import math")
1055 code.writeline("from typing import Union")
1056 code.writeline("import torch")
1057 code.writeline("import triton")
1058 code.writeline("from triton import language as tl")
1059 code.newline()
1060 code.writeline("from flag_gems.utils.shape_utils import (")
1061 code.writeline(" heuristics_for_tile_size,")
1062 code.writeline(" heuristics_for_num_warps,")
1063 code.writeline(" stride_order,")
1064 code.writeline(")")
1065 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer")
1066 code.writeline("from flag_gems.utils.libentry import libentry")
1067 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
1068 code.writeline("from flag_gems.runtime import torch_device_fn")
1069 code.newline()
1070 code.newline()
1071 return code
1073 def codegen(self, code: IndentedBuffer):
1074 # the only runtime determined factor is the rank of the task space
1075 code = self.generate_imports(code)
1076 if self.config.prefer_1d_tile:
1077 code = self.wrapper_gen.codegen_1d_tile(code)
1078 code = self.kernel_gen.codegen_1d_tile(code)
1079 else:
1080 code = self.wrapper_gen.codegen_nd_tile(code)
1081 code = self.kernel_gen.codegen_nd_tile(code)
1082 return code
1085class PointwiseDynamicFunction:
1086 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
1087 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
1088 The generated code are written out to the cache directory (defaults to ~/.flaggems).
1089 """
1091 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None):
1092 self.fx = op_desc
1094 assert isinstance(scalar_fn, JITFunction)
1095 self._scalar_fn = scalar_fn
1096 self._scalar_fn_cache_key = scalar_fn.cache_key
1097 self.pid = os.getpid()
1099 self.config: CodeGenConfig = config or get_codegen_config()
1101 # instantiated & cached overloads
1102 self.overloads: Mapping[str, Callable] = {}
1104 def __call__(self, *args, **kwargs):
1105 # inputs must be passed by position, outputs must be passed by keyword
1106 ndim, args, kwargs = self.prepare_args(*args, **kwargs)
1107 overload = self.instantiate(ndim)
1108 out = overload(*args, **kwargs)
1109 # NOTE: overload keeps the type of outputs:
1110 # if a pre-defiend output is a Tensor or StridedBuffer, the corresponding
1111 # output is also a Tensor StridedBuffer, respectively
1112 # since prepare_args Wraps all the arguments, the outputs are all StridedBuffer
1113 # but if manually instantiated overload is directly called, take care of
1114 # that manually
1115 return self._unwrap(out)
1117 @staticmethod
1118 def use_fast_path(tensors):
1119 return all_the_same_shape(tensors) and (
1120 all_c_contiguous(tensors)
1121 or (
1122 all_the_same_stride(tensors)
1123 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0])
1124 )
1125 )
1127 def prepare_args(self, *args, **kwargs):
1128 # output allocation(when needed)
1129 # task simplification & task-rank infernece & input-output reinterpretation
1130 schema = self.fx
1131 outputs_that_need_allocation: List[int] = []
1132 out_tensors = []
1133 for i in range(schema.num_output_tensors()):
1134 k = f"out{i}"
1135 if k in kwargs:
1136 out_tensors.append(kwargs[k])
1137 else:
1138 outputs_that_need_allocation.append(i)
1139 # input arguments must be passed by position
1140 if schema._is_tensor is not None:
1141 if not check_tensor_attributes(args, (schema._is_tensor)):
1142 raise ValueError(
1143 "Input arguments must be passed by position, and the corresponding dtype must be specified."
1144 )
1145 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)]
1147 # output dtype promotions
1148 outputs_dtypes_for_allocation = []
1149 for i in outputs_that_need_allocation:
1150 *arg_indices, method = schema._promotion_methods[i]
1151 promote_args = (args[j] for j in arg_indices)
1152 _, dtype = type_promotion(*promote_args, type_promotion=method)
1153 outputs_dtypes_for_allocation.append(dtype)
1155 tensors = out_tensors + in_tensors
1156 INT32_MAX = torch.iinfo(torch.int32).max
1157 if tensors[0].numel() > INT32_MAX:
1158 self.config.prefer_block_pointer = False
1159 if self.use_fast_path(tensors): # dimension collapse & use physical ordering
1160 allocated_outputs = [
1161 torch.empty_like(tensors[0], dtype=dtype)
1162 for dtype in outputs_dtypes_for_allocation
1163 ]
1164 task_shape = (tensors[0].numel(),)
1165 strides = (1,)
1166 ndim = 1
1167 args = tuple(
1168 (
1169 StridedBuffer(item, task_shape, strides)
1170 if schema.is_tensor(i)
1171 else item
1172 )
1173 for i, item in enumerate(args)
1174 )
1175 kwargs = {
1176 k: StridedBuffer(item, task_shape, strides)
1177 for k, item in kwargs.items()
1178 }
1179 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1180 kwargs[f"out{output_id}"] = StridedBuffer(
1181 allocated_outputs[seq_id], task_shape, strides
1182 )
1183 else:
1184 # a simple strategy: all the undefined tensors will follow the first
1185 # tensor that is not broadcated, no attempts to simplify task, no reordering,
1186 # no dimenion collapsing
1187 shapes = tuple(item.shape for item in in_tensors)
1189 task_shape = broadcast_shapes(shapes)
1191 if out_tensors:
1192 for index, item in enumerate(out_tensors):
1193 if list(item.shape) != list(task_shape):
1194 raise RuntimeError(
1195 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!"
1196 )
1197 # output arguments must not have internal overlapping for pointwise operation
1198 if has_internal_overlapping(item) == MemOverlap.Yes:
1199 raise RuntimeError(
1200 "Pointwise Input arguments should not have internal overlapping."
1201 )
1203 ndim = len(task_shape)
1204 for item in tensors:
1205 if item.shape == task_shape:
1206 allocated_outputs = [
1207 torch.empty_like(item, dtype=dtype)
1208 for dtype in outputs_dtypes_for_allocation
1209 ]
1210 break
1211 else: # nobreak
1212 device = tensors[0].device
1213 allocated_outputs = [
1214 torch.empty(task_shape, dtype=dtype, device=device)
1215 for dtype in outputs_dtypes_for_allocation
1216 ]
1217 args = tuple(
1218 (
1219 StridedBuffer(
1220 item,
1221 task_shape,
1222 broadcasted_stride(item.shape, item.stride(), task_shape),
1223 )
1224 if schema.is_tensor(i)
1225 else item
1226 )
1227 for i, item in enumerate(args)
1228 )
1229 kwargs = {
1230 k: StridedBuffer(
1231 item,
1232 task_shape,
1233 broadcasted_stride(item.shape, item.stride(), task_shape),
1234 )
1235 for k, item in kwargs.items()
1236 }
1237 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1238 item = allocated_outputs[seq_id]
1239 kwargs[f"out{output_id}"] = StridedBuffer(
1240 item,
1241 task_shape,
1242 broadcasted_stride(item.shape, item.stride(), task_shape),
1243 )
1244 return (ndim, args, kwargs)
1246 def _unwrap(self, tensors):
1247 # unwrap StridedBuffer to get Tensor
1248 if self.fx.num_output_tensors() == 1:
1249 item = tensors
1250 return item.unwrap()
1251 return tuple(item.unwrap() for item in tensors)
1253 def instantiate(self, ndim):
1254 # NOTE: manually instantiated overload does not have `prepare_args` as
1255 # preprocessing, so you have to manually allocate output and make sure that
1256 # the inputs & ouputs actually fits the manually instantiated overload
1257 key = f"{ndim}_{self.config.prefer_block_pointer}"
1258 if key in self.overloads:
1259 return self.overloads[key]
1261 code = IndentedBuffer()
1263 scalar_fn_name = self._scalar_fn.__name__
1264 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}"
1265 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}"
1266 module_gen = ModuleGenerator(
1267 self.fx,
1268 self._scalar_fn,
1269 ndim,
1270 kernel_name,
1271 wrapper_name,
1272 self.config,
1273 )
1274 module_gen.codegen(code)
1276 # NOTE: [why write the generated code to a file]
1277 # triton uses inpsect to get the source of the jitted function, which requires
1278 # that the source code can be found by inspect
1279 # We write it into a file, since inspect cannot find the source of functions dynamically
1280 # created via exec string. We can help inspect to find the source by hacking linecache
1281 # library, but we find generating a module simpler, since we can generating 2 functions
1282 # the kernel and the wrapper, and the wrapper calls the kernel.
1283 file_name = (
1284 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_"
1285 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}"
1286 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}"
1287 ".py"
1288 )
1290 file_path = code_cache_dir() / file_name
1291 write_atomic(file_path, code.getvalue())
1293 # load
1294 spec = importlib.util.spec_from_file_location(
1295 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}",
1296 file_path,
1297 )
1298 m = importlib.util.module_from_spec(spec)
1299 # do not expose it to sys.modules
1300 # sys.modules["_add_module"] = m
1302 # NOTE: [why not import the scalar function]
1303 # we do not re-import the scalar function, although the generated kernel **calls** it
1304 # Since a function's __name__ may be changed, from the module where it is defined import its
1305 # __name__ is not same; Also the same may be rebind to something else, importing via name
1306 # cannot guarantee that scalar function is imported.
1307 # So we copy the scalar function and its __globals__ to the generated module to do this
1308 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime
1309 spec.loader.exec_module(m)
1310 m.__dict__.update(self._scalar_fn.__globals__)
1311 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn
1313 overload = getattr(m, wrapper_name)
1314 self.overloads[key] = overload
1315 return overload
1318def pointwise_dynamic(
1319 f: Optional[JITFunction] = None,
1320 *,
1321 num_inputs: Optional[int] = None,
1322 is_tensor: Optional[List[bool]] = None,
1323 dtypes: Optional[List[Optional[type]]] = None,
1324 num_outputs: Optional[int] = None,
1325 promotion_methods: Optional[Tuple[int, ...]] = None,
1326 config: Optional[CodeGenConfig] = None,
1327):
1328 def decorator(fn):
1329 nonlocal num_inputs
1330 if (num_inputs is None) and (is_tensor is None) and (dtypes is None):
1331 num_inputs = len(fn.arg_names)
1332 op_desc = FunctionSchema(
1333 num_inputs=num_inputs,
1334 is_tensor=is_tensor,
1335 dtypes=dtypes,
1336 num_outputs=num_outputs,
1337 promotion_methods=promotion_methods,
1338 )
1339 return PointwiseDynamicFunction(op_desc, fn, config)
1341 if f is not None:
1342 return decorator(f)
1343 return decorator