Coverage for src/flag_gems/runtime/backend/_cambricon/ops/repeat.py: 0%
303 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import importlib
2import logging
3import os
4from typing import Callable, List, Mapping
6import torch
7import triton
8import triton.language as tl
10from flag_gems.utils import libentry
11from flag_gems.utils.code_cache import code_cache_dir
12from flag_gems.utils.code_utils import IndentedBuffer
14from ..utils import TOTAL_CORE_NUM
16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
19# --------------------------- repeat wrapper genration -----------------------------------
20def parameter_for_wrapper() -> str:
21 """Generate parameter declaration with type annotation for wrapper function.
22 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
23 """
24 parameters: List[str] = []
26 parameters.append("in0")
27 parameters.append("sizes")
28 return ", ".join(parameters)
31def parameter_for_wrapper_out() -> str:
32 """Generate parameter declaration with type annotation for wrapper function.
33 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
34 """
35 parameters: List[str] = []
37 parameters.append("in0")
38 parameters.append("out0")
40 return ", ".join(parameters)
43def parameter_ref_for_wrapper() -> str:
44 """Generate parameter reference for wrapper function.
45 Example: in0, val0, out0, out0_offset
46 """
47 parameters: List[str] = []
49 parameters.append("in0")
50 parameters.append("out0")
52 return ", ".join(parameters)
55def output_ref_for_wrapper() -> str:
56 return "out0"
59def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
60 code.writeline("import math")
61 code.writeline("import torch")
62 code.writeline("import triton")
63 code.writeline("from triton import language as tl")
64 code.newline()
65 code.writeline("from flag_gems.runtime import torch_device_fn")
66 code.writeline("from flag_gems.utils.shape_utils import volume")
67 code.writeline("from flag_gems.utils import libentry")
68 code.writeline("from flag_gems.runtime.backend import vendor_module")
69 code.writeline("MAX_GRID_SIZE_X = vendor_module.MAX_GRID_SIZE_X")
70 code.writeline("from flag_gems.utils.type_utils import type_promotion")
71 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
72 code.newline()
73 code.newline()
74 return code
77def generate_functional_repeat_wrapper(
78 wrapper_name: str,
79 destination_passing_func_name: str,
80 code: IndentedBuffer,
81) -> IndentedBuffer:
82 # wrapper signature
83 parameters: str = parameter_for_wrapper()
84 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
85 code.writeline(wrapper_signature)
87 with code.indent():
88 code.writeline("in0_rank = in0.dim()")
89 code.writeline("sizes_rank = len(sizes)")
90 code.writeline("in0_shape = list(in0.shape)")
91 code.writeline("sizes_shape = list(sizes)")
92 code.newline()
94 code.writeline(
95 "assert(sizes_rank >= in0_rank), \
96 'Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor'"
97 )
98 code.writeline("if (sizes_rank > in0_rank): ")
99 with code.indent():
100 code.writeline("diff = sizes_rank - in0_rank")
101 code.writeline("ones = [1 for _ in range(diff)]")
102 code.writeline("in0_shape = ones + in0_shape")
103 code.newline()
104 code.writeline("is_empty = False")
105 code.writeline("out_shape = []")
106 code.writeline("for i in range(len(in0_shape)): ")
107 with code.indent():
108 code.writeline(
109 "assert(sizes_shape[i] >= 0), 'the number of repetitions per dimension out of range (expected to >= 0) \
110 but got {}'.format(sizes_shape[i])"
111 )
112 code.writeline("if sizes_shape[i] == 0: ")
113 with code.indent():
114 code.writeline("is_empty = True")
115 code.writeline("out_shape.append(in0_shape[i] * sizes_shape[i])")
116 code.newline()
117 code.writeline(
118 "out0 = torch.empty(out_shape, device=in0.device, dtype=in0.dtype)"
119 )
121 code.writeline("in0 = in0.reshape(in0_shape)")
122 code.writeline("if not is_empty: ")
123 with code.indent():
124 # call destination_passing_func
125 output_names: str = output_ref_for_wrapper()
126 call_str = (
127 f"{output_names} = {destination_passing_func_name}"
128 f"({parameter_ref_for_wrapper()})"
129 )
130 code.writeline(call_str)
132 return_str = "return out0"
133 code.writeline(return_str)
134 code.newline()
135 code.newline()
137 return code
140def generate_destination_passing_repeat_wrapper(
141 rank: int,
142 wrapper_name: str,
143 kernel_name: str,
144 code: IndentedBuffer,
145) -> IndentedBuffer:
146 # wrapper signature
147 parameters: str = parameter_for_wrapper_out()
149 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
150 code.writeline(wrapper_signature)
152 with code.indent():
153 # docstring
154 if rank > 0:
155 code.writeline("shape = out0.shape")
156 code.writeline("num_tasks = volume(shape)")
158 if rank > 0:
159 code.writeline("tile_size = min(512, triton.next_power_of_2(num_tasks))")
160 code.writeline("num_warps = 4")
161 code.writeline(
162 "num_ctas = min(MAX_GRID_SIZE_X//num_warps, triton.cdiv(num_tasks, tile_size))"
163 )
164 code.writeline(
165 "tiles_per_cta = triton.cdiv(num_tasks, tile_size * num_ctas)"
166 )
167 else:
168 code.writeline("num_warps = 1")
169 code.writeline("num_ctas = 1")
170 code.writeline("grid = (num_ctas, 1, 1)")
171 code.newline()
173 # input strides for each input tensor w.r.t. the task index space
174 if rank > 0:
175 code.writeline("# strides of each tensor argument w.r.t the task space")
176 code.writeline("in0_strides = in0.stride()")
177 code.writeline("in0_shape = in0.shape")
178 code.writeline("out0_strides = out0.stride()")
179 code.newline()
181 # grid
182 code.writeline("# kernel launch")
184 # launch kernel
185 code.writeline("with torch_device_fn.device(in0.device.index):")
186 with code.indent():
187 kernel_launch: str = f"{kernel_name}[grid]("
188 code.writeline(kernel_launch)
190 with code.indent():
191 code.writeline("in0, out0, ")
193 if rank > 0:
194 s = ", ".join(f"in0_strides[{j}]" for j in range(rank))
195 code.writeline(f"{s}, # stride for in0")
197 s = ", ".join(f"out0_strides[{j}]" for j in range(rank))
198 code.writeline(f"{s}, # stride for out0")
200 shape_args: str = ", ".join(f"shape[{i}]" for i in range(rank))
201 code.writeline(f"{shape_args}, # task indexing space")
202 in_shape_args: str = ", ".join(f"in0_shape[{i}]" for i in range(rank))
203 code.writeline(
204 f"{in_shape_args}, # task indexing space used when input and ouput tensor has different shape"
205 )
206 code.writeline("num_tasks, # num tasks")
207 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
208 code.writeline("tile_size=tile_size,")
209 code.writeline("one_tile_per_cta=tiles_per_cta==1,")
210 code.writeline("num_warps=num_warps,")
211 code.writeline(")")
213 # return
214 code.writeline("return out0")
215 code.newline()
216 code.newline()
217 return code
220def generate_repeat_kernel(
221 rank: int,
222 kernel_name: str,
223 code: IndentedBuffer,
224) -> IndentedBuffer:
225 # make the inlined function visible in the context
226 code.newline()
228 # the decorators
229 code.writeline("@libentry()")
230 code.writeline("@triton.jit")
232 # signature
233 code.writeline(f"def {kernel_name}(")
234 with code.indent():
235 # signature: inputs ptrs & non tensor inputs
236 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
238 # signature: output ptrs
239 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
241 # signature: strides, for each tensor arguments
242 # only add this arguments when rank > 0
243 if rank > 0:
244 # strides for inputs
245 stride_args = ", ".join(f"in0_stride{j}: int" for j in range(rank))
246 code.writeline(f"{stride_args}, # strides for in0")
248 # strides for outputs
249 stride_args = ", ".join(f"out0_stride{j}: int" for j in range(rank))
250 code.writeline(f"{stride_args}, # strides for out0")
252 # task space, used to reconstruct multi index
253 task_space_args = ", ".join(f"s{i}: int" for i in range(rank))
254 code.writeline(f"{task_space_args}, # task_space")
256 task_space_args2 = ", ".join(f"in_s{i}: int" for i in range(rank))
257 code.writeline(
258 f"{task_space_args2}, # task_space2 used when input and output tensor has different shape"
259 )
261 # number of tasks, used to compute mask
262 code.writeline("num_tasks: int,")
264 # tile size & tiles_per_cta, gsl style
265 if rank > 0:
266 code.writeline("tiles_per_cta,")
268 code.writeline("tile_size: tl.constexpr,")
270 code.writeline("one_tile_per_cta: tl.constexpr,")
271 code.writeline("):")
273 with code.indent():
274 # get pid
275 code.writeline("# task id & masking")
276 pid_stmt = "pid = tl.program_id(0)"
277 code.writeline(pid_stmt)
279 code.writeline("num_ctas = tl.num_programs(0)")
281 # get tid (a.k.a task id)
282 tid_stmt = "init_tid = pid * tile_size + tl.arange(0, tile_size)"
283 code.writeline(tid_stmt)
285 # one-tile-per-cta, monolithic kernel style
286 code.writeline("if one_tile_per_cta: # monolitic kernel style")
287 with code.indent():
288 tid_stmt = "tid = init_tid"
289 code.writeline(tid_stmt)
291 # only apply masking when rank > 0
292 # since we only load a value instead of a block of values when the rank is 0
293 mask_stmt: str = "mask = tid < num_tasks"
294 code.writeline(mask_stmt)
295 code.newline()
297 # reconstruct multi index
298 code.writeline("# multi index recontruction")
299 for i in reversed(range(rank)):
300 if i > 0:
301 code.writeline(f"i{i} = tid % s{i}")
302 code.writeline(f"tid //= s{i}")
303 else:
304 code.writeline(f"i{i} = tid")
305 code.newline()
307 # loads
308 code.writeline("# loads")
309 ptrs_expr: str = " + ".join(
310 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank)
311 )
312 ptrs_expr: str = f"in0_ptr + {ptrs_expr}"
313 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)"
314 code.writeline(load_stmt)
315 code.newline()
317 # compute
318 code.writeline("# compute")
319 code.writeline("out0 = in0")
320 code.newline()
322 # stores
323 code.writeline("# stores")
324 ptrs_expr: str = " + ".join(f"i{j} * out0_stride{j}" for j in range(rank))
325 ptrs_expr: str = f"out0_ptr + {ptrs_expr}"
326 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)"
327 code.writeline(store_stmt)
329 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
330 code.writeline("else: # grid-stride-loop style kernel")
331 with code.indent():
332 code.writeline("for j in range(0, tiles_per_cta):")
333 with code.indent():
334 tid_stmt = "tid = init_tid + j * tile_size * num_ctas"
335 code.writeline(tid_stmt)
337 # only apply masking when rank > 0
338 # since we only load a value instead of a block of values when the rank is 0
339 mask_stmt: str = "mask = tid < num_tasks"
340 code.writeline(mask_stmt)
341 code.newline()
343 # reconstruct multi index
344 code.writeline("# multi index recontruction")
345 for i in reversed(range(rank)):
346 if i > 0:
347 code.writeline(f"i{i} = tid % s{i}")
348 code.writeline(f"tid //= s{i}")
349 else:
350 code.writeline(f"i{i} = tid")
351 code.newline()
353 # loads
354 code.writeline("# loads")
355 ptrs_expr: str = " + ".join(
356 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank)
357 )
358 ptrs_expr: str = f"in0_ptr + {ptrs_expr}"
359 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)"
360 code.writeline(load_stmt)
361 code.newline()
363 # compute
364 code.writeline("# compute")
365 code.writeline("out0 = in0")
366 code.newline()
368 # stores
369 code.writeline("# stores")
370 ptrs_expr: str = " + ".join(
371 f"i{j} * out0_stride{j}" for j in range(rank)
372 )
373 ptrs_expr: str = f"out0_ptr + {ptrs_expr}"
374 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)"
375 code.writeline(store_stmt)
376 code.newline()
377 return code
380def generate_code(
381 rank: int,
382 wrapper_name: str,
383 destination_passing_func_name: str,
384 kernel_name: str,
385 code: IndentedBuffer,
386) -> IndentedBuffer:
387 # the only runtime determined factor is the rank of the task space
388 code = generate_imports(code)
389 code = generate_functional_repeat_wrapper(
390 wrapper_name, destination_passing_func_name, code
391 )
392 code = generate_destination_passing_repeat_wrapper(
393 rank, destination_passing_func_name, kernel_name, code
394 )
395 code = generate_repeat_kernel(rank, kernel_name, code)
396 return code
399class RepeatFunction:
400 def __init__(self):
401 self.pid = os.getpid()
402 # instantiated & cached overloads
403 self.overloads: Mapping[str, Callable] = {}
405 def __call__(self, x, sizes):
406 # note: kwargs should not be used in JITFunction directly
407 ndim = self.arg_key(x, sizes)
408 key = str(ndim)
409 if key in self.overloads:
410 overload = self.overloads[key]
411 else:
412 # generate file & import it
413 code = IndentedBuffer()
414 code = generate_code(
415 ndim,
416 "_wrapper",
417 "_wrapper_out",
418 "_repeat_flaggems_jit_function",
419 code,
420 )
422 file_name = f"repeat_rank_{key}_pid_{self.pid}.py"
424 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
425 f.write(code.getvalue())
427 # load
428 spec = importlib.util.spec_from_file_location(
429 f"_gen_module_rank_{key}_pid_{self.pid}",
430 f.name,
431 )
433 m = importlib.util.module_from_spec(spec)
434 # do not expose it to sys.modules
435 # sys.modules["_add_module"] = m
436 spec.loader.exec_module(m)
437 overload = getattr(m, "_wrapper")
438 self.overloads[key] = overload
439 return overload(x, sizes)
441 def arg_key(self, x, sizes):
442 max_rank = max(x.ndim, len(sizes))
443 return max_rank
446_repeat_func = RepeatFunction()
449@libentry()
450@triton.autotune(
451 configs=[
452 triton.Config({"BLOCK_C": 2**n}, num_stages=s, num_warps=w)
453 for n in range(6, 17, 2)
454 for s in [1, 3]
455 for w in [1, 4]
456 ],
457 key=["C"],
458)
459@triton.jit
460def repeat_2d_kernel(
461 inp_ptr,
462 out_ptr,
463 N,
464 C: tl.constexpr,
465 repeat_N: tl.constexpr,
466 repeat_C: tl.constexpr,
467 BLOCK_C: tl.constexpr,
468):
469 job_id = tl.program_id(0)
470 num_jobs = tl.num_programs(0)
471 for batch_idx in range(job_id, N, num_jobs):
472 if C <= BLOCK_C:
473 offset_c = tl.arange(0, C)
474 inp_ptrs = inp_ptr + batch_idx * C + offset_c
475 inp = tl.load(inp_ptrs).reshape(1, C)
476 repeat_inp = inp.broadcast_to(repeat_C, C).reshape(repeat_C * C)
477 out_offset_c = tl.arange(0, repeat_C * C)
478 for n_idx in range(0, repeat_N):
479 out_ptrs = (
480 out_ptr
481 + N * n_idx * repeat_C * C
482 + batch_idx * repeat_C * C
483 + out_offset_c
484 )
485 tl.store(out_ptrs, repeat_inp)
486 else:
487 for off in range(0, C, BLOCK_C):
488 offset_c = off + tl.arange(0, BLOCK_C)
489 inp_ptrs = inp_ptr + batch_idx * C + offset_c
490 inp_mask = offset_c < C
491 inp = tl.load(inp_ptrs, mask=inp_mask, other=0)
492 for c_idx in range(0, repeat_C):
493 for n_idx in range(0, repeat_N):
494 out_ptrs = (
495 out_ptr
496 + N * n_idx * repeat_C * C
497 + batch_idx * repeat_C * C
498 + c_idx * C
499 + offset_c
500 )
501 tl.store(out_ptrs, inp, mask=inp_mask)
504def repeat(inp: torch.Tensor, sizes) -> torch.Tensor:
505 logger.debug("GEMS_CAMBRICON REPEAT")
507 inp_rank = inp.dim()
508 sizes_rank = len(sizes)
509 if inp_rank == 2 and sizes_rank == 2:
510 inp_shape = list(inp.shape)
511 sizes_shape = list(sizes)
512 N = inp_shape[0]
513 C = inp_shape[1]
514 repeat_N = sizes_shape[0]
515 repeat_C = sizes_shape[1]
517 is_empty = False
518 out_shape = []
519 for i in range(len(inp_shape)):
520 assert sizes_shape[i] >= 0
521 if sizes_shape[i] == 0:
522 is_empty = True
523 out_shape.append(inp_shape[i] * sizes_shape[i])
524 out = torch.empty(out_shape, device=inp.device, dtype=inp.dtype)
525 if is_empty:
526 return out
527 repeat_2d_kernel[(TOTAL_CORE_NUM,)](
528 inp.contiguous(), out, N, C, repeat_N, repeat_C
529 )
530 return out
532 out = _repeat_func(inp, sizes)
533 return out