Coverage for src/flag_gems/ops/pad.py: 99%
279 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11logger = logging.getLogger(__name__)
14# --------------------------- padding wrapper genration -----------------------------------
15def parameter_for_wrapper() -> str:
16 """Generate parameter declaration with type annotation for wrapper function.
17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
18 """
19 parameters: List[str] = []
21 parameters.append("in0")
22 parameters.append("pad")
23 parameters.append("mode")
24 parameters.append("value=0")
25 return ", ".join(parameters)
28def parameter_for_wrapper_out() -> str:
29 """Generate parameter declaration with type annotation for wrapper function.
30 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
31 """
32 parameters: List[str] = []
34 parameters.append("in0")
35 parameters.append("out0")
36 parameters.append("dst_shape")
37 parameters.append("pad_before")
38 parameters.append("pad_after")
39 parameters.append("mode")
40 parameters.append("value=0")
42 return ", ".join(parameters)
45def parameter_ref_for_wrapper() -> str:
46 """Generate parameter reference for wrapper function.
47 Example: in0, val0, out0, out0_offset
48 """
49 parameters: List[str] = []
51 parameters.append("in0")
52 parameters.append("out0")
53 parameters.append("dst_shape")
54 parameters.append("pad_before")
55 parameters.append("pad_after")
56 parameters.append("mode")
57 parameters.append("value")
59 return ", ".join(parameters)
62def output_ref_for_wrapper() -> str:
63 return "out0"
66def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
67 code.writeline("import math")
68 code.writeline("import torch")
69 code.writeline("import triton")
70 code.writeline("from triton import language as tl")
71 code.newline()
72 code.writeline("from flag_gems.utils.libentry import libentry")
73 code.writeline("from flag_gems.runtime import torch_device_fn")
74 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
75 code.writeline("from flag_gems.utils.type_utils import type_promotion")
76 code.newline()
77 code.newline()
78 return code
81def generate_functional_padding_wrapper(
82 wrapper_name: str,
83 destination_passing_func_name: str,
84 code: IndentedBuffer,
85) -> IndentedBuffer:
86 # wrapper signature
87 parameters: str = parameter_for_wrapper()
88 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
89 code.writeline(wrapper_signature)
91 with code.indent():
92 code.writeline("ndim = in0.ndim")
93 code.writeline("pad_size = len(pad)")
94 code.writeline("assert pad_size % 2 == 0")
95 code.newline()
96 code.writeline("pad_before = [0 for _ in range(ndim)]")
97 code.writeline("pad_after = [0 for _ in range(ndim)]")
98 code.newline()
99 code.writeline("pad_pair = pad_size // 2 ")
100 code.writeline("for i in range(pad_pair): ")
101 with code.indent():
102 code.writeline("pad_before[ndim - i - 1] = pad[2 * i]")
103 code.writeline("pad_after[ndim - i - 1] = pad[2 * i + 1]")
104 code.writeline("dst_shape = list(in0.shape)")
105 code.writeline("for i in range(ndim): ")
106 with code.indent():
107 code.writeline("dst_shape[i] += pad_before[i] + pad_after[i]")
109 code.writeline(
110 ("out0 = torch.empty(dst_shape, device=in0.device, dtype=in0.dtype)")
111 )
113 # call destination_passing_func
114 output_names: str = output_ref_for_wrapper()
115 call_str = (
116 f"{output_names} = {destination_passing_func_name}"
117 f"({parameter_ref_for_wrapper()})"
118 )
119 code.writeline(call_str)
121 return_str = "return out0"
122 code.writeline(return_str)
123 code.newline()
124 code.newline()
126 return code
129def generate_destination_passing_padding_wrapper(
130 rank: int,
131 wrapper_name: str,
132 kernel_name: str,
133 code: IndentedBuffer,
134) -> IndentedBuffer:
135 # wrapper signature
136 parameters: str = parameter_for_wrapper_out()
138 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
139 code.writeline(wrapper_signature)
141 with code.indent():
142 # docstring
143 code.writeline("BLOCK_SIZE = 256")
144 code.writeline("grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)")
145 code.newline()
147 code.writeline("x_shape = in0.shape")
148 code.writeline("in_strides0 = in0.stride()")
149 code.writeline("out_strides = out0.stride()")
151 # input strides for each input tensor w.r.t. the task index space
152 if rank > 0:
153 code.writeline("# strides of each tensor argument w.r.t the task space")
154 for i in range(rank):
155 code.writeline(f"valid_dim{i}_start = pad_before[{i}]")
157 code.writeline(f"valid_dim{i}_end = dst_shape[{i}] - pad_after[{i}]")
159 code.newline()
161 code.writeline("# Check which dimensions have padding")
162 for i in range(rank):
163 code.writeline(
164 f"dim{i}_has_pad = pad_before[{i}] > 0 or pad_after[{i}] > 0"
165 )
166 code.writeline("IS_CONSTANT = mode == 'constant'")
167 code.writeline("IS_REFLECT = mode == 'reflect'")
168 code.writeline("IS_REPLICATE = mode == 'replicate'")
169 code.writeline("IS_CIRCULAR = mode == 'circular'")
171 code.newline()
173 # grid
174 code.writeline("# kernel launch")
176 # launch kernel
177 code.writeline("with torch_device_fn.device(in0.device):")
178 with code.indent():
179 kernel_launch: str = f"{kernel_name}[grid]("
180 code.writeline(kernel_launch)
182 with code.indent():
183 code.writeline("in0, out0, ")
185 if rank > 0:
186 s = ", ".join(f"x_shape[{j}]" for j in range(rank))
187 code.writeline(f"{s}, # shape for x")
189 s = ", ".join(f"in_strides0[{j}]" for j in range(rank))
190 code.writeline(f"{s}, # stride for x")
192 s = ", ".join(f"out_strides[{j}]" for j in range(rank))
193 code.writeline(f"{s}, # stride for out")
195 s = ", ".join(f"valid_dim{j}_start" for j in range(rank))
196 code.writeline(f"{s}, # valid dim start")
198 s = ", ".join(f"valid_dim{j}_end" for j in range(rank))
199 code.writeline(f"{s}, # valid dim end")
201 s = ", ".join(f"bool(dim{i}_has_pad)" for i in range(rank))
202 code.writeline(f"{s}, # dim has padding flags")
204 code.writeline("in0.numel(), ")
205 code.writeline("out0.numel(), ")
206 code.writeline("value, ")
207 code.writeline("IS_CONSTANT, ")
208 code.writeline("IS_REFLECT, ")
209 code.writeline("IS_REPLICATE, ")
210 code.writeline("IS_CIRCULAR, ")
211 code.writeline("BLOCK_SIZE, ")
212 code.writeline(")")
214 code.writeline("return out0")
215 code.newline()
216 code.newline()
217 return code
220def generate_pad_kernel(
221 rank: int,
222 kernel_name: str,
223 code: IndentedBuffer,
224) -> IndentedBuffer:
225 # make the inlined function visible in the context
226 code.newline()
228 # the decorators
229 code.writeline("@libentry()")
230 non_specialize_arg_names = ["value"]
231 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
233 # signature
234 code.writeline(f"def {kernel_name}(")
235 with code.indent():
236 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
238 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
240 if rank > 0:
241 # shape for inputs
242 shape_args = ", ".join(f"x_shape{j}: int" for j in range(rank))
243 code.writeline(f"{shape_args}, # shape for x")
245 # shape for inputs
246 stride_args = ", ".join(f"in_strides{j}: int" for j in range(rank))
247 code.writeline(f"{stride_args}, # stride for x")
249 # shape for inputs
250 stride_args = ", ".join(f"out_strides{j}: int" for j in range(rank))
251 code.writeline(f"{stride_args}, # stride for out")
253 # shape for inputs
254 stride_args = ", ".join(f"valid_dim{j}_start: int" for j in range(rank))
255 code.writeline(f"{stride_args}, # valid dim start")
257 # shape for inputs
258 stride_args = ", ".join(f"valid_dim{j}_end: int" for j in range(rank))
259 code.writeline(f"{stride_args}, # valid dim end")
261 for i in range(rank):
262 code.writeline(f"dim{i}_has_pad: tl.constexpr, ")
264 code.writeline("in_elem_cnt: tl.constexpr, ")
265 code.writeline("out_elem_cnt: tl.constexpr, ")
266 code.writeline("value, # padding value")
267 code.writeline("IS_CONSTANT: tl.constexpr, ")
268 code.writeline("IS_REFLECT: tl.constexpr, ")
269 code.writeline("IS_REPLICATE: tl.constexpr, ")
270 code.writeline("IS_CIRCULAR: tl.constexpr, ")
271 code.writeline("BLOCK_SIZE: tl.constexpr, ")
273 code.writeline("):")
275 with code.indent():
276 code.writeline("pid = tle.program_id(0)")
277 code.writeline("block_offset = pid * BLOCK_SIZE")
278 code.writeline("offset = block_offset + tl.arange(0, BLOCK_SIZE)")
279 code.newline()
281 code.writeline("remaining = offset ")
282 for i in range(rank):
283 code.writeline(f"idx = remaining // out_strides{i}")
284 code.writeline(f"dst_index_{i} = idx")
285 code.writeline(f"remaining = remaining - idx * out_strides{i}")
286 code.newline()
288 code.writeline("if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)")
289 code.writeline("if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)")
291 code.writeline(
292 "cond = ((dst_index_0 >= valid_dim0_start) & (dst_index_0 < valid_dim0_end))"
293 )
295 for i in range(1, rank):
296 code.writeline(
297 f"cond &= ((dst_index_{i} >= valid_dim{i}_start) & (dst_index_{i} < valid_dim{i}_end))"
298 )
300 code.writeline(
301 "if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)"
302 )
304 for i in range(rank):
305 code.writeline(f"src_index_{i} = dst_index_{i} - valid_dim{i}_start ")
307 for i in range(rank):
308 code.writeline(
309 f"src_index_{i} = tl.where(src_index_{i} < 0, 0, src_index_{i})"
310 )
312 code.newline()
313 code.writeline("if IS_REFLECT: ")
314 with code.indent():
315 for i in range(rank):
316 code.writeline(
317 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start),
318 valid_dim{i}_start - dst_index_{i}, src_index_{i})"""
319 )
320 for i in range(rank):
321 code.writeline(
322 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} >= valid_dim{i}_end),
323 (x_shape{i} + valid_dim{i}_start - 1) * 2 - dst_index_{i} - valid_dim{i}_start, src_index_{i})"""
324 )
326 code.newline()
327 code.writeline("if IS_REPLICATE: ")
328 with code.indent():
329 for i in range(rank):
330 code.writeline(
331 f"src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start), 0, src_index_{i})"
332 )
333 for i in range(rank):
334 end_cond = f"dst_index_{i} >= valid_dim{i}_end"
335 code.writeline(
336 f"src_index_{i} = tl.where(dim{i}_has_pad & ({end_cond}), "
337 f"x_shape{i} - 1, src_index_{i})"
338 )
340 code.newline()
341 code.writeline("if IS_CIRCULAR: ")
342 with code.indent():
343 for i in range(rank):
344 code.writeline(
345 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start),
346 dst_index_{i} + x_shape{i} - valid_dim{i}_start, src_index_{i})"""
347 )
348 for i in range(rank):
349 code.writeline(
350 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} >= valid_dim{i}_end),
351 dst_index_{i} - valid_dim{i}_end, src_index_{i})"""
352 )
354 code.newline()
356 code.writeline("src_offset = src_index_0 * in_strides0")
357 for i in range(1, rank):
358 code.writeline(f"src_offset += src_index_{i} * in_strides{i}")
360 code.writeline("load_cond = src_index_0 < x_shape0")
361 for i in range(1, rank):
362 code.writeline(f"load_cond &= src_index_{i} < x_shape{i}")
364 code.writeline("if IS_CONSTANT: ")
365 with code.indent():
366 # use explicit comparison and bitwise-and for non-scalar masks
367 code.writeline(
368 "x_val = tl.load(in0_ptr + src_offset, mask=((if_pad == 0) & load_cond), other=value)"
369 )
370 code.writeline("else: ")
371 with code.indent():
372 code.writeline(
373 "x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)"
374 )
375 code.writeline("tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)")
377 return code
380def generate_code(
381 inputs: Tuple[Any],
382 wrapper_name: str,
383 destination_passing_func_name: str,
384 kernel_name: str,
385 code: IndentedBuffer,
386) -> IndentedBuffer:
387 shape = inputs[0].shape
388 rank = len(shape)
390 # the only runtime determined factor is the rank of the task space
391 code = generate_imports(code)
392 code = generate_functional_padding_wrapper(
393 wrapper_name, destination_passing_func_name, code
394 )
395 code = generate_destination_passing_padding_wrapper(
396 rank, destination_passing_func_name, kernel_name, code
397 )
398 code = generate_pad_kernel(rank, kernel_name, code)
399 return code
402class PadFunction:
403 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
404 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
405 The generated code are written out to the cache directory (defaults to ~/.flaggems).
406 """
408 def __init__(self):
409 self.pid = os.getpid()
410 self.overloads: Mapping[str, Callable] = {}
412 def __call__(self, *args, **kwargs):
413 # note: kwargs should not be used in JITFunction directly
414 key = f"{self.arg_key(*args)}"
415 if key in self.overloads:
416 overload = self.overloads[key]
417 else:
418 # generate file & import it
419 code = IndentedBuffer()
420 code = generate_code(
421 args,
422 "_pad_wrapper",
423 "_pad_wrapper_out",
424 "_pad_jit_function",
425 code,
426 )
428 file_name = f"constant_pad_rank_{key}.py"
429 file_path = code_cache_dir() / file_name
430 write_atomic(file_path, code.getvalue())
432 # load
433 spec = importlib.util.spec_from_file_location(
434 f"_gen_module_rank_{key}",
435 file_path,
436 )
438 m = importlib.util.module_from_spec(spec)
439 # do not expose it to sys.modules
440 # sys.modules["_add_module"] = m
441 spec.loader.exec_module(m)
442 overload = getattr(m, "_pad_wrapper")
443 self.overloads[key] = overload
444 return overload(*args, **kwargs)
446 def arg_key(self, *args):
447 tensors = [item for item in args if torch.is_tensor(item)]
448 max_rank = max(item.ndim for item in tensors)
449 return max_rank
452_pad_func = PadFunction()
455def pad(self, pad, mode="constant", value=None):
456 logger.debug("GEMS CONSTANT PAD ND")
458 ndim = self.ndim
460 if value is None:
461 value = 0.0
463 pad_pairs = len(pad) // 2
465 if mode == "reflect":
466 for i in range(pad_pairs):
467 pad_l, pad_r = pad[2 * i], pad[2 * i + 1]
468 input_size = self.shape[ndim - 1 - i]
469 assert (
470 pad_l < input_size and pad_r < input_size
471 ), \
472 f"padding size should be less than the corresponding input dimension, \
473 but got padding size: {pad_l}, {pad_r}, input size: {self.shape}"
475 if mode == "circular":
476 for i in range(pad_pairs):
477 pad_l, pad_r = pad[2 * i], pad[2 * i + 1]
478 input_size = self.shape[ndim - 1 - i]
479 assert (
480 pad_l <= input_size and pad_r <= input_size
481 ), "Padding value causes wrapping around more than once."
483 out = _pad_func(self, pad, mode, float(value))
484 return out
487def constant_pad_nd(self, pad_list, value=0):
488 return pad(self, pad_list, mode="constant", value=value)