Coverage for src/flag_gems/runtime/backend/_cambricon/ops/pad.py: 0%
344 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
7import triton
8import triton.language as tl
10from flag_gems.utils import libentry
11from flag_gems.utils.code_cache import code_cache_dir
12from flag_gems.utils.code_utils import IndentedBuffer
14from ..utils import TOTAL_CORE_NUM
16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
19# --------------------------- padding wrapper genration -----------------------------------
20def parameter_for_wrapper() -> str:
21 """Generate parameter declaration with type annotation for wrapper function.
22 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
23 """
24 parameters: List[str] = []
26 parameters.append("in0")
27 parameters.append("pad")
28 parameters.append("mode")
29 parameters.append("value=0")
30 return ", ".join(parameters)
33def parameter_for_wrapper_out() -> str:
34 """Generate parameter declaration with type annotation for wrapper function.
35 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
36 """
37 parameters: List[str] = []
39 parameters.append("in0")
40 parameters.append("out0")
41 parameters.append("dst_shape")
42 parameters.append("pad_before")
43 parameters.append("pad_after")
44 parameters.append("mode")
45 parameters.append("value=0")
47 return ", ".join(parameters)
50def parameter_ref_for_wrapper() -> str:
51 """Generate parameter reference for wrapper function.
52 Example: in0, val0, out0, out0_offset
53 """
54 parameters: List[str] = []
56 parameters.append("in0")
57 parameters.append("out0")
58 parameters.append("dst_shape")
59 parameters.append("pad_before")
60 parameters.append("pad_after")
61 parameters.append("mode")
62 parameters.append("value")
64 return ", ".join(parameters)
67def output_ref_for_wrapper() -> str:
68 return "out0"
71def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
72 code.writeline("import math")
73 code.writeline("import torch")
74 code.writeline("import triton")
75 code.writeline("from triton import language as tl")
76 code.newline()
77 code.writeline("from flag_gems.utils.libentry import libentry")
78 code.writeline("from flag_gems.runtime import torch_device_fn")
79 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
80 code.writeline("from flag_gems.utils.type_utils import type_promotion")
81 code.newline()
82 code.newline()
83 return code
86def generate_functional_padding_wrapper(
87 wrapper_name: str,
88 destination_passing_func_name: str,
89 code: IndentedBuffer,
90) -> IndentedBuffer:
91 # wrapper signature
92 parameters: str = parameter_for_wrapper()
93 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
94 code.writeline(wrapper_signature)
96 with code.indent():
97 code.writeline("ndim = in0.ndim")
98 code.writeline("pad_size = len(pad)")
99 code.writeline("assert pad_size % 2 == 0")
100 code.newline()
101 code.writeline("pad_before = [0 for _ in range(ndim)]")
102 code.writeline("pad_after = [0 for _ in range(ndim)]")
103 code.newline()
104 code.writeline("pad_pair = pad_size // 2 ")
105 code.writeline("for i in range(pad_pair): ")
106 with code.indent():
107 code.writeline("pad_before[ndim - i - 1] = pad[2 * i]")
108 code.writeline("pad_after[ndim - i - 1] = pad[2 * i + 1]")
109 code.writeline("dst_shape = list(in0.shape)")
110 code.writeline("for i in range(ndim): ")
111 with code.indent():
112 code.writeline("dst_shape[i] += pad_before[i] + pad_after[i]")
114 code.writeline(
115 ("out0 = torch.empty(dst_shape, device=in0.device, dtype=in0.dtype)")
116 )
118 # call destination_passing_func
119 output_names: str = output_ref_for_wrapper()
120 call_str = (
121 f"{output_names} = {destination_passing_func_name}"
122 f"({parameter_ref_for_wrapper()})"
123 )
124 code.writeline(call_str)
126 return_str = "return out0"
127 code.writeline(return_str)
128 code.newline()
129 code.newline()
131 return code
134def generate_destination_passing_padding_wrapper(
135 rank: int,
136 wrapper_name: str,
137 kernel_name: str,
138 code: IndentedBuffer,
139) -> IndentedBuffer:
140 # wrapper signature
141 parameters: str = parameter_for_wrapper_out()
143 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
144 code.writeline(wrapper_signature)
146 with code.indent():
147 # docstring
148 code.writeline("BLOCK_SIZE = 2048")
149 code.writeline("grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)")
150 code.newline()
152 code.writeline("x_shape = in0.shape")
153 code.writeline("in_strides0 = in0.stride()")
154 code.writeline("out_strides = out0.stride()")
156 # input strides for each input tensor w.r.t. the task index space
157 if rank > 0:
158 code.writeline("# strides of each tensor argument w.r.t the task space")
159 for i in range(rank):
160 code.writeline(f"valid_dim{i}_start = pad_before[{i}]")
162 code.writeline(f"valid_dim{i}_end = dst_shape[{i}] - pad_after[{i}]")
164 code.newline()
166 code.writeline("IS_CONSTANT = mode == 'constant'")
167 code.writeline("IS_REFLECT = mode == 'reflect'")
168 code.writeline("IS_REPLICATE = mode == 'replicate'")
169 code.writeline("IS_CIRCULAR = mode == 'circular'")
171 code.newline()
173 # grid
174 code.writeline("# kernel launch")
176 # launch kernel
177 code.writeline("with torch_device_fn.device(in0.device):")
178 with code.indent():
179 kernel_launch: str = f"{kernel_name}[grid]("
180 code.writeline(kernel_launch)
182 with code.indent():
183 code.writeline("in0, out0, ")
185 if rank > 0:
186 s = ", ".join(f"x_shape[{j}]" for j in range(rank))
187 code.writeline(f"{s}, # shape for x")
189 s = ", ".join(f"in_strides0[{j}]" for j in range(rank))
190 code.writeline(f"{s}, # stride for x")
192 s = ", ".join(f"out_strides[{j}]" for j in range(rank))
193 code.writeline(f"{s}, # stride for out")
195 s = ", ".join(f"valid_dim{j}_start" for j in range(rank))
196 code.writeline(f"{s}, # valid dim start")
198 s = ", ".join(f"valid_dim{j}_end" for j in range(rank))
199 code.writeline(f"{s}, # valid dim end")
201 code.writeline("in0.numel(), ")
202 code.writeline("out0.numel(), ")
203 code.writeline("value, ")
204 code.writeline("IS_CONSTANT, ")
205 code.writeline("IS_REFLECT, ")
206 code.writeline("IS_REPLICATE, ")
207 code.writeline("IS_CIRCULAR, ")
208 code.writeline("BLOCK_SIZE, ")
209 code.writeline(")")
211 code.writeline("return out0")
212 code.newline()
213 code.newline()
214 return code
217def generate_pad_kernel(
218 rank: int,
219 kernel_name: str,
220 code: IndentedBuffer,
221) -> IndentedBuffer:
222 # make the inlined function visible in the context
223 code.newline()
225 # the decorators
226 code.writeline("@libentry()")
227 non_specialize_arg_names = ["value"]
228 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
230 # signature
231 code.writeline(f"def {kernel_name}(")
232 with code.indent():
233 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
235 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
237 if rank > 0:
238 # shape for inputs
239 shape_args = ", ".join(f"x_shape{j}: int" for j in range(rank))
240 code.writeline(f"{shape_args}, # shape for x")
242 # shape for inputs
243 stride_args = ", ".join(f"in_strides{j}: int" for j in range(rank))
244 code.writeline(f"{stride_args}, # stride for x")
246 # shape for inputs
247 stride_args = ", ".join(f"out_strides{j}: int" for j in range(rank))
248 code.writeline(f"{stride_args}, # stride for out")
250 # shape for inputs
251 stride_args = ", ".join(f"valid_dim{j}_start: int" for j in range(rank))
252 code.writeline(f"{stride_args}, # valid dim start")
254 # shape for inputs
255 stride_args = ", ".join(f"valid_dim{j}_end: int" for j in range(rank))
256 code.writeline(f"{stride_args}, # valid dim end")
258 code.writeline("in_elem_cnt: tl.constexpr, ")
259 code.writeline("out_elem_cnt: tl.constexpr, ")
260 code.writeline("value, # padding value")
261 code.writeline("IS_CONSTANT: tl.constexpr, ")
262 code.writeline("IS_REFLECT: tl.constexpr, ")
263 code.writeline("IS_REPLICATE: tl.constexpr, ")
264 code.writeline("IS_CIRCULAR: tl.constexpr, ")
265 code.writeline("BLOCK_SIZE: tl.constexpr, ")
267 code.writeline("):")
269 with code.indent():
270 code.writeline("pid = tl.program_id(0)")
271 code.writeline("block_offset = pid * BLOCK_SIZE")
272 code.writeline("offset = block_offset + tl.arange(0, BLOCK_SIZE)")
273 code.newline()
275 code.writeline("remaining = offset ")
276 for i in range(rank):
277 code.writeline(f"idx = remaining // out_strides{i}")
278 code.writeline(f"dst_index_{i} = idx")
279 code.writeline(f"remaining = remaining - idx * out_strides{i}")
280 code.newline()
282 code.writeline("if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)")
283 code.writeline("if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)")
285 code.writeline(
286 "cond = ((dst_index_0 >= valid_dim0_start) & (dst_index_0 < valid_dim0_end))"
287 )
289 for i in range(1, rank):
290 code.writeline(
291 f"cond &= ((dst_index_{i} >= valid_dim{i}_start) & (dst_index_{i} < valid_dim{i}_end))"
292 )
294 code.writeline(
295 "if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)"
296 )
298 for i in range(rank):
299 code.writeline(f"src_index_{i} = dst_index_{i} - valid_dim{i}_start ")
301 for i in range(rank):
302 code.writeline(
303 f"src_index_{i} = tl.where(src_index_{i} < 0, 0, src_index_{i})"
304 )
306 code.newline()
307 code.writeline("if IS_REFLECT: ")
308 with code.indent():
309 for i in range(rank):
310 code.writeline(
311 f"""src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start,
312 valid_dim{i}_start - dst_index_{i}, src_index_{i})"""
313 )
314 for i in range(rank):
315 code.writeline(
316 f"""src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end,
317 (x_shape{i} + valid_dim{i}_start - 1) * 2 - dst_index_{i} - valid_dim{i}_start, src_index_{i})"""
318 )
320 code.newline()
321 code.writeline("if IS_REPLICATE: ")
322 with code.indent():
323 for i in range(rank):
324 code.writeline(
325 f"src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start, 0, src_index_{i})"
326 )
327 for i in range(rank):
328 code.writeline(
329 f"src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end, x_shape{i} - 1, src_index_{i})"
330 )
332 code.newline()
333 code.writeline("if IS_CIRCULAR: ")
334 with code.indent():
335 for i in range(rank):
336 code.writeline(
337 f"""src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start,
338 dst_index_{i} + x_shape{i} - valid_dim{i}_start, src_index_{i})"""
339 )
340 for i in range(rank):
341 code.writeline(
342 f"""src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end,
343 dst_index_{i} - valid_dim{i}_end, src_index_{i})"""
344 )
346 code.newline()
348 code.writeline("src_offset = src_index_0 * in_strides0")
349 for i in range(1, rank):
350 code.writeline(f"src_offset += src_index_{i} * in_strides{i}")
352 code.writeline(f"load_cond = src_index_{i} < x_shape{i}")
353 for i in range(1, rank):
354 code.writeline(f"load_cond &= src_index_{i} < x_shape{i}")
356 code.writeline("if IS_CONSTANT: ")
357 with code.indent():
358 code.writeline(
359 "x_val = tl.load(in0_ptr + src_offset, mask=((if_pad == 0) & load_cond), other=value)"
360 )
361 code.writeline("else: ")
362 with code.indent():
363 code.writeline(
364 "x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)"
365 )
366 code.writeline("tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)")
368 return code
371def generate_code(
372 inputs: Tuple[Any],
373 wrapper_name: str,
374 destination_passing_func_name: str,
375 kernel_name: str,
376 code: IndentedBuffer,
377) -> IndentedBuffer:
378 shape = inputs[0].shape
379 rank = len(shape)
381 # the only runtime determined factor is the rank of the task space
382 code = generate_imports(code)
383 code = generate_functional_padding_wrapper(
384 wrapper_name, destination_passing_func_name, code
385 )
386 code = generate_destination_passing_padding_wrapper(
387 rank, destination_passing_func_name, kernel_name, code
388 )
389 code = generate_pad_kernel(rank, kernel_name, code)
390 return code
393class PadFunction:
394 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
395 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
396 The generated code are written out to the cache directory (defaults to ~/.flaggems).
397 """
399 def __init__(self):
400 self.pid = os.getpid()
401 self.overloads: Mapping[str, Callable] = {}
403 def __call__(self, *args, **kwargs):
404 # note: kwargs should not be used in JITFunction directly
405 key = f"{self.arg_key(*args)}"
406 if key in self.overloads:
407 overload = self.overloads[key]
408 else:
409 # generate file & import it
410 code = IndentedBuffer()
411 code = generate_code(
412 args,
413 "_pad_wrapper",
414 "_pad_wrapper_out",
415 "_pad_jit_function",
416 code,
417 )
419 file_name = f"constant_pad_rank_{key}_pid_{self.pid}.py"
421 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
422 f.write(code.getvalue())
424 # load
425 spec = importlib.util.spec_from_file_location(
426 f"_gen_module_rank_{key}_pid_{self.pid}",
427 f.name,
428 )
430 m = importlib.util.module_from_spec(spec)
431 # do not expose it to sys.modules
432 # sys.modules["_add_module"] = m
433 spec.loader.exec_module(m)
434 overload = getattr(m, "_pad_wrapper")
435 self.overloads[key] = overload
436 return overload(*args, **kwargs)
438 def arg_key(self, *args):
439 tensors = [item for item in args if torch.is_tensor(item)]
440 max_rank = max(item.ndim for item in tensors)
441 return max_rank
444_pad_func = PadFunction()
447@libentry()
448@triton.autotune(
449 configs=[
450 triton.Config({"BLOCK_SIZE": 2**n}, num_stages=s)
451 for n in range(10, 16, 2)
452 for s in [1, 3]
453 ],
454 key=["inp_elements"],
455)
456@triton.jit
457def pad_1d_constant_kernel(
458 inp_ptr,
459 out_ptr,
460 inp_elements,
461 pad_value,
462 pad_left,
463 pad_right,
464 BLOCK_SIZE: tl.constexpr,
465):
466 pid = tl.program_id(0)
467 num_jobs = tl.num_programs(0)
468 start = pid * BLOCK_SIZE
469 step = num_jobs * BLOCK_SIZE
470 out_elements = pad_left + inp_elements + pad_right
471 for off in range(start, out_elements, step):
472 inp_offset = off + tl.arange(0, BLOCK_SIZE) - pad_left
473 inp_mask = inp_offset >= 0 and inp_offset < inp_elements
474 inp = tl.load(inp_ptr + inp_offset, mask=inp_mask, other=pad_value)
475 out_offset = off + tl.arange(0, BLOCK_SIZE)
476 out_mask = out_offset < out_elements
477 tl.store(out_ptr + out_offset, inp, mask=out_mask)
480@libentry()
481@triton.autotune(
482 configs=[
483 triton.Config({"BLOCK_H": n}, num_stages=s)
484 for n in [1, 4, 8, 12, 16, 24]
485 for s in [1, 3]
486 ],
487 key=["H", "W"],
488)
489@triton.jit
490def pad_2d_constant_kernel(
491 inp_ptr,
492 out_ptr,
493 H,
494 W: tl.constexpr,
495 pad_value,
496 pad_left: tl.constexpr,
497 pad_right: tl.constexpr,
498 pad_top,
499 pad_bottom,
500 BLOCK_H: tl.constexpr,
501):
502 pid = tl.program_id(0)
503 num_jobs = tl.num_programs(0)
504 block_start = pid * BLOCK_H
505 step = num_jobs * BLOCK_H
506 out_W: tl.constexpr = pad_left + W + pad_right
507 out_H = pad_top + H + pad_bottom
508 for batch_idx in range(block_start, out_H, step):
509 offset_h = tl.arange(0, BLOCK_H) + batch_idx - pad_top
510 offset_w = tl.arange(0, out_W) - pad_left
511 offsets = offset_h[:, None] * W + offset_w[None, :]
512 mask = (offset_h[:, None] >= 0 and offset_h[:, None] < H) and (
513 offset_w[None, :] >= 0 and offset_w[None, :] < W
514 )
515 inp = tl.load(inp_ptr + offsets, mask=mask, other=pad_value)
517 out_offset_c = tl.arange(0, out_W)
518 out_offset_n = tl.arange(0, BLOCK_H) + batch_idx
519 out_offsets = out_offset_n[:, None] * out_W + out_offset_c[None, :]
520 out_mask = out_offset_n[:, None] < out_H and out_offset_c[None, :] < out_W
521 tl.store(out_ptr + out_offsets, inp, mask=out_mask)
524def pad(self, pad, mode="constant", value=None):
525 logger.debug("GEMS_CAMBRICON CONSTANT PAD ND")
527 ndim = self.ndim
528 pad_size = len(pad)
529 assert pad_size % 2 == 0
531 if value is None:
532 value = 0.0
534 if mode == "constant":
535 pad_before = [0 for _ in range(ndim)]
536 pad_after = [0 for _ in range(ndim)]
537 pad_pair = pad_size // 2
538 for i in range(pad_pair):
539 pad_before[ndim - i - 1] = pad[2 * i]
540 pad_after[ndim - i - 1] = pad[2 * i + 1]
542 inp_shape = list(self.shape)
543 out_shape = list(self.shape)
544 for i in range(ndim):
545 out_shape[i] += pad_before[i] + pad_after[i]
546 out = torch.empty(out_shape, dtype=self.dtype, device=self.device)
548 if ndim == 1:
549 grid = lambda meta: (
550 min(triton.cdiv(out_shape[0], meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),
551 )
552 pad_1d_constant_kernel[grid](
553 self.contiguous(),
554 out,
555 inp_shape[0],
556 value,
557 pad_before[-1],
558 pad_after[-1],
559 )
560 return out
562 if ndim == 2:
563 grid = lambda meta: (
564 min(triton.cdiv(out_shape[0], meta["BLOCK_H"]), TOTAL_CORE_NUM),
565 )
566 pad_2d_constant_kernel[grid](
567 self.contiguous(),
568 out,
569 inp_shape[0],
570 inp_shape[1],
571 value,
572 pad_before[-1],
573 pad_after[-1],
574 pad_before[-2],
575 pad_after[-2],
576 )
577 return out
579 if ndim == 3:
580 out[: pad_before[0]] = torch.full(
581 out[0 : pad_before[0]].shape,
582 value,
583 dtype=self.dtype,
584 device=self.device,
585 )
586 out[pad_before[0] + inp_shape[0] :] = torch.full(
587 out[pad_before[0] + inp_shape[0] :].shape,
588 value,
589 dtype=self.dtype,
590 device=self.device,
591 )
593 for i in range(pad_before[0], pad_before[0] + inp_shape[0]):
594 grid = lambda meta: (
595 min(triton.cdiv(out_shape[1], meta["BLOCK_H"]), TOTAL_CORE_NUM),
596 )
597 pad_2d_constant_kernel[grid](
598 self[i - pad_before[0]].contiguous(),
599 out[i],
600 inp_shape[1],
601 inp_shape[2],
602 value,
603 pad_before[-1],
604 pad_after[-1],
605 pad_before[-2],
606 pad_after[-2],
607 )
608 return out
610 if mode == "reflect":
611 ndim //= 2
612 assert (
613 len(pad) == 2 * ndim
614 ), f"padding size is expected to be {2 * ndim}, but got {len(pad)}"
616 for i in range(ndim):
617 pad_l, pad_r = pad[2 * i], pad[2 * i + 1]
618 input_l, input_r = (
619 self.shape[ndim - (2 * i + 1) - 1],
620 self.shape[ndim - (2 * i + 1)],
621 )
622 assert (
623 pad_l < input_l and pad_r < input_r
624 ), \
625 f"padding size should be less than the corresponding input dimension, \
626 but got padding size: {pad_l}, {pad_r}, input size: {self.shape}"
628 if mode == "circular":
629 ndim //= 2
630 assert (
631 len(pad) == 2 * ndim
632 ), f"padding size is expected to be {2 * ndim}, but got {len(pad)}"
633 for i in range(ndim):
634 pad_l, pad_r = pad[2 * i], pad[2 * i + 1]
635 input_size = self.shape[ndim - i - 1]
636 assert (
637 pad_l <= input_size and pad_r <= input_size
638 ), "Padding value causes wrapping around more than once."
640 out = _pad_func(self, pad, mode, float(value))
641 return out
644def constant_pad_nd(self, pad_list, value=0):
645 return pad(self, pad_list, mode="constant", value=value)