Coverage for src/flag_gems/runtime/backend/_cambricon/ops/stack.py: 0%
104 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import importlib
2import logging
3import math
4import os
5import textwrap
6from typing import Callable, List, Mapping, Tuple, Union
8import torch
10from flag_gems.utils.code_cache import cache_dir
11from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
13from ..utils import TOTAL_CORE_NUM
14from .vstack import vstack
16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
19def get_dtype_size(dtype):
20 try:
21 return torch.finfo(dtype).bits // 8
22 except TypeError:
23 try:
24 return torch.iinfo(dtype).bits // 8
25 except TypeError:
26 if dtype == torch.bool:
27 return 1
28 else:
29 raise ValueError(f"Unsupported dtype: {dtype}")
32class StackKernelCode(IndentedBuffer):
33 """
34 Stack kernel template.
35 """
37 overloads: Mapping[str, Callable] = {}
39 def __init__(self):
40 self.pid = os.getpid()
41 self.cache = self.overloads
42 self.kernel_name = "_stack_jit_kernel"
43 self.wrapper_func_name = "_wrapper"
44 super(StackKernelCode, self).__init__()
46 def __imports(self):
47 """Generate imports for the kernel code."""
48 tpl = """\
49 import math
50 import torch
51 import triton
52 from triton import language as tl
53 from typing import List, Tuple, Union
54 from flag_gems.utils import libentry
55 from flag_gems.runtime.backend import vendor_module
56 TOTAL_CORE_NUM = vendor_module.TOTAL_CORE_NUM
57 MAX_NRAM_SIZE = vendor_module.MAX_NRAM_SIZE
59 """
60 self.tpl(textwrap.dedent(tpl))
62 def __wrapper(self):
63 """Generate wrapper function for the kernel code."""
64 self.newline()
65 tpl = """\
66 def {wrapper_name}(
67 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
68 ) -> torch.Tensor:
69 if len(tensors) == 0:
70 raise RuntimeError("stack expected a non-empty TensorList")
72 inp_shapes = [list(_.shape) for _ in tensors]
73 inp0_shape = inp_shapes[0]
74 for i, s in enumerate(inp_shapes[1:]):
75 if (dim < -tensors[i + 1].dim() - 1) or (dim > tensors[i + 1].dim()):
76 raise IndexError(
77 "Dimension out of range (expected to be in range of [{{}}, {{}}], but got {{}})".format(
78 -tensors[i + 1].dim() - 1, tensors[i + 1].dim(), dim
79 )
80 )
81 if s != inp0_shape:
82 raise RuntimeError(
83 f"stack expects each tensor to be equal size, \
84 but got {{inp0_shape}} at entry 0 and {{s}} at entry {{i+1}}"
85 )
87 if dim < 0:
88 dim = dim + len(inp0_shape) + 1
89 out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:]
90 high = int(math.prod(out_shape[:dim]))
91 low = int(math.prod(out_shape[dim+1:]))
92 tensor_num = len(tensors)
93 out0 = torch.empty(out_shape, dtype=tensors[0].dtype, device=tensors[0].device)
94 def grid(meta):
95 if meta['BLOCK_SIZE']>0:
96 task_x = high
97 task_y = tensor_num
98 task_z = triton.cdiv(low ,meta['BLOCK_SIZE'])
99 return (task_x, task_y, task_z)
100 else:
101 total_task = high * tensor_num
102 if meta['LOW_NUM']>0:
103 core_used = triton.cdiv(total_task // meta['LOW_NUM'], meta['TASK_PER_CORE'])
104 elif meta['N_LOW_NUM']>0:
105 core_used = triton.cdiv(high, meta['TASK_PER_CORE'])
106 return (core_used,)
107 {kernel_name}[grid](
108 out0,
109 *tensors,
110 high,
111 tensor_num,
112 low,
113 )
114 return out0
115 """
116 self.tpl(
117 textwrap.dedent(tpl),
118 wrapper_name=self.wrapper_func_name,
119 kernel_name=self.kernel_name,
120 )
122 def __config(self, tensor_num, high, low, dtype):
123 """Generate config for the kernel code."""
124 dtyp_bytest = get_dtype_size(dtype)
125 # Since the kernel has three branches, each branch has its own parameters,
126 # so for a certain branch, the other parameters can be directly set to zero.
128 # 1)N_LOW_NUM branch: NRAM can hold at least one set of `tensor_num * low * dtyp_bytest`.
129 # This parameter is used to indicate how many `tensor_num * low` are processed in a single core.
131 # 2) LOW_NUM branch: NRAM can hold at least one set of `low * dtyp_bytest`,
132 # but cannot hold tensor_num * low * dtyp_bytest.
133 # This parameter is used to indicate how many `low` are processed in a single core.
135 # 3) BLOCK_SIZE branch: NRAM is not enough to store a set of `low`,
136 # so it can only loop multiple times to process a set of `low`.
137 # This parameter is used to indicates how many elements to load
138 # at a time when looping over and processing low.
139 tpl = """\
140 def cfggen():
141 N_LOW_NUM = {n_low_num_options}
142 LOW_NUM = {low_num_options}
143 BLOCK_SIZE = {block_size_options}
144 warps = [1]
145 num_stages = {num_stages}
146 configs = [
147 triton.Config(
148 {{
149 "BLOCK_SIZE": block_size,
150 "N_LOW_NUM": n_low_num,
151 "LOW_NUM": low_num,
152 }},
153 num_warps=w,
154 num_stages=s)
155 for block_size in BLOCK_SIZE
156 for n_low_num in N_LOW_NUM
157 for low_num in LOW_NUM
158 for w in warps for s in num_stages
159 ]
160 return configs
161 performance_related_keys = {keys}
162 """
164 # If `tensor_num * low * dtyp_bytest` is less than `nram_threshold`,
165 # use N_LOW_NUM branch, otherwise LOW_NUM branch.
166 nram_threshold = 170000
167 # The maximum number of elements in triton is 1048576.
168 max_elements_num = 1048576
169 # after removing the overhead of pipeline and temporary variables.
170 if tensor_num * low * dtyp_bytest <= nram_threshold:
171 n_low_per_core = math.ceil(high / TOTAL_CORE_NUM)
172 limited_by_nram = nram_threshold // dtyp_bytest // (tensor_num * low)
173 limited_by_triton = max_elements_num // (tensor_num * low)
174 best_opt = min(n_low_per_core, limited_by_triton, limited_by_nram)
175 self.tpl(
176 textwrap.dedent(tpl),
177 n_low_num_options=f"{[best_opt]}",
178 low_num_options=r"[0]",
179 block_size_options=r"[0]",
180 num_stages=r"[1]",
181 keys=r'["high"]',
182 )
183 elif low * dtyp_bytest <= nram_threshold:
184 self.tpl(
185 textwrap.dedent(tpl),
186 n_low_num_options=r"[0]",
187 low_num_options=r"[1,2,3]",
188 block_size_options=r"[0]",
189 num_stages=f"{[1]}",
190 keys=r'["high", "tensor_num", "low"]',
191 )
192 else:
193 self.tpl(
194 textwrap.dedent(tpl),
195 n_low_num_options=r"[0]",
196 low_num_options=r"[0]",
197 block_size_options=r"[8192, 16384, 32768, 65536, 131072, 262144]",
198 num_stages=r"[1]",
199 keys=r'["low"]',
200 )
202 def __kernel(self, tensor_num):
203 """Generate kernel code body."""
204 tpl = """\
205 def stack_heuristics(args, need_key):
206 ret = {{
207 'TASK_PER_CORE': 0,
208 'TASK_LAST_CORE_REPEAT': 0,
209 'TASK_LAST_CORE_REMAIN': 0,
210 }}
211 total_task = args['high']*args['tensor_num']
212 if args['LOW_NUM']>0:
213 LOW_NUM = args['LOW_NUM'] if total_task > args['LOW_NUM'] else total_task
214 ret['TASK_PER_CORE'] = triton.cdiv(total_task // LOW_NUM, TOTAL_CORE_NUM)
215 assert ret['TASK_PER_CORE']>0, ret['TASK_PER_CORE']
216 core_used = triton.cdiv(total_task // LOW_NUM, ret['TASK_PER_CORE'])
217 task_last_core = total_task-(core_used-1)*ret['TASK_PER_CORE']*LOW_NUM
218 ret['TASK_LAST_CORE_REPEAT'] = task_last_core//LOW_NUM
219 ret['TASK_LAST_CORE_REMAIN'] = task_last_core%LOW_NUM
220 elif args['N_LOW_NUM']>0:
221 ret['TASK_PER_CORE'] = triton.cdiv(args['high'], TOTAL_CORE_NUM)
222 core_used = triton.cdiv(args['high'], ret['TASK_PER_CORE'])
223 ret['TASK_LAST_CORE_REPEAT'] = args['high'] -(core_used-1)*ret['TASK_PER_CORE']
224 return ret[need_key]
226 @triton.jit()
227 def load_trans_store(
228 low: tl.constexpr,
229 tensor_num: tl.constexpr,
230 {tensors},
231 offset,
232 buffer,
233 buffer_offset,
234 output,
235 out_offset,
236 ):
237 if low >64:
238 {low_gt_64_code}
239 tl.store(output+offset*tensor_num+out_offset, buffer)
240 else:
241 {low_le_64_code}
242 tl.store(output+offset*tensor_num+out_offset, tl.trans(buffer, 1, 0, 2))
244 @triton.jit()
245 def load_and_store(
246 output_ptr,
247 buffer,
248 buffer_offset,
249 task_id,
250 LOW_NUM: tl.constexpr,
251 low: tl.constexpr,
252 LOW_OFFSET: tl.constexpr,
253 tensor_num: tl.constexpr,
254 {tensors}
255 ):
256 for low_idx in tl.range(LOW_NUM):
257 cur_low_id = task_id + low_idx
258 tensor_idx = cur_low_id%tensor_num
259 high_idx = cur_low_id//tensor_num
260 load_start = high_idx *low
261 {load_and_store_code}
262 tl.store(output_ptr+buffer_offset, buffer)
264 @libentry()
265 @triton.autotune(configs=cfggen(), key=performance_related_keys)
266 @triton.heuristics(
267 {{
268 "TASK_PER_CORE": lambda args: stack_heuristics(args, "TASK_PER_CORE"),
269 "TASK_LAST_CORE_REPEAT": lambda args: stack_heuristics(args, "TASK_LAST_CORE_REPEAT"),
270 "TASK_LAST_CORE_REMAIN": lambda args: stack_heuristics(args, "TASK_LAST_CORE_REMAIN"),
271 }}
272 )
273 @triton.jit()
274 def {kernel_name}(
275 output,
276 {tensors},
277 high: tl.constexpr,
278 tensor_num: tl.constexpr,
279 low: tl.constexpr,
280 N_LOW_NUM: tl.constexpr,
281 LOW_NUM: tl.constexpr,
282 TASK_PER_CORE: tl.constexpr,
283 TASK_LAST_CORE_REPEAT: tl.constexpr,
284 TASK_LAST_CORE_REMAIN: tl.constexpr,
285 BLOCK_SIZE: tl.constexpr):
286 if N_LOW_NUM>0:
287 # The memory space is sufficient to hold at least one set of "tensor_num* low * type_bytes"
288 core_idx = tl.program_id(0)
289 core_used = tl.num_programs(0)
290 if core_idx>=core_used:
291 return
292 in_offset = core_idx*TASK_PER_CORE*low
293 if low >64:
294 buffer_repeat = tl.empty(shape=[N_LOW_NUM, tensor_num, low], dtype=output.dtype.element_ty)
295 else:
296 buffer_repeat = tl.empty(shape=[tensor_num, N_LOW_NUM, low], dtype=output.dtype.element_ty)
297 buffer_repeat_offset = tl.arange(0, N_LOW_NUM)[:, None]*low+tl.arange(0, low)[None,:]
298 out_repeat_offset= \\
299 tl.arange(0, N_LOW_NUM)[:,None,None]*low*tensor_num+\\
300 tl.arange(0, tensor_num)[None,:,None]*low+\\
301 tl.arange(0, low)[None, None,:]
302 if core_idx !=core_used -1:
303 for repeat_idx in range(TASK_PER_CORE//N_LOW_NUM):
304 repeat_offset = in_offset + repeat_idx*N_LOW_NUM*low
305 load_trans_store(low, tensor_num, {tensors},repeat_offset, buffer_repeat,\\
306 buffer_repeat_offset,output,out_repeat_offset)
307 if (TASK_PER_CORE%N_LOW_NUM) > 0:
308 normal_remain_offset = in_offset + (TASK_PER_CORE//N_LOW_NUM)*N_LOW_NUM*low
309 if low >64:
310 buffer_normal_remain = tl.empty(shape=[TASK_PER_CORE%N_LOW_NUM,tensor_num, low], \\
311 dtype=output.dtype.element_ty)
312 else:
313 buffer_normal_remain = tl.empty(shape=[tensor_num,TASK_PER_CORE%N_LOW_NUM, low], \\
314 dtype=output.dtype.element_ty)
315 buffer_normal_remain_offset = tl.arange(0, TASK_PER_CORE%N_LOW_NUM)[:, None]*low + \\
316 tl.arange(0, low)[None,:]
317 out_normal_remain_offset= \\
318 tl.arange(0, TASK_PER_CORE%N_LOW_NUM)[:,None,None]*low*tensor_num+\\
319 tl.arange(0, tensor_num)[None,:,None]*low+\\
320 tl.arange(0, low)[None, None,:]
321 load_trans_store(low, tensor_num, {tensors},normal_remain_offset, buffer_normal_remain,\\
322 buffer_normal_remain_offset,output,out_normal_remain_offset)
323 else:
324 for repeat_idx in range(TASK_LAST_CORE_REPEAT//N_LOW_NUM):
325 repeat_offset = in_offset + repeat_idx*N_LOW_NUM*low
326 load_trans_store(low, tensor_num, {tensors},repeat_offset, buffer_repeat,\\
327 buffer_repeat_offset,output,out_repeat_offset)
328 if (TASK_LAST_CORE_REPEAT%N_LOW_NUM) >0 :
329 last_core_remain_offset = in_offset + (TASK_LAST_CORE_REPEAT//N_LOW_NUM)*N_LOW_NUM*low
330 if low >64:
331 buffer_last_core_remain = \\
332 tl.empty(shape=[TASK_LAST_CORE_REPEAT%N_LOW_NUM,tensor_num, low], \\
333 dtype=output.dtype.element_ty)
334 else:
335 buffer_last_core_remain = \\
336 tl.empty(shape=[tensor_num,TASK_LAST_CORE_REPEAT%N_LOW_NUM, low], \\
337 dtype=output.dtype.element_ty)
338 buffer_last_core_remain_offset = \\
339 tl.arange(0, TASK_LAST_CORE_REPEAT%N_LOW_NUM)[:, None]*low + \\
340 tl.arange(0, low)[None,:]
341 out_last_core_remain_offset= \\
342 tl.arange(0, TASK_LAST_CORE_REPEAT%N_LOW_NUM)[:,None,None]*low*tensor_num+\\
343 tl.arange(0, tensor_num)[None,:,None]*low+\\
344 tl.arange(0, low)[None, None,:]
345 load_trans_store(low, tensor_num, {tensors},last_core_remain_offset, \\
346 buffer_last_core_remain, buffer_last_core_remain_offset, \\
347 output,out_last_core_remain_offset)
348 elif LOW_NUM>0:
349 # The memory space is sufficient to hold at least one set of "low * type_bytes"
350 core_idx = tl.program_id(0)
351 core_used = tl.num_programs(0)
352 if core_idx>=core_used:
353 return
354 dtype = output.dtype.element_ty
355 buffer = tl.empty(shape=[LOW_NUM,low], dtype=dtype)
356 buffer_offset = tl.arange(0, LOW_NUM)[:,None]*low+tl.arange(0, low)[None,:]
357 LOW_OFFSET = tl.arange(0, low)
358 if core_idx != core_used-1:
359 for cycles_idx in range(TASK_PER_CORE):
360 task_id = core_idx*TASK_PER_CORE*LOW_NUM+cycles_idx*LOW_NUM
361 out_ptr = output + task_id*low
362 load_and_store(
363 out_ptr,
364 buffer,
365 buffer_offset,
366 task_id,
367 LOW_NUM,
368 low,
369 LOW_OFFSET,
370 tensor_num,
371 {tensors}
372 )
373 else:
374 base_task_id = core_idx*TASK_PER_CORE*LOW_NUM
375 for cycles_idx in range(TASK_LAST_CORE_REPEAT):
376 task_id= base_task_id+cycles_idx*LOW_NUM
377 out_ptr = output + task_id*low
378 load_and_store(
379 out_ptr,
380 buffer,
381 buffer_offset,
382 task_id,
383 LOW_NUM,
384 low,
385 LOW_OFFSET,
386 tensor_num,
387 {tensors}
388 )
389 task_id = base_task_id+TASK_LAST_CORE_REPEAT*LOW_NUM
390 output_ptr = output + task_id*low
391 for low_idx in tl.range(TASK_LAST_CORE_REMAIN):
392 cur_low_id = task_id + low_idx
393 tensor_idx = cur_low_id%tensor_num
394 high_idx = cur_low_id//tensor_num
395 load_start = high_idx *low
396 {low_num_gt_0_last_core_code}
397 tl.store(output_ptr+buffer_offset, buffer, mask=buffer_offset<TASK_LAST_CORE_REMAIN*low)
398 elif BLOCK_SIZE>0:
399 # Insufficient memory space to hold a set of "low* type_bytes"
400 high_idx = tl.program_id(0)
401 tensor_idx = tl.program_id(1)
402 output_ptr = output + high_idx*(low*tensor_num)+tensor_idx*low
403 offset_in_loop = tl.program_id(2)*BLOCK_SIZE+tl.arange(0, BLOCK_SIZE)
404 x = tl.empty(shape=[BLOCK_SIZE,],dtype=output.dtype.element_ty)
405 {block_size_gt_0_code}
406 tl.store(output_ptr+offset_in_loop, x, mask=offset_in_loop<low)
407 """
409 def add_indent(cleaned_str, indent_size):
410 return "\n".join(
411 [f"{' ' * indent_size}{line}" for line in cleaned_str.split("\n")]
412 )
414 tensors = ", ".join([f"in_{idx}" for idx in range(tensor_num)])
415 load_form_inputs = textwrap.dedent(
416 """\
417 if tensor_idx == 0:
418 buffer[low_idx,:] = tl.load(in_0+load_start+LOW_OFFSET)\n"""
419 + "\n".join(
420 [
421 f"""\
422 elif tensor_idx == {idx}:
423 buffer[low_idx,:] = tl.load(in_{idx}+load_start+LOW_OFFSET)"""
424 for idx in range(1, tensor_num - 1)
425 ]
426 )
427 + "\n"
428 + f"""\
429 else:
430 buffer[low_idx,:] = tl.load(in_{tensor_num - 1}+load_start+LOW_OFFSET)"""
431 )
432 self.tpl(
433 textwrap.dedent(tpl),
434 kernel_name=self.kernel_name,
435 tensors=tensors,
436 low_gt_64_code="\n".join(
437 [
438 f"{' ' * 8}buffer[:,{idx},:]=tl.load(in_{idx}+offset+buffer_offset)"
439 for idx in range(tensor_num)
440 ]
441 ),
442 low_le_64_code="\n".join(
443 [
444 f"{' ' * 8}buffer[{idx},:,:]=tl.load(in_{idx}+offset+buffer_offset)"
445 for idx in range(tensor_num)
446 ]
447 ),
448 load_and_store_code=add_indent(load_form_inputs, 8),
449 low_num_gt_0_last_core_code=add_indent(load_form_inputs, 16),
450 block_size_gt_0_code=add_indent(
451 textwrap.dedent(
452 """\
453 if tensor_idx == 0:
454 x = tl.load(in_0+high_idx *low+offset_in_loop,mask=offset_in_loop<low)\n"""
455 + "\n".join(
456 [
457 f"""\
458 elif tensor_idx == {idx}:
459 x = tl.load(in_{idx}+high_idx *low+offset_in_loop,mask=offset_in_loop<low)"""
460 for idx in range(1, tensor_num - 1)
461 ]
462 )
463 + "\n"
464 + f"""\
465 else:
466 x = tl.load(in_{tensor_num - 1}+high_idx *low+offset_in_loop,mask=offset_in_loop<low)"""
467 ),
468 8,
469 ),
470 )
472 def __gen_code(self, tensor_num, high, low, dtype):
473 """Entry point for code generation of stack."""
474 # generate imports.
475 self.__imports()
476 # generate config.
477 self.__config(tensor_num, high, low, dtype)
478 # generate kernel.
479 self.__kernel(tensor_num)
480 # generate wrapper function.
481 self.__wrapper()
483 def __call__(
484 self, tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
485 ) -> torch.Tensor:
486 assert dim != 0, "StackKernel template does not optimize `dim=0`."
487 tensor_num = len(tensors)
488 inp0_shape = list(tensors[0].shape)
489 out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:]
490 high = int(math.prod(out_shape[:dim]))
491 low = int(math.prod(out_shape[dim + 1 :]))
492 dtype = tensors[0].dtype
493 self.kernel_name = f"{self.kernel_name}_num_{tensor_num}"
494 key = f"num_{tensor_num}_high_{high}_low_{low}_dtype_{dtype}"
495 for tensor in tensors[1:]:
496 assert tensor.dtype == dtype, f"{tensor.dtype} != {dtype}"
497 if key not in self.cache:
498 # generate code and cache.
499 self.__gen_code(tensor_num, high, low, dtype)
500 file_name = f"{cache_dir()}/stack_{key}_pid_{self.pid}.py"
501 write_atomic(file_name, self.getvalue())
502 # load
503 spec = importlib.util.spec_from_file_location(
504 f"_gen_module_{key}_pid_{self.pid}", file_name
505 )
506 m = importlib.util.module_from_spec(spec)
507 # do not expose it to sys.modules
508 # sys.modules["_add_module"] = m
509 spec.loader.exec_module(m)
510 overload = getattr(m, self.wrapper_func_name)
511 self.cache[key] = overload
512 overload = self.cache[key]
513 return overload(tensors, dim)
516def stack(
517 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
518) -> torch.Tensor:
519 logger.debug("GEMS_CAMBRICON STACK")
521 if len(tensors) == 0:
522 raise RuntimeError("stack expected a non-empty TensorList")
524 inp_shapes = [list(_.shape) for _ in tensors]
525 inp0_shape = inp_shapes[0]
526 for i, s in enumerate(inp_shapes[1:]):
527 if (dim < -tensors[i + 1].dim() - 1) or (dim > tensors[i + 1].dim()):
528 raise IndexError(
529 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
530 -tensors[i + 1].dim() - 1, tensors[i + 1].dim(), dim
531 )
532 )
533 if s != inp0_shape:
534 raise RuntimeError(
535 f"stack expects each tensor to be equal size, but got {inp0_shape} at entry 0 and {s} at entry {i + 1}"
536 )
538 if dim < 0:
539 dim = dim + len(inp0_shape) + 1
540 out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:]
541 if dim == 0:
542 return vstack(tensors).view(out_shape)
543 tensors = [
544 tensor if tensor.is_contiguous() else tensor.contiguous() for tensor in tensors
545 ]
546 return StackKernelCode()(tensors, dim)