Coverage for src/flag_gems/runtime/backend/_cambricon/utils/pointwise_dynamic.py: 0%
1003 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import importlib
2import os
3from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple
5import torch
6import triton
7from triton.runtime.jit import JITFunction
9from flag_gems.utils.code_cache import code_cache_dir
10from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config
12from flag_gems.utils.shape_utils import (
13 MemOverlap,
14 all_c_contiguous,
15 all_the_same_shape,
16 all_the_same_stride,
17 broadcast_shapes,
18 broadcasted_stride,
19 check_tensor_attributes,
20 has_internal_overlapping,
21)
22from flag_gems.utils.tensor_wrapper import StridedBuffer
23from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion
26# ------------------ Operation Description ---------------------------
27def _type_name(type) -> str:
28 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object"
29 if type in (bool, int, float, str):
30 return type.__name__
31 if isinstance(type, torch.dtype):
32 return str(type)
33 return str(type)
36def _check_typed_list(container, type):
37 for item in container:
38 assert isinstance(item, type)
41def _check_sized_list(container, size):
42 assert len(container) == size
45def _tuple_content(strings: Sequence[str]) -> str:
46 # comma separated list
47 if len(strings) == 0:
48 return ""
49 if len(strings) == 1:
50 return f"{strings[0]},"
51 else:
52 return ", ".join(strings)
55def _cs(strings: Iterable[str]) -> str:
56 return ", ".join(strings)
59def _broadcast_vec(i, ndim):
60 axes = [":" if j == i else "None" for j in range(ndim)]
61 return f"[{_cs(axes)}]"
64class FunctionSchema:
65 _num_inputs: int
66 _is_tensor: List[bool]
67 _dtypes: List[Optional[type]]
69 _num_input_tensors: int
70 _num_non_tensor_inputs: int
72 _num_outputs: int
73 _promotion_methods: List[Tuple[int, ...]]
75 def __init__(
76 self,
77 *,
78 num_inputs: Optional[int] = None,
79 is_tensor: Optional[List[bool]] = None,
80 dtypes: Optional[List[Optional[type]]] = None,
81 num_outputs: Optional[int] = None,
82 promotion_methods=None,
83 ):
84 if is_tensor is not None:
85 _check_typed_list(is_tensor, bool)
86 if dtypes is not None:
87 _check_typed_list(dtypes, (type, type(None)))
89 if promotion_methods is None:
90 raise ValueError(
91 "No type promotion method provided! You must provide type promotion method for each output!"
92 )
93 else:
94 self._promotion_methods = self.canonicalize_promotion_methods(
95 promotion_methods
96 )
97 if num_inputs is not None:
98 self._num_inputs = num_inputs
99 if is_tensor is not None:
100 _check_sized_list(is_tensor, num_inputs)
101 self._is_tensor = is_tensor
102 else:
103 self._is_tensor = [True] * num_inputs
105 if dtypes is not None:
106 _check_sized_list(dtypes, num_inputs)
107 self._dtypes = dtypes
108 else:
109 self._dtypes = [None] * num_inputs
110 elif is_tensor is not None:
111 self._num_inputs = len(is_tensor)
112 self._is_tensor = is_tensor
113 if dtypes is not None:
114 _check_sized_list(dtypes, self._num_inputs)
115 self._dtypes = dtypes
116 else:
117 self._dtypes = [None] * self._num_inputs
118 elif dtypes is not None:
119 self._num_inputs = len(dtypes)
120 self._dtypes = dtypes
121 if is_tensor is not None:
122 _check_sized_list(is_tensor, self._num_inputs)
123 self._is_tensor = is_tensor
124 else:
125 self._is_tensor = [item is None for item in dtypes]
126 else:
127 raise ValueError(
128 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified."
129 )
131 if num_outputs is not None:
132 self._num_outputs = num_outputs
133 _check_sized_list(promotion_methods, num_outputs)
134 else:
135 self._num_outputs = len(promotion_methods)
137 assert self._num_inputs >= 1
138 assert self._num_outputs >= 1
140 self._num_input_tensors = sum(self._is_tensor)
141 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors
142 self._input_id = self._compute_input_id()
144 @staticmethod
145 def canonicalize_promotion_methods(promotion_methods):
146 canonicalized = []
147 for item in promotion_methods:
148 *arg_indices, method = item
149 canonicalized.append(
150 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method])
151 )
152 return canonicalized
154 def num_inputs(self):
155 # num of arguments, outputs not included
156 return self._num_inputs
158 def num_outputs(self):
159 return self._num_outputs
161 def is_tensor(self, arg_id: int) -> bool:
162 return self._is_tensor[arg_id]
164 def input_type(self, arg_id) -> Optional[type]:
165 return self._dtypes[arg_id]
167 def output_type(self, i):
168 return self._promotion_methods[i]
170 def num_input_tensors(self) -> int:
171 return self._num_input_tensors
173 def num_output_tensors(self) -> int:
174 return self._num_outputs
176 def num_non_tensor_args(self) -> int:
177 return self._num_non_tensor_inputs
179 def signature(self, outputs_in_arg: bool = False) -> str:
180 input_types = []
181 for is_tensor, dtype in zip(self._is_tensor, self._dtypes):
182 if is_tensor:
183 input_types.append("StridedBuffer")
184 else:
185 if dtype is None:
186 input_types.append("scalar")
187 else:
188 input_types.append(_type_name(dtype))
190 output_types = []
192 if outputs_in_arg:
193 for i in range(self.num_outputs()):
194 output_types.append(f"StridedBuffer(a{1}!)")
195 input_types.extend(output_types)
196 else:
197 for _ in range(self.num_outputs()):
198 output_types.append("StridedBuffer")
199 sig = f"Pointwise: {', '.join(input_types)} -> {', '.join(output_types)}"
200 return sig
202 def _compute_input_id(self):
203 input_tensor_index = 0
204 non_tensor_index = 0
205 mapping: List[int] = []
206 for i in range(self.num_inputs()):
207 if self.is_tensor(i):
208 mapping.append(input_tensor_index)
209 input_tensor_index += 1
210 else:
211 mapping.append(non_tensor_index)
212 non_tensor_index += 1
213 return mapping
215 def input_index(self, idx):
216 return self._input_id[idx]
218 def __str__(self) -> str:
219 return self.signature(outputs_in_arg=False)
222class KernelGenerator:
223 def __init__(
224 self,
225 function_schema: FunctionSchema,
226 scalar_fn: triton.JITFunction,
227 rank: int,
228 name: str,
229 config: CodeGenConfig,
230 ):
231 self.fx = function_schema
232 self.fn = scalar_fn
233 self.ndim = rank
234 self.name = name
235 self.config = config
237 self.fn_name = scalar_fn.__name__
238 self.fn_module = scalar_fn.__module__
240 def gen_import_function(self, code: IndentedBuffer):
241 code.writeline(f'"""Quoted source of {self.fn_name}:')
242 code.writemultiline(self.fn.src)
243 code.writeline('"""')
244 code.newline()
246 def gen_config_prune(self, code):
247 code.writeline("def config_prune(configs, named_args, **kwargs):")
248 with code.indent():
249 code.writeline("new_configs = []")
250 code.writeline("elem_sizes = []")
251 for i in range(self.fx.num_input_tensors()):
252 code.writeline(
253 f"elem_sizes.append(named_args['in{i}_ptr'].dtype.itemsize)"
254 )
255 for i in range(self.fx.num_output_tensors()):
256 code.writeline(
257 f"elem_sizes.append(named_args['out{i}_ptr'].dtype.itemsize)"
258 )
260 code.writeline("max_elem_size = max(elem_sizes)")
261 shape = ", ".join(f"s{i}" for i in range(self.ndim))
262 named_shape = ", ".join(f"named_args['s{i}']" for i in range(self.ndim))
263 code.writeline(f"{shape} = {named_shape}")
264 tile_sizes = ", ".join(f"tile_size{i}" for i in range(self.ndim))
265 tile_size_dict = ", ".join(
266 f"'tile_size{i}': tile_size{i}" for i in range(self.ndim)
267 )
269 code.writeline("if max_elem_size < 8:")
270 with code.indent():
271 code.writeline("max_tile_sizes = [1024, 2048, 4096, 8192, 16000]")
272 code.writeline("for max_tile_size in max_tile_sizes:")
273 with code.indent():
274 code.writeline(
275 f"({tile_sizes}, ) = heuristics_for_tile_size(max_tile_size, {shape})"
276 )
277 code.writeline(
278 f"new_configs.append(triton.Config({{{tile_size_dict}}}, num_stages=3, num_warps=1))"
279 )
280 code.writeline("else:")
281 with code.indent():
282 code.writeline("max_tile_sizes = [1024, 2048, 4096, 8000]")
283 code.writeline("for max_tile_size in max_tile_sizes:")
284 with code.indent():
285 code.writeline(
286 f"({tile_sizes}, ) = heuristics_for_tile_size(max_tile_size, {shape})"
287 )
288 code.writeline(
289 f"new_configs.append(triton.Config({{{tile_size_dict}}}, num_stages=3, num_warps=1))"
290 )
292 code.writeline("return new_configs")
293 code.newline()
294 code.newline()
296 def gen_decorators(self, code):
297 if self.ndim in [1, 2, 3, 4] and (not self.config.prefer_1d_tile):
298 self.gen_config_prune(code)
300 num_non_tensor_args = self.fx.num_non_tensor_args()
301 if num_non_tensor_args > 0:
302 non_tensor_arg_names = ", ".join(
303 f"'val{i}'" for i in range(num_non_tensor_args)
304 )
306 shapes = ", ".join(f"'s{i}'" for i in range(self.ndim))
307 stride_args = []
308 for i in range(self.fx.num_input_tensors()):
309 stride_args.append(_cs(f"'in{i}_stride{j}'" for j in range(self.ndim)))
310 for i in range(self.fx.num_output_tensors()):
311 stride_args.append(_cs(f"'out{i}_stride{j}'" for j in range(self.ndim)))
313 code.writeline("@libentry()")
314 if self.ndim == 1 and (not self.config.prefer_1d_tile):
315 code.writeline("@libtuner(")
316 with code.indent():
317 code.writeline("configs=[")
318 with code.indent():
319 code.writeline(
320 "triton.Config({'tile_size0': 1024}, num_stages=3, num_warps=1),"
321 )
322 code.writeline(
323 "triton.Config({'tile_size0': 2048}, num_stages=3, num_warps=1),"
324 )
325 code.writeline("],")
326 if num_non_tensor_args > 0:
327 code.writeline(
328 f"key=['num_tasks', {_cs(stride_args)}, {non_tensor_arg_names}],"
329 )
330 else:
331 code.writeline(f"key=['num_tasks', {_cs(stride_args)}],")
332 code.writeline("prune_configs_by={'early_config_prune': config_prune},")
333 output_params = [
334 f"out{i}_ptr" for i in range(self.fx.num_output_tensors())
335 ]
336 output_elements = ", ".join(f"'{name}'" for name in output_params)
337 code.writeline(f"restore_value=[{output_elements}],")
338 code.writeline(")")
340 if self.ndim == 2 and (not self.config.prefer_1d_tile):
341 code.writeline("@libtuner(")
342 with code.indent():
343 code.writeline("configs=[")
344 with code.indent():
345 code.writeline(
346 "triton.Config({'tile_size0': 1, 'tile_size1': 1024}, num_stages=3, num_warps=1),"
347 )
348 code.writeline(
349 "triton.Config({'tile_size0': 1, 'tile_size1': 2048}, num_stages=3, num_warps=1),"
350 )
351 code.writeline("],")
352 if num_non_tensor_args > 0:
353 code.writeline(
354 f"key=[{shapes}, {_cs(stride_args)}, {non_tensor_arg_names}],"
355 )
356 else:
357 code.writeline(f"key=[{shapes}, {_cs(stride_args)}],")
358 code.writeline("prune_configs_by={'early_config_prune': config_prune},")
359 output_params = [
360 f"out{i}_ptr" for i in range(self.fx.num_output_tensors())
361 ]
362 output_elements = ", ".join(f"'{name}'" for name in output_params)
363 code.writeline(f"restore_value=[{output_elements}],")
364 code.writeline(")")
366 if self.ndim == 3 and (not self.config.prefer_1d_tile):
367 code.writeline("@libtuner(")
368 with code.indent():
369 code.writeline("configs=[")
370 with code.indent():
371 code.writeline(
372 """
373 triton.Config({'tile_size0': 1, 'tile_size1': 1, 'tile_size2': 1024}, num_stages=3, num_warps=1),
374 """
375 )
376 code.writeline(
377 """
378 triton.Config({'tile_size0': 1, 'tile_size1': 1, 'tile_size2': 2048}, num_stages=3, num_warps=1),
379 """
380 )
381 code.writeline("],")
382 if num_non_tensor_args > 0:
383 code.writeline(
384 f"key=[{shapes}, {_cs(stride_args)}, {non_tensor_arg_names}],"
385 )
386 else:
387 code.writeline(f"key=[{shapes}, {_cs(stride_args)}],")
388 code.writeline("prune_configs_by={'early_config_prune': config_prune},")
389 output_params = [
390 f"out{i}_ptr" for i in range(self.fx.num_output_tensors())
391 ]
392 output_elements = ", ".join(f"'{name}'" for name in output_params)
393 code.writeline(f"restore_value=[{output_elements}],")
394 code.writeline(")")
396 if self.ndim == 4 and (not self.config.prefer_1d_tile):
397 code.writeline("@libtuner(")
398 with code.indent():
399 code.writeline("configs=[")
400 with code.indent():
401 code.writeline(
402 """
403 triton.Config({'tile_size0': 1,'tile_size1': 1,'tile_size2': 1,'tile_size3': 1024},num_stages=3,num_warps=1),
404 """
405 )
406 code.writeline(
407 """
408 triton.Config({'tile_size0': 1,'tile_size1': 1,'tile_size2': 1,'tile_size3': 2048},num_stages=3,num_warps=1),
409 """
410 )
411 code.writeline("],")
412 if num_non_tensor_args > 0:
413 code.writeline(
414 f"key=[{shapes}, {_cs(stride_args)}, {non_tensor_arg_names}],"
415 )
416 else:
417 code.writeline(f"key=[{shapes}, {_cs(stride_args)}],")
418 code.writeline("prune_configs_by={'early_config_prune': config_prune},")
419 output_params = [
420 f"out{i}_ptr" for i in range(self.fx.num_output_tensors())
421 ]
422 output_elements = ", ".join(f"'{name}'" for name in output_params)
423 code.writeline(f"restore_value=[{output_elements}],")
424 code.writeline(")")
426 if num_non_tensor_args > 0:
427 # we do not specialize non tensor args since they are passed into the inlined function
428 # which means that their values may not deserve specialization
429 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)]
430 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
431 else:
432 code.writeline("@triton.jit")
434 def input_name(self, i):
435 is_tensor = self.fx.is_tensor(i)
436 name = "in" if is_tensor else "val"
437 index = self.fx.input_index(i)
438 return f"{name}{index}"
440 def output_name(self, i):
441 return f"out{i}"
443 def gen_signature(self, code, with_block_pointer=False):
444 code.writeline(f"def {self.name}(")
445 with code.indent():
446 input_tensor_index = 0
447 non_tensor_index = 0
448 output_tensor_index = 0
450 schema = self.fx
451 # signature: inputs ptrs & non tensor inputs
452 for i in range(schema.num_inputs()):
453 if schema.is_tensor(i):
454 code.writeline(
455 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
456 )
457 input_tensor_index += 1
458 else:
459 if schema.input_type(i) is not None:
460 code.writeline(
461 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
462 )
463 else:
464 code.writeline(f"val{non_tensor_index},")
465 non_tensor_index += 1
467 # signature: output ptrs
468 for i in range(schema.num_outputs()):
469 code.writeline(
470 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
471 )
472 output_tensor_index += 1
474 # signature: strides, for each tensor arguments
475 ndim = self.ndim
476 if ndim > 0:
477 # strides for inputs
478 for i in range(schema.num_input_tensors()):
479 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
480 code.writeline(f"{stride_args}, # strides for in{i}")
481 if with_block_pointer:
482 stride_order_args = _cs(
483 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
484 )
485 code.writeline(f"{stride_order_args}, # stride order for in{i}")
486 zero_stride_args = _cs(
487 f"in{i}_zero_stride{j}: tl.constexpr" for j in range(ndim)
488 )
489 code.writeline(
490 f"{zero_stride_args}, # zero stride flag for in{i}"
491 )
493 # strides for outputs
494 for i in range(schema.num_output_tensors()):
495 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
496 code.writeline(f"{stride_args}, # strides for out{i}")
497 if with_block_pointer:
498 stride_order_args = _cs(
499 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
500 )
501 code.writeline(
502 f"{stride_order_args}, # stride order for out{i}"
503 )
504 zero_stride_args = _cs(
505 f"out{i}_zero_stride{j}: tl.constexpr" for j in range(ndim)
506 )
507 code.writeline(
508 f"{zero_stride_args}, # zero stride flag for out{i}"
509 )
511 # task space, used to reconstruct multi index
512 task_space_args = _cs(f"s{i}" for i in range(ndim))
513 code.writeline(f"{task_space_args}, # task_space")
515 # number of tasks, used to compute mask
516 code.writeline("num_tasks,")
517 if self.config.prefer_block_pointer:
518 code.writeline("FALLBACK_BPTR: tl.constexpr,")
520 # tile size & tiles_per_cta, gsl style
521 if ndim > 0:
522 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim))
523 code.writeline(f"{tile_sizes},")
524 if ndim > 4:
525 code.writeline("tiles_per_cta: int,")
526 code.writeline("one_tile_per_cta: tl.constexpr,")
527 code.writeline("):")
529 def gen_signature_1d_tile(self, code):
530 code.writeline(f"def {self.name}(")
531 with code.indent():
532 input_tensor_index = 0
533 non_tensor_index = 0
534 output_tensor_index = 0
536 schema = self.fx
537 # signature: inputs ptrs & non tensor inputs
538 for i in range(schema.num_inputs()):
539 if schema.is_tensor(i):
540 code.writeline(
541 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
542 )
543 input_tensor_index += 1
544 else:
545 if schema.input_type(i) is not None:
546 code.writeline(
547 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
548 )
549 else:
550 code.writeline(f"val{non_tensor_index},")
551 non_tensor_index += 1
553 # signature: output ptrs
554 for i in range(schema.num_outputs()):
555 code.writeline(
556 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
557 )
558 output_tensor_index += 1
560 # signature: strides, for each tensor arguments
561 ndim = self.ndim
562 if ndim > 0:
563 # strides for inputs
564 for i in range(schema.num_input_tensors()):
565 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
566 code.writeline(f"{stride_args}, # strides for in{i}")
568 # strides for outputs
569 for i in range(schema.num_output_tensors()):
570 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
571 code.writeline(f"{stride_args}, # strides for out{i}")
573 # task space, used to reconstruct multi index
574 task_space_args = _cs(f"s{i}" for i in range(ndim))
575 code.writeline(f"{task_space_args}, # task_space")
577 # number of tasks, used to compute mask
578 code.writeline("num_tasks,")
580 if self.config.prefer_block_pointer:
581 code.writeline("FALLBACK_BPTR: tl.constexpr,")
583 # tile size & tiles_per_cta, gsl style
584 if ndim > 0:
585 code.writeline("tiles_per_cta: int,")
586 code.writeline("tile_size: tl.constexpr,")
587 code.writeline("one_tile_per_cta: tl.constexpr,")
588 code.writeline("):")
590 def gen_num_tiles(self, code):
591 # tile-grid size
592 ndim = self.ndim
593 for i in range(ndim):
594 if i < ndim:
595 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})")
597 def gen_body_for_0d(self, code):
598 schema = self.fx
599 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
600 outputs_to_scalar_fn = [
601 self.output_name(i) for i in range(schema.num_output_tensors())
602 ]
603 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
604 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
606 code.writeline("# loads")
607 for i in range(schema.num_input_tensors()):
608 code.writeline(
609 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) "
610 "# workaround the bug on bool, we should use the pointer's dtype)"
611 )
612 code.newline()
614 code.writeline("# compute")
615 code.writeline(
616 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
617 )
618 code.newline()
620 code.writeline("# stores")
621 for i in range(schema.num_output_tensors()):
622 code.writeline(
623 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))"
624 )
625 code.newline()
626 return code
628 # nd tile 1d grid kernel with block pointer
629 def gen_body_one_tile_per_cta_with_bptr(self, code):
630 ndim = self.ndim
631 schema = self.fx
633 # block pointer for each operand
634 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim)))
635 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim)))
636 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim)))
638 # reconstruct pid multi index
639 code.writeline(
640 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
641 )
642 for i in reversed(range(ndim)):
643 if i > 0:
644 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
645 code.writeline(f"tile_id //= num_tiles{i}")
646 else:
647 code.writeline(f"tile_id{i} = tile_id")
648 code.newline()
650 # cta_offsets
651 code.writeline("# tile offsets")
653 # Because block pointer only support `tl.int32` indexing, when max offsets
654 # of ptrs exceeding 2^31, we should fallback it to noraml indexing method.
655 code.writeline("if not FALLBACK_BPTR:")
656 with code.indent():
657 for i in range(ndim):
658 # Or else: AssertionError: Block pointers only support 32 bit
659 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing
660 # for 64 bit support
661 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)")
663 # loads
664 code.writeline("# loads")
665 for i in range(schema.num_input_tensors()):
666 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim)))
667 order = _tuple_content(
668 tuple(f"in{i}_stride_order{j}" for j in range(ndim))
669 )
671 for j in range(ndim):
672 code.writeline(f"if in{i}_zero_stride{j}:")
673 with code.indent():
674 code.writeline(f"in{i}_stride{j} = 0")
676 code.writeline(
677 f"in{i}_bptr = tl.make_block_ptr("
678 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
679 )
681 code.writeline(
682 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) "
683 )
684 code.newline()
686 # compute
687 # TODO: sepearate this part
688 inputs_to_scalar_fn = [
689 self.input_name(i) for i in range(schema.num_inputs())
690 ]
691 outputs_to_scalar_fn = [
692 self.output_name(i) for i in range(schema.num_output_tensors())
693 ]
694 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
695 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
697 code.writeline("# compute")
698 code.writeline(
699 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
700 )
701 code.newline()
703 # stores
704 for i in range(schema.num_output_tensors()):
705 strides = _tuple_content(
706 tuple(f"out{i}_stride{j}" for j in range(ndim))
707 )
708 order = _tuple_content(
709 tuple(f"out{i}_stride_order{j}" for j in range(ndim))
710 )
712 for j in range(ndim):
713 code.writeline(f"if out{i}_zero_stride{j}:")
714 with code.indent():
715 code.writeline(f"out{i}_stride{j} = 0")
717 code.writeline(
718 f"out{i}_bptr = tl.make_block_ptr("
719 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
720 )
722 code.writeline(
723 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))"
724 )
725 code.writeline("else:")
726 with code.indent():
727 # offsets
728 for i in range(ndim):
729 code.writeline(
730 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})"
731 )
733 # masks
734 for i in range(ndim):
735 code.writeline(f"mask{i} = offsets{i} < s{i}")
736 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim))
737 mask_combine = " & ".join(masks)
738 code.writeline(f"mask = {mask_combine}")
740 # loads
741 code.writeline("# loads")
742 for i in range(schema.num_input_tensors()):
743 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim)))
744 order = _tuple_content(
745 tuple(f"in{i}_stride_order{j}" for j in range(ndim))
746 )
748 for j in range(ndim):
749 code.writeline(f"if in{i}_zero_stride{j}:")
750 with code.indent():
751 code.writeline(f"in{i}_stride{j} = 0")
752 offsets = tuple(
753 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}"
754 for j in range(ndim)
755 )
756 offset_combine = " + ".join(offsets)
757 code.writeline(
758 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
759 )
761 code.newline()
763 # compute
764 inputs_to_scalar_fn = [
765 self.input_name(i) for i in range(schema.num_inputs())
766 ]
767 outputs_to_scalar_fn = [
768 self.output_name(i) for i in range(schema.num_output_tensors())
769 ]
770 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
771 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
773 code.writeline("# compute")
774 code.writeline(
775 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
776 )
777 code.newline()
779 # store
780 for i in range(schema.num_output_tensors()):
781 strides = _tuple_content(
782 tuple(f"out{i}_stride{j}" for j in range(ndim))
783 )
784 order = _tuple_content(
785 tuple(f"out{i}_stride_order{j}" for j in range(ndim))
786 )
788 for j in range(ndim):
789 code.writeline(f"if out{i}_zero_stride{j}:")
790 with code.indent():
791 code.writeline(f"out{i}_stride{j} = 0")
793 offsets = tuple(
794 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}"
795 for j in range(ndim)
796 )
797 offset_combine = " + ".join(offsets)
798 code.writeline(
799 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
800 )
802 def gen_body_gsl_with_bptr(self, code):
803 code.writeline("num_ctas = tle.num_programs(0)")
804 if self.ndim <= 4:
805 num_tiles = " * ".join([f"num_tiles{i}" for i in range(self.ndim)])
806 code.writeline(
807 f"tiles_per_cta = tl.cdiv({num_tiles}, num_ctas).to(tl.int32)"
808 )
809 code.writeline("for j in range(0, tiles_per_cta):")
810 with code.indent():
811 code.writeline("tile_id = pid + j * num_ctas")
812 self.gen_body_one_tile_per_cta_with_bptr(code)
814 def gen_body_one_tile_per_cta_without_bptr(self, code):
815 ndim = self.ndim
816 schema = self.fx
818 # reconstruct pid multi index
819 code.writeline(
820 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
821 )
822 for i in reversed(range(ndim)):
823 if i > 0:
824 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
825 code.writeline(f"tile_id //= num_tiles{i}")
826 else:
827 code.writeline(f"tile_id{i} = tile_id")
828 code.newline()
830 # offsets
831 for i in range(ndim):
832 code.writeline(
833 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})"
834 )
836 # masks
837 for i in range(ndim):
838 code.writeline(f"mask{i} = offsets{i} < s{i}")
839 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim))
840 mask_combine = " & ".join(masks)
841 code.writeline(f"mask = {mask_combine}")
843 # loads
844 code.writeline("# loads")
845 for i in range(schema.num_input_tensors()):
846 offsets = tuple(
847 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}"
848 for j in range(ndim)
849 )
850 offset_combine = " + ".join(offsets)
851 code.writeline(
852 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
853 )
855 code.newline()
857 # compute
858 # TODO: sepearate this part
859 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
860 outputs_to_scalar_fn = [
861 self.output_name(i) for i in range(schema.num_output_tensors())
862 ]
863 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
864 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
866 code.writeline("# compute")
867 code.writeline(
868 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
869 )
870 code.newline()
872 # stores
873 for i in range(schema.num_output_tensors()):
874 offsets = tuple(
875 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}"
876 for j in range(ndim)
877 )
878 offset_combine = " + ".join(offsets)
879 code.writeline(
880 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
881 )
883 def gen_body_gsl_without_bptr(self, code):
884 code.writeline("num_ctas = tle.num_programs(0)")
885 if self.ndim <= 4:
886 num_tiles = " * ".join([f"num_tiles{i}" for i in range(self.ndim)])
887 code.writeline(f"tiles_per_cta = tl.cdiv({num_tiles}, num_ctas)")
888 code.writeline("for j in range(0, tiles_per_cta):")
889 with code.indent():
890 code.writeline("tile_id = pid + j * num_ctas")
891 self.gen_body_one_tile_per_cta_without_bptr(code)
893 def codegen_nd_tile_with_bptr(self, code):
894 """Generate kernel nd tile & 1d grid with gsl support with block pointer."""
895 self.gen_import_function(code)
896 self.gen_decorators(code)
897 self.gen_signature(code, with_block_pointer=True)
899 # function body for rank-0
900 if self.ndim == 0:
901 with code.indent():
902 self.gen_body_for_0d(code)
903 return code
905 with code.indent():
906 code.writeline("pid = tle.program_id(0)")
907 self.gen_num_tiles(code)
908 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
909 if self.ndim > 4:
910 code.writeline("if one_tile_per_cta: # monolitic kernel style")
911 with code.indent():
912 code.writeline("tile_id = pid")
913 self.gen_body_one_tile_per_cta_with_bptr(code)
914 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
915 code.writeline("else: # grid-stride-loop style kernel")
916 with code.indent():
917 self.gen_body_gsl_with_bptr(code)
918 else:
919 self.gen_body_gsl_with_bptr(code)
920 code.newline()
921 return code
923 def codegen_nd_tile_without_bptr(self, code):
924 self.gen_import_function(code)
925 self.gen_decorators(code)
926 self.gen_signature(code, with_block_pointer=False)
928 # function body for rank-0
929 if self.ndim == 0:
930 with code.indent():
931 self.gen_body_for_0d(code)
932 return code
934 with code.indent():
935 code.writeline("pid = tle.program_id(0)")
936 self.gen_num_tiles(code)
937 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
938 if self.ndim > 4:
939 code.writeline("if one_tile_per_cta: # monolitic kernel style")
940 with code.indent():
941 code.writeline("tile_id = pid")
942 self.gen_body_one_tile_per_cta_without_bptr(code)
943 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
944 code.writeline("else: # grid-stride-loop style kernel")
945 with code.indent():
946 self.gen_body_gsl_without_bptr(code)
947 else:
948 self.gen_body_gsl_without_bptr(code)
949 code.newline()
950 return code
952 def codegen_nd_tile(self, code):
953 use_block_pointer = self.config.prefer_block_pointer
954 if use_block_pointer:
955 self.codegen_nd_tile_with_bptr(code)
956 else:
957 self.codegen_nd_tile_without_bptr(code)
958 return code
960 def gen_body_one_tile_per_cta_1d_tile(self, code):
961 ndim = self.ndim
962 schema = self.fx
964 # tile id
965 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)")
966 code.writeline("mask = tid < num_tasks")
968 # multi index reconstruction
969 for i in reversed(range(ndim)):
970 if i > 0:
971 code.writeline(f"i{i} = tid % s{i}")
972 code.writeline(f"tid //= s{i}")
973 else:
974 code.writeline(f"i{i} = tid")
975 code.newline()
977 # loads
978 code.writeline("# loads")
979 for i in range(schema.num_input_tensors()):
980 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim))
981 offset_combine = " + ".join(offsets)
982 code.writeline(
983 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
984 )
986 code.newline()
988 # compute
989 # TODO: sepearate this part
990 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
991 outputs_to_scalar_fn = [
992 self.output_name(i) for i in range(schema.num_output_tensors())
993 ]
994 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
995 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
997 code.writeline("# compute")
998 code.writeline(
999 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
1000 )
1001 code.newline()
1003 # stores
1004 for i in range(schema.num_output_tensors()):
1005 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim))
1006 offset_combine = " + ".join(offsets)
1007 code.writeline(
1008 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
1009 )
1011 def gen_body_gsl_1d_tile(self, code):
1012 code.writeline("num_ctas = tle.num_programs(0)")
1013 code.writeline("for j in range(0, tiles_per_cta):")
1014 with code.indent():
1015 code.writeline("tile_id = pid + j * num_ctas")
1016 self.gen_body_one_tile_per_cta_1d_tile(code)
1018 def codegen_1d_tile(self, code):
1019 """Generate kernel 1d tile & 1d grid with gsl support."""
1020 self.gen_import_function(code)
1021 self.gen_decorators(code)
1022 self.gen_signature_1d_tile(code)
1024 # function body for rank-0
1025 if self.ndim == 0:
1026 with code.indent():
1027 self.gen_body_for_0d(code)
1028 return code
1030 with code.indent():
1031 code.writeline("pid = tle.program_id(0)")
1032 # code.writeline("num_ctas = te.num_programs(0)")
1033 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
1034 code.writeline("if one_tile_per_cta: # monolitic kernel style")
1035 with code.indent():
1036 code.writeline("tile_id = pid")
1037 self.gen_body_one_tile_per_cta_1d_tile(code)
1038 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
1039 code.writeline("else: # grid-stride-loop style kernel")
1040 with code.indent():
1041 self.gen_body_gsl_1d_tile(code)
1042 code.newline()
1043 return code
1046class WrapperGenerator:
1047 def __init__(
1048 self,
1049 function_schema: FunctionSchema,
1050 jit_fn_name: str,
1051 ndim: int,
1052 name: str,
1053 config: CodeGenConfig,
1054 ):
1055 self.fx = function_schema
1056 self.jit_fn_name = jit_fn_name
1057 self.ndim = ndim
1058 self.name = name
1059 self.config = config
1061 def input_name(self, i):
1062 is_tensor = self.fx.is_tensor(i)
1063 name = "in" if is_tensor else "val"
1064 index = self.fx.input_index(i)
1065 return f"{name}{index}"
1067 def output_name(self, i):
1068 return f"out{i}"
1070 def gen_signature(self, code: IndentedBuffer):
1071 # TODO: check if triton handles constexprs transitively
1072 schema = self.fx
1073 params: List[str] = []
1074 for i in range(schema.num_inputs()):
1075 if schema.is_tensor(i):
1076 params.append(
1077 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]"
1078 )
1079 else:
1080 arg_type = schema.input_type(i)
1081 if arg_type is not None:
1082 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}")
1083 else:
1084 params.append(f"{self.input_name(i)}")
1085 # NOTE: [the wrapper's signature and rules for passing parameters ]
1086 # input params: must be passed by position, since the names are renamed to
1087 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd
1088 # So we enforce that these parameters must be passed by position.
1089 # maybe we can fix it later
1090 # output parameters: must be passed by keyword, since the scalar function
1091 # do not have output parameters(think of it as some scalar function, output
1092 # parameter does not make sense in this case.) They are added to allow destination
1093 # passing style API. Output parameter is convenient in cases where we want
1094 # to use some pre-defiend outputs(especially when they are some views of other
1095 # tensors). We emphasize that these parameters are added in-addition, we enforce
1096 # that they be passed by keyword. After all, out0, out1, ... does not mismatch
1097 # names form the scalar function, since it does not have output parameters.
1098 params.append("/")
1099 params.append("*") # output params must be passed by keyword
1101 for i in range(schema.num_output_tensors()):
1102 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]")
1103 code.writeline(f"def {self.name}({_cs(params)}): ")
1105 def gen_docstring(self, code: IndentedBuffer):
1106 schema = self.fx
1107 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""'
1108 code.writeline(doc)
1110 def gen_same_shape_check(self, code: IndentedBuffer):
1111 schema: FunctionSchema = self.fx
1112 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [
1113 f"out{i}.shape" for i in range(schema.num_output_tensors())
1114 ]
1115 check: str = " == ".join(params)
1116 code.writeline(f"assert {check}, 'operand shapes mismatch'")
1118 def gen_fallback_bptr(self, code: IndentedBuffer):
1119 code.writeline("def fallback_bptr(t):")
1120 with code.indent():
1121 code.writeline("ndim = t.dim()")
1122 code.writeline("sizes = t.size()")
1123 code.writeline("if t.numel() == 0:")
1124 with code.indent():
1125 code.writeline("return False")
1126 code.writeline("for i in range(ndim):")
1127 with code.indent():
1128 code.writeline("if sizes[i] >= 2147483648:")
1129 with code.indent():
1130 code.writeline("return True")
1131 code.writeline("return False")
1132 code.newline()
1133 code.newline()
1135 def gen_task_partition(self, code: IndentedBuffer):
1136 code.writeline("# task partitioning")
1137 ndim = self.ndim
1138 if ndim == 0:
1139 code.writeline("num_warps = 1")
1140 code.writeline("num_ctas = 1")
1141 else:
1142 code.writeline("shape = out0.shape")
1143 code.writeline("num_tasks = out0.numel()")
1144 code.writeline("if num_tasks == 0:")
1145 with code.indent():
1146 self.gen_return(code)
1147 max_tile_size = self.config.max_tile_size
1148 code.writeline(
1149 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)"
1150 )
1151 code.writeline("tile_size = math.prod(tile_sizes)")
1152 code.writeline(
1153 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))"
1154 )
1155 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
1156 max_grid_size0 = self.config.max_grid_size[0]
1157 code.writeline(f"num_ctas = min({max_grid_size0} // num_warps, num_tiles)")
1159 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
1160 code.writeline("one_tile_per_cta = tiles_per_cta==1")
1161 if self.config.prefer_block_pointer:
1162 code.writeline("FALLBACK_BPTR = False")
1163 inputs = ",".join(
1164 [f"in{i}" for i in range(self.fx.num_input_tensors())]
1165 )
1166 outputs = ",".join(
1167 [f"out{i}" for i in range(self.fx.num_output_tensors())]
1168 )
1169 code.writeline(f"all_tensors = [{inputs}, {outputs}]")
1170 code.writeline("for t in all_tensors:")
1171 with code.indent():
1172 code.writeline("if fallback_bptr(t):")
1173 with code.indent():
1174 code.writeline("FALLBACK_BPTR = True")
1175 code.writeline("break")
1176 if ndim > 0 and ndim <= 4:
1177 max_grid_size0 = self.config.max_grid_size[0]
1178 dynamic_num_tiles = " * ".join(
1179 f"triton.cdiv(meta['s{i}'], meta['tile_size{i}'])" for i in range(ndim)
1180 )
1181 code.writeline(
1182 f"grid = lambda meta: (min({max_grid_size0} // num_warps, {dynamic_num_tiles}), )"
1183 )
1184 else:
1185 code.writeline("grid = (num_ctas, 1, 1)")
1187 def gen_task_partition_1d(self, code: IndentedBuffer):
1188 code.writeline("# task partitioning")
1189 ndim = self.ndim
1190 if ndim == 0:
1191 code.writeline("num_warps = 1")
1192 code.writeline("num_ctas = 1")
1193 else:
1194 code.writeline("shape = out0.shape")
1195 code.writeline("num_tasks = out0.numel()")
1196 code.writeline("if num_tasks == 0:")
1197 with code.indent():
1198 self.gen_return(code)
1199 max_tile_size = self.config.max_tile_size
1200 code.writeline(
1201 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)"
1202 )
1203 code.writeline("tile_size = tile_sizes[0]")
1204 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)")
1205 max_grid_size0 = self.config.max_grid_size[0]
1206 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
1208 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
1209 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
1210 code.writeline("one_tile_per_cta = tiles_per_cta==1")
1211 if self.config.prefer_block_pointer:
1212 code.writeline("FALLBACK_BPTR = False")
1213 inputs = ",".join(
1214 [f"in{i}" for i in range(self.fx.num_input_tensors())]
1215 )
1216 outputs = ",".join(
1217 [f"out{i}" for i in range(self.fx.num_output_tensors())]
1218 )
1219 code.writeline(f"all_tensors = [{inputs}, {outputs}]")
1220 code.writeline("for t in all_tensors:")
1221 with code.indent():
1222 code.writeline("if fallback_bptr(t):")
1223 with code.indent():
1224 code.writeline("FALLBACK_BPTR = True")
1225 code.writeline("break")
1226 code.writeline("grid = (num_ctas, 1, 1)")
1228 def gen_kernel_launch(
1229 self,
1230 code: IndentedBuffer,
1231 ):
1232 schema = self.fx
1233 ndim = self.ndim
1235 with_block_pointer = self.config.prefer_block_pointer
1237 code.writeline("# kernel launch")
1238 for i in range(schema.num_input_tensors()):
1239 code.writeline(f"in{i}_strides = in{i}.stride()")
1240 if not with_block_pointer:
1241 continue
1242 if ndim >= 2: # where ndim is 1, we don't need to compute stride order
1243 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)")
1244 else:
1245 code.writeline(f"in{i}_stride_order = (0,)")
1246 code.writeline(
1247 f"in{i}_zero_strides = [True if s == 0 else False for s in in{i}_strides]"
1248 )
1249 for i in range(schema.num_output_tensors()):
1250 code.writeline(f"out{i}_strides = out{i}.stride()")
1251 if not with_block_pointer:
1252 continue
1253 if ndim >= 2:
1254 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)")
1255 else:
1256 code.writeline(f"out{i}_stride_order = (0,)")
1257 code.writeline(
1258 f"out{i}_zero_strides = [True if s == 0 else False for s in out{i}_strides]"
1259 )
1261 code.writeline("with torch_device_fn.device(in0.device.index):")
1262 with code.indent():
1263 code.writeline(f"{self.jit_fn_name}[grid](")
1264 with code.indent():
1265 params = []
1266 # NOTE: WRAP
1267 for i in range(schema.num_inputs()):
1268 if schema.is_tensor(i):
1269 params.append(f"{self.input_name(i)}")
1270 else:
1271 params.append(self.input_name(i))
1272 for i in range(schema.num_output_tensors()):
1273 params.append(f"{self.output_name(i)}")
1275 code.writeline(f"{_cs(params)},")
1277 if ndim > 0:
1278 for i in range(schema.num_input_tensors()):
1279 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
1280 code.writeline(f"{s}, # stride for in{i}")
1281 if with_block_pointer:
1282 order = ", ".join(
1283 f"in{i}_stride_order[{j}]" for j in range(ndim)
1284 )
1285 code.writeline(f"{order}, # stride order for in{i}")
1286 zero_strides = ", ".join(
1287 f"in{i}_zero_strides[{j}]" for j in range(ndim)
1288 )
1289 code.writeline(
1290 f"{zero_strides}, # zero stride flag for in{i}"
1291 )
1293 for i in range(schema.num_output_tensors()):
1294 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
1295 code.writeline(f"{s}, # stride for out{i}")
1296 if with_block_pointer:
1297 order = ", ".join(
1298 f"out{i}_stride_order[{j}]" for j in range(ndim)
1299 )
1300 code.writeline(f"{order}, # stride orderfor out{i}")
1301 zero_strides = ", ".join(
1302 f"out{i}_zero_strides[{j}]" for j in range(ndim)
1303 )
1304 code.writeline(
1305 f"{zero_strides}, # zero stride flag for out{i}"
1306 )
1308 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
1309 code.writeline(f"{shape_args}, # task indexing space")
1310 code.writeline("num_tasks, # num tasks")
1311 if self.config.prefer_block_pointer:
1312 code.writeline("FALLBACK_BPTR=FALLBACK_BPTR,")
1313 if ndim > 4:
1314 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
1315 if ndim == 0 or ndim > 4:
1316 for i in range(ndim):
1317 code.writeline(f"tile_size{i}=tile_sizes[{i}],")
1318 if ndim > 4:
1319 code.writeline("one_tile_per_cta=one_tile_per_cta,")
1320 code.writeline("num_warps=num_warps,")
1321 code.writeline(")")
1323 def gen_kernel_launch_1d(
1324 self,
1325 code: IndentedBuffer,
1326 ):
1327 schema = self.fx
1328 ndim = self.ndim
1330 code.writeline("# kernel launch")
1331 for i in range(schema.num_input_tensors()):
1332 code.writeline(f"in{i}_strides = in{i}.stride()")
1333 for i in range(schema.num_output_tensors()):
1334 code.writeline(f"out{i}_strides = out{i}.stride()")
1336 code.writeline("with torch_device_fn.device(in0.device.index):")
1337 with code.indent():
1338 code.writeline(f"{self.jit_fn_name}[grid](")
1339 with code.indent():
1340 params = []
1341 # NOTE: WRAP
1342 for i in range(schema.num_inputs()):
1343 if schema.is_tensor(i):
1344 params.append(f"{self.input_name(i)}")
1345 else:
1346 params.append(self.input_name(i))
1347 for i in range(schema.num_output_tensors()):
1348 params.append(f"{self.output_name(i)}")
1350 code.writeline(f"{_cs(params)},")
1352 if ndim > 0:
1353 for i in range(schema.num_input_tensors()):
1354 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
1355 code.writeline(f"{s}, # stride for in{i}")
1356 for i in range(schema.num_output_tensors()):
1357 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
1358 code.writeline(f"{s}, # stride for out{i}")
1360 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
1361 code.writeline(f"{shape_args}, # task indexing space")
1362 code.writeline("num_tasks, # num tasks")
1363 if self.config.prefer_block_pointer:
1364 code.writeline("FALLBACK_BPTR=FALLBACK_BPTR,")
1365 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
1366 code.writeline("tile_size=tile_size,")
1367 code.writeline("one_tile_per_cta=one_tile_per_cta,")
1368 code.writeline("num_warps=num_warps,")
1369 code.writeline(")")
1371 def gen_return(self, code: IndentedBuffer):
1372 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors()))
1373 code.writeline(f"return {return_exprs}")
1375 def codegen_nd_tile(self, code):
1376 if self.config.prefer_block_pointer:
1377 self.gen_fallback_bptr(code)
1378 self.gen_signature(code)
1380 with code.indent():
1381 self.gen_docstring(code)
1382 self.gen_same_shape_check(code)
1383 self.gen_task_partition(code)
1384 self.gen_kernel_launch(code)
1385 self.gen_return(code)
1386 code.newline()
1387 return code
1389 def codegen_1d_tile(self, code):
1390 if self.config.prefer_block_pointer:
1391 self.gen_fallback_bptr(code)
1392 self.gen_signature(code)
1394 with code.indent():
1395 self.gen_docstring(code)
1396 self.gen_same_shape_check(code)
1397 self.gen_task_partition_1d(code)
1398 self.gen_kernel_launch_1d(code)
1399 self.gen_return(code)
1400 code.newline()
1401 return code
1404class ModuleGenerator:
1405 def __init__(
1406 self,
1407 function_schema: FunctionSchema,
1408 scalar_fn: triton.JITFunction,
1409 ndim: int,
1410 jit_fn_name: str,
1411 wrapper_name: str,
1412 config: CodeGenConfig,
1413 ):
1414 self.config = config
1415 self.wrapper_gen = WrapperGenerator(
1416 function_schema, jit_fn_name, ndim, wrapper_name, config
1417 )
1418 self.kernel_gen = KernelGenerator(
1419 function_schema, scalar_fn, ndim, jit_fn_name, config
1420 )
1422 @staticmethod
1423 def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
1424 code.writeline("import math")
1425 code.writeline("from typing import Union")
1426 code.writeline("import torch")
1427 code.writeline("import triton")
1428 code.writeline("from triton import language as tl")
1429 code.newline()
1430 code.writeline("from flag_gems.utils.shape_utils import (")
1431 code.writeline(" heuristics_for_tile_size,")
1432 code.writeline(" heuristics_for_num_warps,")
1433 code.writeline(" stride_order,")
1434 code.writeline(")")
1435 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer")
1436 code.writeline("from flag_gems.utils.libentry import libentry, libtuner")
1437 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
1438 code.writeline("from flag_gems.runtime import torch_device_fn")
1439 code.newline()
1440 code.newline()
1441 return code
1443 def codegen(self, code: IndentedBuffer):
1444 # the only runtime determined factor is the rank of the task space
1445 code = self.generate_imports(code)
1446 if self.config.prefer_1d_tile:
1447 code = self.wrapper_gen.codegen_1d_tile(code)
1448 code = self.kernel_gen.codegen_1d_tile(code)
1449 else:
1450 code = self.wrapper_gen.codegen_nd_tile(code)
1451 code = self.kernel_gen.codegen_nd_tile(code)
1452 return code
1455class PointwiseDynamicFunction:
1456 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
1457 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
1458 The generated code are written out to the cache directory (defaults to ~/.flaggems).
1459 """
1461 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None):
1462 self.fx = op_desc
1464 assert isinstance(scalar_fn, JITFunction)
1465 self._scalar_fn = scalar_fn
1466 self._scalar_fn_cache_key = scalar_fn.cache_key
1467 self.pid = os.getpid()
1469 self.config: CodeGenConfig = config or get_codegen_config()
1471 # instantiated & cached overloads
1472 self.overloads: Mapping[str, Callable] = {}
1474 def __call__(self, *args, **kwargs):
1475 # inputs must be passed by position, outputs must be passed by keyword
1476 ndim, args, kwargs = self.prepare_args(*args, **kwargs)
1477 overload = self.instantiate(ndim)
1478 out = overload(*args, **kwargs)
1479 # NOTE: overload keeps the type of outputs:
1480 # if a pre-defiend output is a Tensor or StridedBuffer, the corresponding
1481 # output is also a Tensor StridedBuffer, respectively
1482 # since prepare_args Wraps all the arguments, the outputs are all StridedBuffer
1483 # but if manually instantiated overload is directly called, take care of
1484 # that manually
1485 return self._unwrap(out)
1487 @staticmethod
1488 def use_fast_path(tensors):
1489 return all_the_same_shape(tensors) and (
1490 all_c_contiguous(tensors)
1491 or (
1492 all_the_same_stride(tensors)
1493 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0])
1494 )
1495 )
1497 def prepare_args(self, *args, **kwargs):
1498 # output allocation(when needed)
1499 # task simplification & task-rank infernece & input-output reinterpretation
1500 schema = self.fx
1501 outputs_that_need_allocation: List[int] = []
1502 out_tensors = []
1503 for i in range(schema.num_output_tensors()):
1504 k = f"out{i}"
1505 if k in kwargs:
1506 out_tensors.append(kwargs[k])
1507 else:
1508 outputs_that_need_allocation.append(i)
1509 # input arguments must be passed by position
1510 if schema._is_tensor is not None:
1511 if not check_tensor_attributes(args, (schema._is_tensor)):
1512 raise ValueError(
1513 "Input arguments must be passed by position, and the corresponding dtype must be specified."
1514 )
1515 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)]
1517 # output dtype promotions
1518 outputs_dtypes_for_allocation = []
1519 for i in outputs_that_need_allocation:
1520 *arg_indices, method = schema._promotion_methods[i]
1521 promote_args = (args[j] for j in arg_indices)
1522 _, dtype = type_promotion(*promote_args, type_promotion=method)
1523 outputs_dtypes_for_allocation.append(dtype)
1525 tensors = out_tensors + in_tensors
1526 if self.use_fast_path(tensors): # dimension collapse & use physical ordering
1527 allocated_outputs = [
1528 torch.empty_like(tensors[0], dtype=dtype)
1529 for dtype in outputs_dtypes_for_allocation
1530 ]
1531 task_shape = (tensors[0].numel(),)
1532 strides = (1,)
1533 ndim = 1
1534 args = tuple(
1535 (
1536 StridedBuffer(item, task_shape, strides)
1537 if schema.is_tensor(i)
1538 else item
1539 )
1540 for i, item in enumerate(args)
1541 )
1542 kwargs = {
1543 k: StridedBuffer(item, task_shape, strides)
1544 for k, item in kwargs.items()
1545 }
1546 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1547 kwargs[f"out{output_id}"] = StridedBuffer(
1548 allocated_outputs[seq_id], task_shape, strides
1549 )
1550 else:
1551 # a simple strategy: all the undefined tensors will follow the first
1552 # tensor that is not broadcated, no attempts to simplify task, no reordering,
1553 # no dimenion collapsing
1554 shapes = tuple(item.shape for item in in_tensors)
1556 task_shape = broadcast_shapes(shapes)
1558 if out_tensors:
1559 for index, item in enumerate(out_tensors):
1560 if list(item.shape) != list(task_shape):
1561 raise RuntimeError(
1562 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!"
1563 )
1564 # output arguments must not have internal overlapping for pointwise operation
1565 if has_internal_overlapping(item) == MemOverlap.Yes:
1566 raise RuntimeError(
1567 "Pointwise Input arguments should not have internal overlapping."
1568 )
1570 ndim = len(task_shape)
1571 for item in tensors:
1572 if item.shape == task_shape:
1573 allocated_outputs = [
1574 torch.empty_like(item, dtype=dtype)
1575 for dtype in outputs_dtypes_for_allocation
1576 ]
1577 break
1578 else: # nobreak
1579 device = tensors[0].device
1580 allocated_outputs = [
1581 torch.empty(task_shape, dtype=dtype, device=device)
1582 for dtype in outputs_dtypes_for_allocation
1583 ]
1584 args = tuple(
1585 (
1586 StridedBuffer(
1587 item,
1588 task_shape,
1589 broadcasted_stride(item.shape, item.stride(), task_shape),
1590 )
1591 if schema.is_tensor(i)
1592 else item
1593 )
1594 for i, item in enumerate(args)
1595 )
1596 kwargs = {
1597 k: StridedBuffer(
1598 item,
1599 task_shape,
1600 broadcasted_stride(item.shape, item.stride(), task_shape),
1601 )
1602 for k, item in kwargs.items()
1603 }
1604 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1605 item = allocated_outputs[seq_id]
1606 kwargs[f"out{output_id}"] = StridedBuffer(
1607 item,
1608 task_shape,
1609 broadcasted_stride(item.shape, item.stride(), task_shape),
1610 )
1611 return (ndim, args, kwargs)
1613 def _unwrap(self, tensors):
1614 # unwrap StridedBuffer to get Tensor
1615 if self.fx.num_output_tensors() == 1:
1616 item = tensors
1617 return item.unwrap()
1618 return tuple(item.unwrap() for item in tensors)
1620 def instantiate(self, ndim):
1621 # NOTE: manually instantiated overload does not have `prepare_args` as
1622 # preprocessing, so you have to manually allocate output and make sure that
1623 # the inputs & ouputs actually fits the manually instantiated overload
1624 key = f"{ndim}_{self.config.prefer_block_pointer}"
1625 if key in self.overloads:
1626 return self.overloads[key]
1628 code = IndentedBuffer()
1630 scalar_fn_name = self._scalar_fn.__name__
1631 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}"
1632 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}"
1633 module_gen = ModuleGenerator(
1634 self.fx,
1635 self._scalar_fn,
1636 ndim,
1637 kernel_name,
1638 wrapper_name,
1639 self.config,
1640 )
1641 module_gen.codegen(code)
1643 # NOTE: [why write the generated code to a file]
1644 # triton uses inpsect to get the source of the jitted function, which requires
1645 # that the source code can be found by inspect
1646 # We write it into a file, since inspect cannot find the source of functions dynamically
1647 # created via exec string. We can help inspect to find the source by hacking linecache
1648 # library, but we find generating a module simpler, since we can generating 2 functions
1649 # the kernel and the wrapper, and the wrapper calls the kernel.
1650 file_name = (
1651 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_"
1652 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}"
1653 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}"
1654 ".py"
1655 )
1657 file_path = code_cache_dir() / file_name
1658 write_atomic(file_path, code.getvalue())
1660 # load
1661 spec = importlib.util.spec_from_file_location(
1662 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}",
1663 file_path,
1664 )
1665 m = importlib.util.module_from_spec(spec)
1666 # do not expose it to sys.modules
1667 # sys.modules["_add_module"] = m
1669 # NOTE: [why not import the scalar function]
1670 # we do not re-import the scalar function, although the generated kernel **calls** it
1671 # Since a function's __name__ may be changed, from the module where it is defined import its
1672 # __name__ is not same; Also the same may be rebind to something else, importing via name
1673 # cannot guarantee that scalar function is imported.
1674 # So we copy the scalar function and its __globals__ to the generated module to do this
1675 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime
1676 spec.loader.exec_module(m)
1677 m.__dict__.update(self._scalar_fn.__globals__)
1678 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn
1680 overload = getattr(m, wrapper_name)
1681 self.overloads[key] = overload
1682 return overload
1685def pointwise_dynamic(
1686 f: Optional[JITFunction] = None,
1687 *,
1688 num_inputs: Optional[int] = None,
1689 is_tensor: Optional[List[bool]] = None,
1690 dtypes: Optional[List[Optional[type]]] = None,
1691 num_outputs: Optional[int] = None,
1692 promotion_methods: Optional[Tuple[int, ...]] = None,
1693 config: Optional[CodeGenConfig] = None,
1694):
1695 def decorator(fn):
1696 nonlocal num_inputs
1697 if (num_inputs is None) and (is_tensor is None) and (dtypes is None):
1698 num_inputs = len(fn.arg_names)
1699 op_desc = FunctionSchema(
1700 num_inputs=num_inputs,
1701 is_tensor=is_tensor,
1702 dtypes=dtypes,
1703 num_outputs=num_outputs,
1704 promotion_methods=promotion_methods,
1705 )
1706 return PointwiseDynamicFunction(op_desc, fn, config)
1708 if f is not None:
1709 return decorator(f)
1710 return decorator