Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/pad.py: 0%
282 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14# --------------------------- padding wrapper genration -----------------------------------
15def parameter_for_wrapper() -> str:
16 """Generate parameter declaration with type annotation for wrapper function.
17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
18 """
19 parameters: List[str] = []
21 parameters.append("in0")
22 parameters.append("pad")
23 parameters.append("mode")
24 parameters.append("value=0")
25 return ", ".join(parameters)
28def parameter_for_wrapper_out() -> str:
29 """Generate parameter declaration with type annotation for wrapper function.
30 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
31 """
32 parameters: List[str] = []
34 parameters.append("in0")
35 parameters.append("out0")
36 parameters.append("dst_shape")
37 parameters.append("pad_before")
38 parameters.append("pad_after")
39 parameters.append("mode")
40 parameters.append("value=0")
42 return ", ".join(parameters)
45def parameter_ref_for_wrapper() -> str:
46 """Generate parameter reference for wrapper function.
47 Example: in0, val0, out0, out0_offset
48 """
49 parameters: List[str] = []
51 parameters.append("in0")
52 parameters.append("out0")
53 parameters.append("dst_shape")
54 parameters.append("pad_before")
55 parameters.append("pad_after")
56 parameters.append("mode")
57 parameters.append("value")
59 return ", ".join(parameters)
62def output_ref_for_wrapper() -> str:
63 return "out0"
66def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
67 code.writeline("import math")
68 code.writeline("import torch")
69 code.writeline("import triton")
70 code.writeline("from triton import language as tl")
71 code.newline()
72 code.writeline("from flag_gems.utils.libentry import libentry")
73 code.writeline("from flag_gems.runtime import torch_device_fn")
74 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
75 code.writeline("from flag_gems.utils.type_utils import type_promotion")
76 code.newline()
77 code.newline()
78 return code
81def generate_functional_padding_wrapper(
82 wrapper_name: str,
83 destination_passing_func_name: str,
84 code: IndentedBuffer,
85) -> IndentedBuffer:
86 # wrapper signature
87 parameters: str = parameter_for_wrapper()
88 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
89 code.writeline(wrapper_signature)
91 with code.indent():
92 code.writeline("ndim = in0.ndim")
93 code.writeline("pad_size = len(pad)")
94 code.writeline("assert pad_size % 2 == 0")
95 code.newline()
96 code.writeline("pad_before = [0 for _ in range(ndim)]")
97 code.writeline("pad_after = [0 for _ in range(ndim)]")
98 code.newline()
99 code.writeline("pad_pair = pad_size // 2 ")
100 code.writeline("for i in range(pad_pair): ")
101 with code.indent():
102 code.writeline("pad_before[ndim - i - 1] = pad[2 * i]")
103 code.writeline("pad_after[ndim - i - 1] = pad[2 * i + 1]")
104 code.writeline("dst_shape = list(in0.shape)")
105 code.writeline("for i in range(ndim): ")
106 with code.indent():
107 code.writeline("dst_shape[i] += pad_before[i] + pad_after[i]")
109 code.writeline(
110 ("out0 = torch.empty(dst_shape, device=in0.device, dtype=in0.dtype)")
111 )
113 # call destination_passing_func
114 output_names: str = output_ref_for_wrapper()
115 call_str = (
116 f"{output_names} = {destination_passing_func_name}"
117 f"({parameter_ref_for_wrapper()})"
118 )
119 code.writeline(call_str)
121 return_str = "return out0"
122 code.writeline(return_str)
123 code.newline()
124 code.newline()
126 return code
129def generate_destination_passing_padding_wrapper(
130 rank: int,
131 wrapper_name: str,
132 kernel_name: str,
133 code: IndentedBuffer,
134) -> IndentedBuffer:
135 # wrapper signature
136 parameters: str = parameter_for_wrapper_out()
138 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
139 code.writeline(wrapper_signature)
141 with code.indent():
142 # docstring
143 code.writeline("num_ctas = 12")
144 code.writeline(
145 "BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(out0.numel(), num_ctas))"
146 )
147 code.writeline("grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)")
148 code.newline()
150 code.writeline("x_shape = in0.shape")
151 code.writeline("in_strides0 = in0.stride()")
152 code.writeline("out_strides = out0.stride()")
154 # input strides for each input tensor w.r.t. the task index space
155 if rank > 0:
156 code.writeline("# strides of each tensor argument w.r.t the task space")
157 for i in range(rank):
158 code.writeline(f"valid_dim{i}_start = pad_before[{i}]")
160 code.writeline(f"valid_dim{i}_end = dst_shape[{i}] - pad_after[{i}]")
162 code.newline()
164 code.writeline("IS_CONSTANT = mode == 'constant'")
165 code.writeline("IS_REFLECT = mode == 'reflect'")
166 code.writeline("IS_REPLICATE = mode == 'replicate'")
167 code.writeline("IS_CIRCULAR = mode == 'circular'")
169 code.newline()
171 # grid
172 code.writeline("# kernel launch")
173 code.writeline("import os")
174 code.writeline('os.environ["TRITONXPU_OTHER_SIM"] = "1"')
175 code.writeline('os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"')
176 # launch kernel
177 code.writeline("with torch_device_fn.device(in0.device):")
178 with code.indent():
179 kernel_launch: str = f"{kernel_name}[grid]("
180 code.writeline(kernel_launch)
182 with code.indent():
183 code.writeline("in0, out0, ")
185 if rank > 0:
186 s = ", ".join(f"x_shape[{j}]" for j in range(rank))
187 code.writeline(f"{s}, # shape for x")
189 s = ", ".join(f"in_strides0[{j}]" for j in range(rank))
190 code.writeline(f"{s}, # stride for x")
192 s = ", ".join(f"out_strides[{j}]" for j in range(rank))
193 code.writeline(f"{s}, # stride for out")
195 s = ", ".join(f"valid_dim{j}_start" for j in range(rank))
196 code.writeline(f"{s}, # valid dim start")
198 s = ", ".join(f"valid_dim{j}_end" for j in range(rank))
199 code.writeline(f"{s}, # valid dim end")
201 code.writeline("in0.numel(), ")
202 code.writeline("out0.numel(), ")
203 code.writeline("value, ")
204 code.writeline("IS_CONSTANT, ")
205 code.writeline("IS_REFLECT, ")
206 code.writeline("IS_REPLICATE, ")
207 code.writeline("IS_CIRCULAR, ")
208 code.writeline("BLOCK_SIZE, ")
209 code.writeline("buffer_size_limit=512, ")
210 code.writeline(")")
212 code.writeline('if "TRITONXPU_OTHER_SIM" in os.environ: ')
213 with code.indent():
214 code.writeline('del os.environ["TRITONXPU_OTHER_SIM"]')
216 code.writeline('if "TRITONXPU_STORE_MASK_SIM" in os.environ: ')
217 with code.indent():
218 code.writeline('del os.environ["TRITONXPU_STORE_MASK_SIM"]')
220 code.writeline("return out0")
221 code.newline()
222 code.newline()
223 return code
226def generate_pad_kernel(
227 rank: int,
228 kernel_name: str,
229 code: IndentedBuffer,
230) -> IndentedBuffer:
231 # make the inlined function visible in the context
232 code.newline()
234 # the decorators
235 code.writeline("@libentry()")
236 non_specialize_arg_names = ["value"]
237 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
239 # signature
240 code.writeline(f"def {kernel_name}(")
241 with code.indent():
242 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
244 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
246 if rank > 0:
247 # shape for inputs
248 shape_args = ", ".join(f"x_shape{j}: tl.constexpr" for j in range(rank))
249 code.writeline(f"{shape_args}, # shape for x")
251 # shape for inputs
252 stride_args = ", ".join(f"in_strides{j}: tl.constexpr" for j in range(rank))
253 code.writeline(f"{stride_args}, # stride for x")
255 # shape for inputs
256 stride_args = ", ".join(
257 f"out_strides{j}: tl.constexpr" for j in range(rank)
258 )
259 code.writeline(f"{stride_args}, # stride for out")
261 # shape for inputs
262 stride_args = ", ".join(
263 f"valid_dim{j}_start: tl.constexpr" for j in range(rank)
264 )
265 code.writeline(f"{stride_args}, # valid dim start")
267 # shape for inputs
268 stride_args = ", ".join(
269 f"valid_dim{j}_end: tl.constexpr" for j in range(rank)
270 )
271 code.writeline(f"{stride_args}, # valid dim end")
273 code.writeline("in_elem_cnt: tl.constexpr, ")
274 code.writeline("out_elem_cnt: tl.constexpr, ")
275 code.writeline("value, # padding value")
276 code.writeline("IS_CONSTANT: tl.constexpr, ")
277 code.writeline("IS_REFLECT: tl.constexpr, ")
278 code.writeline("IS_REPLICATE: tl.constexpr, ")
279 code.writeline("IS_CIRCULAR: tl.constexpr, ")
280 code.writeline("BLOCK_SIZE: tl.constexpr, ")
282 code.writeline("):")
284 with code.indent():
285 code.writeline("pid = tle.program_id(0)")
286 code.writeline("block_offset = pid * BLOCK_SIZE")
287 code.writeline("offset = block_offset + tl.arange(0, BLOCK_SIZE)")
288 code.newline()
290 code.writeline("remaining = offset ")
291 for i in range(rank):
292 code.writeline(f"idx = remaining // out_strides{i}")
293 code.writeline(f"dst_index_{i} = idx")
294 code.writeline(f"remaining = remaining - idx * out_strides{i}")
295 code.newline()
297 code.writeline("if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)")
298 code.writeline("if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)")
300 code.writeline(
301 "cond = ((dst_index_0 >= valid_dim0_start) & (dst_index_0 < valid_dim0_end))"
302 )
304 for i in range(1, rank):
305 code.writeline(
306 f"cond &= ((dst_index_{i} >= valid_dim{i}_start) & (dst_index_{i} < valid_dim{i}_end))"
307 )
309 code.writeline(
310 "if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)"
311 )
313 for i in range(rank):
314 code.writeline(f"src_index_{i} = dst_index_{i} - valid_dim{i}_start ")
316 for i in range(rank):
317 code.writeline(
318 f"src_index_{i} = tl.where(src_index_{i} < 0, 0, src_index_{i})"
319 )
321 code.newline()
322 code.writeline("if IS_REFLECT: ")
323 with code.indent():
324 for i in range(rank):
325 code.writeline(
326 f"""src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start,
327 valid_dim{i}_start - dst_index_{i}, src_index_{i})"""
328 )
329 for i in range(rank):
330 code.writeline(
331 f"""src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end,
332 (x_shape{i} + valid_dim{i}_start - 1) * 2 - dst_index_{i} - valid_dim{i}_start, src_index_{i})"""
333 )
335 code.newline()
336 code.writeline("if IS_REPLICATE: ")
337 with code.indent():
338 for i in range(rank):
339 code.writeline(
340 f"src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start, 0, src_index_{i})"
341 )
342 for i in range(rank):
343 code.writeline(
344 f"src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end, x_shape{i} - 1, src_index_{i})"
345 )
347 code.newline()
348 code.writeline("if IS_CIRCULAR: ")
349 with code.indent():
350 for i in range(rank):
351 code.writeline(
352 f"""src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start,
353 dst_index_{i} + x_shape{i} - valid_dim{i}_start, src_index_{i})"""
354 )
355 for i in range(rank):
356 code.writeline(
357 f"""src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end,
358 dst_index_{i} - valid_dim{i}_end, src_index_{i})"""
359 )
361 code.newline()
363 code.writeline("src_offset = src_index_0 * in_strides0")
364 for i in range(1, rank):
365 code.writeline(f"src_offset += src_index_{i} * in_strides{i}")
367 code.writeline(f"load_cond = src_index_{i} < x_shape{i}")
368 for i in range(1, rank):
369 code.writeline(f"load_cond &= src_index_{i} < x_shape{i}")
371 code.writeline("if IS_CONSTANT: ")
372 with code.indent():
373 code.writeline(
374 "x_val = tl.load(in0_ptr + src_offset, mask=((if_pad == 0) & load_cond), other=value)"
375 )
376 code.writeline("else: ")
377 with code.indent():
378 code.writeline(
379 "x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)"
380 )
381 code.writeline("tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)")
383 return code
386def generate_code(
387 inputs: Tuple[Any],
388 wrapper_name: str,
389 destination_passing_func_name: str,
390 kernel_name: str,
391 code: IndentedBuffer,
392) -> IndentedBuffer:
393 shape = inputs[0].shape
394 rank = len(shape)
396 # the only runtime determined factor is the rank of the task space
397 code = generate_imports(code)
398 code = generate_functional_padding_wrapper(
399 wrapper_name, destination_passing_func_name, code
400 )
401 code = generate_destination_passing_padding_wrapper(
402 rank, destination_passing_func_name, kernel_name, code
403 )
404 code = generate_pad_kernel(rank, kernel_name, code)
405 return code
408class PadFunction:
409 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
410 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
411 The generated code are written out to the cache directory (defaults to ~/.flaggems).
412 """
414 def __init__(self):
415 self.pid = os.getpid()
416 self.overloads: Mapping[str, Callable] = {}
418 def __call__(self, *args, **kwargs):
419 # note: kwargs should not be used in JITFunction directly
420 key = f"{self.arg_key(*args)}"
421 if key in self.overloads:
422 overload = self.overloads[key]
423 else:
424 # generate file & import it
425 code = IndentedBuffer()
426 code = generate_code(
427 args,
428 "_wrapper",
429 "_wrapper_out",
430 "_jit_function",
431 code,
432 )
434 file_name = f"constant_pad_rank_{key}_pid_{self.pid}.py"
436 with open(cache_dir() / file_name, "wt", encoding="utf-8") as f:
437 f.write(code.getvalue())
439 # load
440 spec = importlib.util.spec_from_file_location(
441 f"_gen_module_rank_{key}_pid_{self.pid}",
442 f.name,
443 )
445 m = importlib.util.module_from_spec(spec)
446 # do not expose it to sys.modules
447 # sys.modules["_add_module"] = m
448 spec.loader.exec_module(m)
449 overload = getattr(m, "_wrapper")
450 self.overloads[key] = overload
451 return overload(*args, **kwargs)
453 def arg_key(self, *args):
454 tensors = [item for item in args if torch.is_tensor(item)]
455 max_rank = max(item.ndim for item in tensors)
456 return max_rank
459_pad_func = PadFunction()
462def pad(self, pad, mode="constant", value=None):
463 logger.debug("GEMS CONSTANT PAD ND")
465 ndim = self.ndim
467 if value is None:
468 value = 0.0
470 pad_pairs = len(pad) // 2
472 if mode == "reflect":
473 for i in range(pad_pairs):
474 pad_l, pad_r = pad[2 * i], pad[2 * i + 1]
475 input_size = self.shape[ndim - 1 - i]
476 assert (
477 pad_l < input_size and pad_r < input_size
478 ), \
479 f"padding size should be less than the corresponding input dimension, \
480 but got padding size: {pad_l}, {pad_r}, input size: {self.shape}"
482 if mode == "circular":
483 for i in range(pad_pairs):
484 pad_l, pad_r = pad[2 * i], pad[2 * i + 1]
485 input_size = self.shape[ndim - 1 - i]
486 assert (
487 pad_l <= input_size and pad_r <= input_size
488 ), "Padding value causes wrapping around more than once."
490 out = _pad_func(self, pad, mode, float(value))
491 return out
494def constant_pad_nd(self, pad_list, value=0):
495 return pad(self, pad_list, mode="constant", value=value)