Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/repeat.py: 0%
251 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import importlib
2import logging
3import os
4from typing import Callable, List, Mapping
6import torch
8from flag_gems.utils.code_cache import cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14# --------------------------- repeat wrapper genration -----------------------------------
15def parameter_for_wrapper() -> str:
16 """Generate parameter declaration with type annotation for wrapper function.
17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
18 """
19 parameters: List[str] = []
21 parameters.append("in0")
22 parameters.append("sizes")
23 return ", ".join(parameters)
26def parameter_for_wrapper_out() -> str:
27 """Generate parameter declaration with type annotation for wrapper function.
28 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
29 """
30 parameters: List[str] = []
32 parameters.append("in0")
33 parameters.append("out0")
35 return ", ".join(parameters)
38def parameter_ref_for_wrapper() -> str:
39 """Generate parameter reference for wrapper function.
40 Example: in0, val0, out0, out0_offset
41 """
42 parameters: List[str] = []
44 parameters.append("in0")
45 parameters.append("out0")
47 return ", ".join(parameters)
50def output_ref_for_wrapper() -> str:
51 return "out0"
54def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
55 code.writeline("import math")
56 code.writeline("import torch")
57 code.writeline("import triton")
58 code.writeline("from triton import language as tl")
59 code.newline()
60 code.writeline("from flag_gems.runtime import torch_device_fn")
61 code.writeline("from flag_gems.utils.shape_utils import volume")
62 code.writeline("from flag_gems.utils.libentry import libentry")
63 code.writeline("from flag_gems.utils.type_utils import type_promotion")
64 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
65 code.newline()
66 code.newline()
67 return code
70def generate_functional_repeat_wrapper(
71 wrapper_name: str,
72 destination_passing_func_name: str,
73 code: IndentedBuffer,
74) -> IndentedBuffer:
75 # wrapper signature
76 parameters: str = parameter_for_wrapper()
77 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
78 code.writeline(wrapper_signature)
80 with code.indent():
81 code.writeline("in0_rank = in0.dim()")
82 code.writeline("sizes_rank = len(sizes)")
83 code.writeline("in0_shape = list(in0.shape)")
84 code.writeline("sizes_shape = list(sizes)")
85 code.newline()
87 code.writeline(
88 "assert(sizes_rank >= in0_rank), \
89 'Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor'"
90 )
91 code.writeline("if (sizes_rank > in0_rank): ")
92 with code.indent():
93 code.writeline("diff = sizes_rank - in0_rank")
94 code.writeline("ones = [1 for _ in range(diff)]")
95 code.writeline("in0_shape = ones + in0_shape")
96 code.newline()
97 code.writeline("is_empty = False")
98 code.writeline("out_shape = []")
99 code.writeline("for i in range(len(in0_shape)): ")
100 with code.indent():
101 code.writeline(
102 "assert(sizes_shape[i] >= 0), 'the number of repetitions per dimension out of range (expected to >= 0) \
103 but got {}'.format(sizes_shape[i])"
104 )
105 code.writeline("if sizes_shape[i] == 0: ")
106 with code.indent():
107 code.writeline("is_empty = True")
108 code.writeline("out_shape.append(in0_shape[i] * sizes_shape[i])")
109 code.newline()
110 code.writeline(
111 "out0 = torch.empty(out_shape, device=in0.device, dtype=in0.dtype)"
112 )
114 code.writeline("in0 = in0.reshape(in0_shape)")
115 code.writeline("if not is_empty: ")
116 with code.indent():
117 # call destination_passing_func
118 output_names: str = output_ref_for_wrapper()
119 call_str = (
120 f"{output_names} = {destination_passing_func_name}"
121 f"({parameter_ref_for_wrapper()})"
122 )
123 code.writeline(call_str)
125 return_str = "return out0"
126 code.writeline(return_str)
127 code.newline()
128 code.newline()
130 return code
133def generate_destination_passing_repeat_wrapper(
134 rank: int,
135 wrapper_name: str,
136 kernel_name: str,
137 code: IndentedBuffer,
138) -> IndentedBuffer:
139 # wrapper signature
140 parameters: str = parameter_for_wrapper_out()
142 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
143 code.writeline(wrapper_signature)
145 with code.indent():
146 # docstring
147 if rank > 0:
148 code.writeline("shape = out0.shape")
149 code.writeline("num_tasks = volume(shape)")
151 if rank > 0:
152 code.writeline("num_ctas = 12")
153 code.writeline("num_warps = 1")
154 code.writeline(
155 "tile_size = triton.next_power_of_2(triton.cdiv(num_tasks, num_ctas))"
156 )
157 code.writeline(
158 "tiles_per_cta = triton.cdiv(num_tasks, tile_size * num_ctas)"
159 )
160 else:
161 code.writeline("num_warps = 1")
162 code.writeline("num_ctas = 1")
163 code.writeline("grid = (num_ctas, 1, 1)")
164 code.newline()
166 # input strides for each input tensor w.r.t. the task index space
167 if rank > 0:
168 code.writeline("# strides of each tensor argument w.r.t the task space")
169 code.writeline("in0_strides = in0.stride()")
170 code.writeline("in0_shape = in0.shape")
171 code.writeline("out0_strides = out0.stride()")
172 code.newline()
174 # grid
175 code.writeline("# kernel launch")
177 # launch kernel
178 code.writeline("with torch_device_fn.device(in0.device.index):")
179 with code.indent():
180 kernel_launch: str = f"{kernel_name}[grid]("
181 code.writeline(kernel_launch)
183 with code.indent():
184 code.writeline("in0, out0, ")
186 if rank > 0:
187 s = ", ".join(f"in0_strides[{j}]" for j in range(rank))
188 code.writeline(f"{s}, # stride for in0")
190 s = ", ".join(f"out0_strides[{j}]" for j in range(rank))
191 code.writeline(f"{s}, # stride for out0")
193 shape_args: str = ", ".join(f"shape[{i}]" for i in range(rank))
194 code.writeline(f"{shape_args}, # task indexing space")
195 in_shape_args: str = ", ".join(f"in0_shape[{i}]" for i in range(rank))
196 code.writeline(
197 f"{in_shape_args}, # task indexing space used when input and ouput tensor has different shape"
198 )
199 code.writeline("num_tasks, # num tasks")
200 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
201 code.writeline("tile_size=tile_size,")
202 code.writeline("one_tile_per_cta=tiles_per_cta==1,")
203 code.writeline("num_warps=num_warps,")
204 code.writeline(")")
206 # return
207 code.writeline("return out0")
208 code.newline()
209 code.newline()
210 return code
213def generate_repeat_kernel(
214 rank: int,
215 kernel_name: str,
216 code: IndentedBuffer,
217) -> IndentedBuffer:
218 # make the inlined function visible in the context
219 code.newline()
221 # the decorators
222 code.writeline("@libentry()")
223 code.writeline("@triton.jit")
225 # signature
226 code.writeline(f"def {kernel_name}(")
227 with code.indent():
228 # signature: inputs ptrs & non tensor inputs
229 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
231 # signature: output ptrs
232 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
234 # signature: strides, for each tensor arguments
235 # only add this arguments when rank > 0
236 if rank > 0:
237 # strides for inputs
238 stride_args = ", ".join(f"in0_stride{j}: tl.constexpr" for j in range(rank))
239 code.writeline(f"{stride_args}, # strides for in0")
241 # strides for outputs
242 stride_args = ", ".join(
243 f"out0_stride{j}: tl.constexpr" for j in range(rank)
244 )
245 code.writeline(f"{stride_args}, # strides for out0")
247 # task space, used to reconstruct multi index
248 task_space_args = ", ".join(f"s{i}: tl.constexpr" for i in range(rank))
249 code.writeline(f"{task_space_args}, # task_space")
251 task_space_args2 = ", ".join(f"in_s{i}: tl.constexpr" for i in range(rank))
252 code.writeline(
253 f"{task_space_args2}, # task_space2 used when input and output tensor has different shape"
254 )
256 # number of tasks, used to compute mask
257 code.writeline("num_tasks: int,")
259 # tile size & tiles_per_cta, gsl style
260 if rank > 0:
261 code.writeline("tiles_per_cta,")
263 code.writeline("tile_size: tl.constexpr,")
265 code.writeline("one_tile_per_cta: tl.constexpr,")
266 code.writeline("):")
268 with code.indent():
269 # get pid
270 code.writeline("# task id & masking")
271 pid_stmt = "pid = tle.program_id(0)"
272 code.writeline(pid_stmt)
274 code.writeline("num_ctas = tle.num_programs(0)")
276 # get tid (a.k.a task id)
277 tid_stmt = "init_tid = pid * tile_size + tl.arange(0, tile_size)"
278 code.writeline(tid_stmt)
280 # one-tile-per-cta, monolithic kernel style
281 code.writeline("if one_tile_per_cta: # monolitic kernel style")
282 with code.indent():
283 tid_stmt = "tid = init_tid"
284 code.writeline(tid_stmt)
286 # only apply masking when rank > 0
287 # since we only load a value instead of a block of values when the rank is 0
288 mask_stmt: str = "mask = tid < num_tasks"
289 code.writeline(mask_stmt)
290 code.newline()
292 # reconstruct multi index
293 code.writeline("# multi index recontruction")
294 for i in reversed(range(rank)):
295 if i > 0:
296 code.writeline(f"i{i} = tid % s{i}")
297 code.writeline(f"tid //= s{i}")
298 else:
299 code.writeline(f"i{i} = tid")
300 code.newline()
302 # loads
303 code.writeline("# loads")
304 ptrs_expr: str = " + ".join(
305 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank)
306 )
307 ptrs_expr: str = f"in0_ptr + {ptrs_expr}"
308 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)"
309 code.writeline(load_stmt)
310 code.newline()
312 # compute
313 code.writeline("# compute")
314 code.writeline("out0 = in0")
315 code.newline()
317 # stores
318 code.writeline("# stores")
319 ptrs_expr: str = " + ".join(f"i{j} * out0_stride{j}" for j in range(rank))
320 ptrs_expr: str = f"out0_ptr + {ptrs_expr}"
321 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)"
322 code.writeline(store_stmt)
324 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
325 code.writeline("else: # grid-stride-loop style kernel")
326 with code.indent():
327 code.writeline("for j in range(0, tiles_per_cta):")
328 with code.indent():
329 tid_stmt = "tid = init_tid + j * tile_size * num_ctas"
330 code.writeline(tid_stmt)
332 # only apply masking when rank > 0
333 # since we only load a value instead of a block of values when the rank is 0
334 mask_stmt: str = "mask = tid < num_tasks"
335 code.writeline(mask_stmt)
336 code.newline()
338 # reconstruct multi index
339 code.writeline("# multi index recontruction")
340 for i in reversed(range(rank)):
341 if i > 0:
342 code.writeline(f"i{i} = tid % s{i}")
343 code.writeline(f"tid //= s{i}")
344 else:
345 code.writeline(f"i{i} = tid")
346 code.newline()
348 # loads
349 code.writeline("# loads")
350 ptrs_expr: str = " + ".join(
351 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank)
352 )
353 ptrs_expr: str = f"in0_ptr + {ptrs_expr}"
354 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)"
355 code.writeline(load_stmt)
356 code.newline()
358 # compute
359 code.writeline("# compute")
360 code.writeline("out0 = in0")
361 code.newline()
363 # stores
364 code.writeline("# stores")
365 ptrs_expr: str = " + ".join(
366 f"i{j} * out0_stride{j}" for j in range(rank)
367 )
368 ptrs_expr: str = f"out0_ptr + {ptrs_expr}"
369 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)"
370 code.writeline(store_stmt)
371 code.newline()
372 return code
375def generate_code(
376 rank: int,
377 wrapper_name: str,
378 destination_passing_func_name: str,
379 kernel_name: str,
380 code: IndentedBuffer,
381) -> IndentedBuffer:
382 # the only runtime determined factor is the rank of the task space
383 code = generate_imports(code)
384 code = generate_functional_repeat_wrapper(
385 wrapper_name, destination_passing_func_name, code
386 )
387 code = generate_destination_passing_repeat_wrapper(
388 rank, destination_passing_func_name, kernel_name, code
389 )
390 code = generate_repeat_kernel(rank, kernel_name, code)
391 return code
394class RepeatFunction:
395 def __init__(self):
396 self.pid = os.getpid()
397 # instantiated & cached overloads
398 self.overloads: Mapping[str, Callable] = {}
400 def __call__(self, x, sizes):
401 # note: kwargs should not be used in JITFunction directly
402 ndim = self.arg_key(x, sizes)
403 key = str(ndim)
404 if key in self.overloads:
405 overload = self.overloads[key]
406 else:
407 # generate file & import it
408 code = IndentedBuffer()
409 code = generate_code(
410 ndim,
411 "_wrapper",
412 "_wrapper_out",
413 "_repeat_flaggems_jit_function",
414 code,
415 )
417 file_name = f"repeat_rank_{key}_pid_{self.pid}.py"
419 with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
420 f.write(code.getvalue())
422 # load
423 spec = importlib.util.spec_from_file_location(
424 f"_gen_module_rank_{key}_pid_{self.pid}",
425 f.name,
426 )
428 m = importlib.util.module_from_spec(spec)
429 # do not expose it to sys.modules
430 # sys.modules["_add_module"] = m
431 spec.loader.exec_module(m)
432 overload = getattr(m, "_wrapper")
433 self.overloads[key] = overload
434 return overload(x, sizes)
436 def arg_key(self, x, sizes):
437 max_rank = max(x.ndim, len(sizes))
438 return max_rank
441_repeat_func = RepeatFunction()
444def repeat(inp: torch.Tensor, sizes) -> torch.Tensor:
445 logger.debug("GEMS REPEAT")
447 out = _repeat_func(inp, sizes)
448 return out