Coverage for src/flag_gems/runtime/backend/_cambricon/ops/flip.py: 0%
275 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import importlib
2import logging
3import os
4from typing import Callable, Mapping
6import torch
8from flag_gems.utils.code_cache import cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14class FlipKernelCode(IndentedBuffer):
15 """
16 Flip kernel template.
17 """
19 overloads: Mapping[str, Callable] = {}
21 def __init__(self):
22 self.pid = os.getpid()
23 self.cache = self.overloads
24 self.kernel_name = "_flip_jit_kernel"
25 self.wrapper_func_name = "_wrapper"
26 super(FlipKernelCode, self).__init__()
28 def __init(self, x, dims):
29 """Initialize the flip kernel."""
30 dim_size = x.dim()
32 flip_dims = list(dims)
33 flip_dims_flags = [False for _ in x.stride()]
34 for i in range(len(flip_dims)):
35 dim = flip_dims[i]
36 assert (
37 dim >= -dim_size and dim < dim_size
38 ), "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
39 -dim_size, dim_size - 1, dim
40 )
41 if dim < 0:
42 flip_dims[i] = dim_size + dim
43 assert not flip_dims_flags[
44 dim
45 ], "dim {} appears multiple times in the list of dims".format(dim)
46 flip_dims_flags[dim] = True
48 # merge shapes and flip_dims_flags by flip flags.
49 self.merge_shapes = []
50 self.merge_strides = []
51 flag = flip_dims_flags[0]
52 self.merge_flip_dims_flags = []
53 self.merge_flip_dim = 0
54 shape = 1
55 for i in range(dim_size):
56 if (flag == flip_dims_flags[i]) or x.shape[i] == 1:
57 shape *= x.shape[i]
58 else:
59 self.merge_shapes.append(shape)
60 self.merge_strides.append(x.stride(i - 1))
61 self.merge_flip_dims_flags.append(flag)
62 if flag:
63 self.merge_flip_dim += 1
64 flag = flip_dims_flags[i]
65 shape = x.shape[i]
66 self.merge_shapes.append(shape)
67 self.merge_strides.append(1)
68 self.merge_flip_dims_flags.append(flag)
69 if flag:
70 self.merge_flip_dim += 1
72 self.merge_dim_size = len(self.merge_shapes)
74 def __imports(self):
75 """Generate imports for the kernel code."""
76 self.tpl(
77 """
78import math
79import torch
80import triton
81from triton import language as tl
83from flag_gems.utils import libentry
84from flag_gems.runtime.backend import vendor_module
85TOTAL_CORE_NUM = vendor_module.utils.TOTAL_CORE_NUM
86MAX_NRAM_SIZE = vendor_module.utils.MAX_NRAM_SIZE
89 """
90 )
92 def __wrapper(self):
93 """Generate wrapper function for the kernel code."""
94 self.newline()
95 self.tpl(
96 """
97def {wrapper_name}(x, merge_shapes, merge_strides, merge_dim_size):
98 if merge_dim_size == 0 or x.numel() <= 1:
99 return x.clone()
101 low_task = merge_shapes[merge_dim_size - 1]
102 sub_dim = 1
104 high_task = 1
105 if merge_dim_size > 1:
106 sub_dim = merge_shapes[merge_dim_size - 2]
107 low_task *= sub_dim
108 for i in range(merge_dim_size - 2):
109 high_task *= merge_shapes[i]
110 y = 1
111 if high_task < TOTAL_CORE_NUM:
112 for i in range(1, sub_dim + 1):
113 if sub_dim % i == 0:
114 y = i
115 if y * high_task >= TOTAL_CORE_NUM:
116 break
118 grid = lambda meta: (min(high_task, TOTAL_CORE_NUM), y, )
120 # in case of one-dim.
121 if (high_task == 1) and (y == 1) and (merge_dim_size == 1):
122 if low_task <= 1024:
123 grid = lambda meta: (1, 1, )
124 else:
125 grid = lambda meta: (1, TOTAL_CORE_NUM, )
127 out = torch.empty_like(x)
128 with torch.cuda.device(x.device):
129 {kernel_name}[grid]({args})
130 return out
131 """,
132 wrapper_name=self.wrapper_func_name,
133 kernel_name=self.kernel_name,
134 args=self.__kernel_args(is_declare=False),
135 )
137 def __config(self):
138 """Generate config for the kernel code."""
139 # generate config key.
140 merge_shapes_args_str = ", ".join(
141 [f"'merge_shape_{i}'" for i in range(self.merge_dim_size)]
142 )
143 merge_strides_args_str = ", ".join(
144 [f"'merge_stride_{i}'" for i in range(self.merge_dim_size)]
145 )
147 self.newline()
148 self.tpl(
149 """
151def get_h_dim(args):
152 merge_dim_size = args['merge_dim_size'];
153 high = 0
154 if merge_dim_size > 1:
155 high = args['merge_shape_{merge_dim_size_2}']
156 width = args['merge_shape_{merge_dim_size_1}']
157 max_nram_size = 3072
158 if max_nram_size >= width:
159 tmp_h = max_nram_size // width
160 if tmp_h < high:
161 return tmp_h
162 return high
163 return 0
165def get_w_dim(args):
166 merge_dim_size = args['merge_dim_size'];
167 width = args['merge_shape_{merge_dim_size_1}']
168 max_nram_size = 3072
169 if max_nram_size >= width:
170 return width
171 return max_nram_size
173@libentry()
174@triton.autotune(
175 configs=[
176 triton.Config({{}}, num_stages=3, num_warps=1),
177 ],
178 key = [{config_keys}],
179)
180@triton.heuristics(
181 values={{
182 "H_DIM": get_h_dim,
183 "W_DIM": get_w_dim,
184 }},
185)
186@triton.jit
187 """,
188 merge_dim_size_2=str(self.merge_dim_size - 2),
189 merge_dim_size_1=str(self.merge_dim_size - 1),
190 config_keys=f"'x_ptr', {merge_shapes_args_str}, {merge_strides_args_str}",
191 )
193 def __kernel_flip_2d(self):
194 """Generate kernel for 2d buffer flip."""
195 self.writeline(f"step = merge_shape_{self.merge_dim_size - 2} // num_y")
196 self.writeline(
197 f"src_offset += pid_y * step * merge_shape_{self.merge_dim_size - 1}"
198 )
199 if self.merge_flip_dims_flags[self.merge_dim_size - 2]:
200 # [flip, no-flip]
201 self.writeline("# flip low-2d [flip, no-flip]")
202 self.writeline(
203 f"dst_offset += (num_y - pid_y - 1) * step * merge_shape_{self.merge_dim_size - 1}"
204 )
205 self.writeline("if H_DIM != 0:")
206 with self.indent():
207 self.writeline(
208 "offset = tl.arange(0, H_DIM)[:,None]*W_DIM + tl.arange(0, W_DIM)[None,:]"
209 )
210 self.writeline("tail = step % H_DIM")
211 self.writeline("iter = step // H_DIM")
212 self.writeline("for i in range(0, iter):")
213 with self.indent():
214 self.writeline("in_offset = src_offset + i * H_DIM*W_DIM")
215 self.writeline(
216 "out_offset = dst_offset + tail * W_DIM + (iter - i - 1) * H_DIM*W_DIM"
217 )
218 self.writeline(
219 "src = tl.load(x_ptr + offset + in_offset, cache_modifier='.cg')"
220 )
221 self.writeline("src = tl.flip(src, [0])")
222 self.writeline(
223 "tl.store(out_ptr + offset + out_offset, src, cache_modifier='.cg')"
224 )
225 self.writeline("if tail > 0:")
226 with self.indent():
227 self.writeline("# process tail.")
228 self.writeline("in_offset = src_offset + iter * H_DIM*W_DIM")
229 self.writeline("out_offset = dst_offset - (H_DIM-tail)*W_DIM")
230 self.writeline("mask = offset < tail*W_DIM")
231 self.writeline(
232 "src = tl.load(x_ptr + offset + in_offset, mask=mask, other=0.0, cache_modifier='.cg')"
233 )
234 self.writeline("src = tl.flip(src, [0])")
235 self.writeline("mask = offset >= (H_DIM - tail) * W_DIM")
236 self.writeline(
237 "tl.store(out_ptr + offset + out_offset, src, mask=mask, cache_modifier='.cg')"
238 )
239 self.writeline("else:")
240 with self.indent():
241 self.writeline("offset = tl.arange(0, W_DIM)")
242 self.writeline(f"iter = merge_shape_{self.merge_dim_size - 1} // W_DIM")
243 self.writeline(f"tail = merge_shape_{self.merge_dim_size - 1} % W_DIM")
244 self.writeline("src = tl.zeros((W_DIM,), dtype=x_ptr.dtype.element_ty)")
245 self.writeline("for i in range(0, step):")
246 with self.indent():
247 self.writeline(
248 f"in_offset = src_offset + i * merge_shape_{self.merge_dim_size - 1}"
249 )
250 self.writeline(
251 f"out_offset = dst_offset + (step - i - 1) * merge_shape_{self.merge_dim_size - 1}"
252 )
253 self.writeline("for j in range(0, iter):")
254 with self.indent():
255 self.writeline("new_offset = offset + j*W_DIM")
256 self.writeline(
257 "src = tl.load(x_ptr + in_offset + new_offset, cache_modifier='.cg')"
258 )
259 self.writeline(
260 "tl.store(out_ptr + out_offset + new_offset, src, cache_modifier='.cg')"
261 )
262 self.writeline("if tail > 0:")
263 with self.indent():
264 self.writeline("new_offset = offset + iter*W_DIM")
265 self.writeline("mask = offset < tail")
266 self.writeline(
267 "src = tl.load(x_ptr + in_offset + new_offset, mask=mask, cache_modifier='.cg')"
268 )
269 self.writeline(
270 "tl.store(out_ptr + out_offset + new_offset, src, mask=mask, cache_modifier='.cg')"
271 )
272 else:
273 # [no-flip, flip]
274 self.writeline("# flip low-2d [no-flip, flip]")
275 self.writeline(
276 f"dst_offset += pid_y * step * merge_shape_{self.merge_dim_size - 1}"
277 )
278 self.writeline("if H_DIM != 0:")
279 with self.indent():
280 self.writeline(
281 "offset = tl.arange(0, H_DIM)[:,None]*W_DIM + tl.arange(0, W_DIM)[None,:]"
282 )
283 self.writeline("tail = step % H_DIM")
284 self.writeline("iter = step // H_DIM")
285 self.writeline("for i in range(0, iter):")
286 with self.indent():
287 self.writeline("in_offset = src_offset + i * H_DIM*W_DIM")
288 self.writeline("out_offset = dst_offset + i * H_DIM*W_DIM")
289 self.writeline(
290 "src = tl.load(x_ptr + offset + in_offset, cache_modifier='.cg')"
291 )
292 self.writeline("src = tl.flip(src, [1])")
293 self.writeline(
294 "tl.store(out_ptr + offset + out_offset, src, cache_modifier='.cg')"
295 )
296 self.writeline("if tail > 0:")
297 with self.indent():
298 self.writeline("# process tail.")
299 self.writeline("in_offset = src_offset + iter * H_DIM*W_DIM")
300 self.writeline("out_offset = dst_offset + iter * H_DIM*W_DIM")
301 self.writeline("mask = offset < tail*W_DIM")
302 self.writeline(
303 "src = tl.load(x_ptr + offset + in_offset, mask=mask, other=0.0, cache_modifier='.cg')"
304 )
305 self.writeline("src = tl.flip(src, [1])")
306 self.writeline(
307 "tl.store(out_ptr + offset + out_offset, src, mask=mask, cache_modifier='.cg')"
308 )
309 self.writeline("else:")
310 with self.indent():
311 self.writeline("offset = tl.arange(0, W_DIM)")
312 self.writeline("src = tl.zeros((W_DIM,), dtype=x_ptr.dtype.element_ty)")
313 self.writeline(f"tail = merge_shape_{self.merge_dim_size - 1} % W_DIM")
314 self.writeline(f"iter = merge_shape_{self.merge_dim_size - 1} // W_DIM")
315 self.writeline("for i in range(0, step):")
316 with self.indent():
317 self.writeline(
318 f"in_offset = src_offset + i * merge_shape_{self.merge_dim_size - 1}"
319 )
320 self.writeline(
321 f"out_offset = dst_offset + i * merge_shape_{self.merge_dim_size - 1}"
322 )
323 self.writeline("if tail > 0:")
324 with self.indent():
325 self.writeline("new_offset = in_offset + iter * W_DIM")
326 self.writeline("mask = offset < tail")
327 self.writeline(
328 "src = tl.load(x_ptr + new_offset + offset, mask=mask, cache_modifier='.cg')"
329 )
330 self.writeline("src = tl.flip(src, [0])")
331 self.writeline("mask = offset >= (W_DIM-tail)")
332 self.writeline(
333 "tl.store(out_ptr + out_offset - (W_DIM - tail) + offset, \
334 src, mask=mask, cache_modifier='.cg')"
335 )
336 self.writeline("for j in range(0, iter):")
337 with self.indent():
338 self.writeline("new_in_offset = in_offset + j * W_DIM")
339 self.writeline(
340 "new_out_offset = tail + out_offset + (iter - j - 1) * W_DIM"
341 )
342 self.writeline(
343 "src = tl.load(x_ptr + new_in_offset + offset, cache_modifier='.cg')"
344 )
345 self.writeline("src = tl.flip(src, [0])")
346 self.writeline(
347 "tl.store(out_ptr + new_out_offset + offset, src, cache_modifier='.cg')"
348 )
350 def __kernel(self):
351 """Generate kernel code body."""
352 # configuration.
353 self.__config()
354 kernel_signature = f"def {self.kernel_name}({self.__kernel_args()}):"
355 self.writeline(kernel_signature)
356 with self.indent():
357 self.writeline("pid_x = tl.program_id(0)")
358 self.writeline("num_x = tl.num_programs(0)")
359 self.writeline("pid_y = tl.program_id(1)")
360 self.writeline("num_y = tl.num_programs(1)")
361 # iteration on high dimension.
362 self.writeline("for high_id in range(pid_x, high_task, num_x):")
363 with self.indent():
364 self.writeline("src_offset = 0")
365 self.writeline("dst_offset = 0")
366 self.writeline("temp_high_id = high_id")
367 # get src_offset and dst offset
368 if self.merge_dim_size > 2:
369 for i in range(self.merge_dim_size - 2):
370 self.writeline(f"tmp_stride = merge_stride_{i} // low_task")
371 self.writeline(f"id_{i} = temp_high_id // tmp_stride")
372 self.writeline("temp_high_id = temp_high_id % tmp_stride")
373 self.writeline(f"src_offset += id_{i} * merge_stride_{i}")
374 if not self.merge_flip_dims_flags[i]:
375 self.writeline(f"dst_offset += id_{i} * merge_stride_{i}")
376 else:
377 self.writeline(
378 f"dst_offset += (merge_shape_{i} - id_{i} -1) * merge_stride_{i}"
379 )
380 self.__kernel_flip_2d()
381 elif self.merge_dim_size == 2:
382 self.__kernel_flip_2d()
383 elif self.merge_dim_size == 1:
384 assert self.merge_flip_dims_flags[0]
385 self.writeline("offset = tl.arange(0, W_DIM)")
386 self.writeline(
387 f"step = merge_shape_{self.merge_dim_size - 1} // num_y"
388 )
389 self.writeline(
390 f"tail = merge_shape_{self.merge_dim_size - 1} % num_y"
391 )
392 self.writeline("# process step.")
393 self.writeline("src_offset = pid_y * step")
394 self.writeline("dst_offset = tail + (num_y - pid_y - 1) * step")
395 self.writeline("step_iter = step // W_DIM")
396 self.writeline("step_tail = step % W_DIM")
397 self.writeline("for i in range(0, step_iter):")
398 with self.indent():
399 self.writeline("in_offset = src_offset + i * W_DIM")
400 self.writeline(
401 "out_offset = dst_offset + step_tail + (step_iter - i - 1) * W_DIM"
402 )
403 self.writeline(
404 "src = tl.load(x_ptr + offset + in_offset, cache_modifier='.cg')"
405 )
406 self.writeline("src = tl.flip(src, [0])")
407 self.writeline(
408 "tl.store(out_ptr + offset + out_offset, src, cache_modifier='.cg')"
409 )
410 self.writeline("if step_tail > 0:")
411 with self.indent():
412 self.writeline("in_offset = src_offset + step_iter * W_DIM")
413 self.writeline("out_offset = dst_offset")
414 self.writeline("mask = offset < step_tail")
415 self.writeline(
416 "src = tl.load(x_ptr + offset + in_offset, mask=mask, cache_modifier='.cg')"
417 )
418 self.writeline("src = tl.flip(src, [0])")
419 self.writeline("mask = offset >= (W_DIM - step_tail)")
420 self.writeline(
421 "tl.store(out_ptr + offset + out_offset - (W_DIM - step_tail), \
422 src, mask=mask, cache_modifier='.cg')"
423 )
424 self.writeline("if pid_y == num_y - 1:")
425 with self.indent():
426 self.writeline("# process tail.")
427 self.writeline("src_offset = num_y * step")
428 self.writeline("dst_offset = 0")
429 self.writeline("tail_iter = tail // W_DIM")
430 self.writeline("tail_remain = tail % W_DIM")
431 self.writeline("for i in range(0, tail_iter):")
432 with self.indent():
433 self.writeline("in_offset = src_offset + i * W_DIM")
434 self.writeline(
435 "out_offset = dst_offset + tail_remain + (tail_iter - i - 1) * W_DIM"
436 )
437 self.writeline(
438 "src = tl.load(x_ptr + offset + in_offset, cache_modifier='.cg')"
439 )
440 self.writeline("src = tl.flip(src, [0])")
441 self.writeline(
442 "tl.store(out_ptr + offset + out_offset, src, cache_modifier='.cg')"
443 )
444 self.writeline("if tail_remain > 0:")
445 with self.indent():
446 self.writeline("in_offset = src_offset + tail_iter * W_DIM")
447 self.writeline("out_offset = dst_offset")
448 self.writeline("mask = offset < tail_remain")
449 self.writeline(
450 "src = tl.load(x_ptr + offset + in_offset, mask=mask, cache_modifier='.cg')"
451 )
452 self.writeline("src = tl.flip(src, [0])")
453 self.writeline("mask = offset >= (W_DIM-tail_remain)")
454 self.writeline(
455 "tl.store(out_ptr + offset + out_offset - (W_DIM - tail_remain), \
456 src, mask=mask, cache_modifier='.cg')"
457 )
458 else:
459 raise RuntimeError(f"merge dim size error({self.merge_dim_size})")
461 def __gen_code(self):
462 """Entry point for code generation of flip."""
463 # generate imports.
464 self.__imports()
465 # generate wrapper function.
466 self.__wrapper()
468 # generate kernel.
469 self.__kernel()
471 def __kernel_args(self, is_declare=True):
472 """Generate string type of jit kernel arguments."""
473 merge_shapes_args = []
474 merge_strides_args = []
475 for i in range(self.merge_dim_size):
476 if is_declare:
477 merge_shapes_args.append(f"merge_shape_{i}")
478 merge_strides_args.append(f"merge_stride_{i}")
479 else:
480 merge_shapes_args.append(f"merge_shapes[{i}]")
481 merge_strides_args.append(f"merge_strides[{i}]")
482 merge_shapes_args_str = ", ".join(merge_shapes_args)
483 merge_strides_args_str = ", ".join(merge_strides_args)
485 extra_args_str = f"{merge_shapes_args_str}, {merge_strides_args_str}"
486 if is_declare:
487 return f"x_ptr, out_ptr, {extra_args_str}, merge_dim_size, high_task: tl.constexpr, \
488 low_task: tl.constexpr, H_DIM: tl.constexpr, W_DIM: tl.constexpr"
489 else:
490 return f"x, out, {extra_args_str}, merge_dim_size, high_task, low_task"
492 def __call__(self, x: torch.Tensor, dims) -> torch.Tensor:
493 """Call flip kernel."""
494 # initialize the funtion.
495 # note:
496 # - This function must be call first and only once.
497 self.__init(x, dims)
498 if (self.merge_flip_dim == 0) or (self.merge_dim_size == 0 or x.numel() <= 1):
499 return x.clone()
500 # get overload kernel.
501 flip_dim_str = "_".join([str(i) for i in self.merge_flip_dims_flags])
502 self.kernel_name = self.kernel_name + "_flip_" + flip_dim_str
503 key = f"{self.merge_dim_size}_{flip_dim_str}"
504 if key not in self.cache:
505 # generate code and cache.
506 self.__gen_code()
508 file_name = f"flip_{key}_pid_{self.pid}.py"
509 with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
510 f.write(self.getvalue())
511 # load
512 spec = importlib.util.spec_from_file_location(
513 f"_gen_module_{key}_pid_{self.pid}", f.name
514 )
515 m = importlib.util.module_from_spec(spec)
516 # do not expose it to sys.modules
517 # sys.modules["_add_module"] = m
518 spec.loader.exec_module(m)
519 overload = getattr(m, self.wrapper_func_name)
520 self.cache[key] = overload
522 overload = self.cache[key]
523 return overload(x, self.merge_shapes, self.merge_strides, self.merge_dim_size)
526def flip(A: torch.Tensor, dims) -> torch.Tensor:
527 logger.debug("GEMS_CAMBRICON FLIP")
528 if not A.is_contiguous():
529 A = A.contiguous()
530 return FlipKernelCode()(A, dims)