Coverage for src/flag_gems/ops/pad.py: 99%
274 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11logger = logging.getLogger(__name__)
14# --------------------------- padding wrapper genration -----------------------------------
15def parameter_for_wrapper() -> str:
16 """Generate parameter declaration with type annotation for wrapper function.
17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
18 """
19 parameters: List[str] = []
21 parameters.append("in0")
22 parameters.append("pad")
23 parameters.append("mode")
24 parameters.append("value=0")
25 return ", ".join(parameters)
28def parameter_for_wrapper_out() -> str:
29 """Generate parameter declaration with type annotation for wrapper function.
30 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
31 """
32 parameters: List[str] = []
34 parameters.append("in0")
35 parameters.append("out0")
36 parameters.append("dst_shape")
37 parameters.append("pad_before")
38 parameters.append("pad_after")
39 parameters.append("mode")
40 parameters.append("value=0")
42 return ", ".join(parameters)
45def parameter_ref_for_wrapper() -> str:
46 """Generate parameter reference for wrapper function.
47 Example: in0, val0, out0, out0_offset
48 """
49 parameters: List[str] = []
51 parameters.append("in0")
52 parameters.append("out0")
53 parameters.append("dst_shape")
54 parameters.append("pad_before")
55 parameters.append("pad_after")
56 parameters.append("mode")
57 parameters.append("value")
59 return ", ".join(parameters)
62def output_ref_for_wrapper() -> str:
63 return "out0"
66def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
67 code.writeline("import math")
68 code.writeline("import torch")
69 code.writeline("import triton")
70 code.writeline("from triton import language as tl")
71 code.newline()
72 code.writeline("from flag_gems.utils.libentry import libentry")
73 code.writeline("from flag_gems.runtime import torch_device_fn")
74 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
75 code.writeline("from flag_gems.utils.type_utils import type_promotion")
76 code.newline()
77 code.newline()
78 return code
81def generate_functional_padding_wrapper(
82 wrapper_name: str,
83 destination_passing_func_name: str,
84 code: IndentedBuffer,
85) -> IndentedBuffer:
86 # wrapper signature
87 parameters: str = parameter_for_wrapper()
88 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
89 code.writeline(wrapper_signature)
91 with code.indent():
92 code.writeline("ndim = in0.ndim")
93 code.writeline("pad_size = len(pad)")
94 code.writeline("assert pad_size % 2 == 0")
95 code.newline()
96 code.writeline("pad_before = [0 for _ in range(ndim)]")
97 code.writeline("pad_after = [0 for _ in range(ndim)]")
98 code.newline()
99 code.writeline("pad_pair = pad_size // 2 ")
100 code.writeline("for i in range(pad_pair): ")
101 with code.indent():
102 code.writeline("pad_before[ndim - i - 1] = pad[2 * i]")
103 code.writeline("pad_after[ndim - i - 1] = pad[2 * i + 1]")
104 code.writeline("dst_shape = list(in0.shape)")
105 code.writeline("for i in range(ndim): ")
106 with code.indent():
107 code.writeline("dst_shape[i] += pad_before[i] + pad_after[i]")
109 code.writeline(
110 ("out0 = torch.empty(dst_shape, device=in0.device, dtype=in0.dtype)")
111 )
113 # call destination_passing_func
114 output_names: str = output_ref_for_wrapper()
115 call_str = (
116 f"{output_names} = {destination_passing_func_name}"
117 f"({parameter_ref_for_wrapper()})"
118 )
119 code.writeline(call_str)
121 return_str = "return out0"
122 code.writeline(return_str)
123 code.newline()
124 code.newline()
126 return code
129def generate_destination_passing_padding_wrapper(
130 rank: int,
131 wrapper_name: str,
132 kernel_name: str,
133 code: IndentedBuffer,
134) -> IndentedBuffer:
135 # wrapper signature
136 parameters: str = parameter_for_wrapper_out()
138 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
139 code.writeline(wrapper_signature)
141 with code.indent():
142 # docstring
143 code.writeline("BLOCK_SIZE = 256")
144 code.writeline("grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)")
145 code.newline()
147 code.writeline("x_shape = in0.shape")
148 code.writeline("in_strides0 = in0.stride()")
149 code.writeline("out_strides = out0.stride()")
151 # input strides for each input tensor w.r.t. the task index space
152 if rank > 0:
153 code.writeline("# strides of each tensor argument w.r.t the task space")
154 for i in range(rank):
155 code.writeline(f"valid_dim{i}_start = pad_before[{i}]")
157 code.writeline(f"valid_dim{i}_end = dst_shape[{i}] - pad_after[{i}]")
159 code.newline()
161 code.writeline("IS_CONSTANT = mode == 'constant'")
162 code.writeline("IS_REFLECT = mode == 'reflect'")
163 code.writeline("IS_REPLICATE = mode == 'replicate'")
164 code.writeline("IS_CIRCULAR = mode == 'circular'")
166 code.newline()
168 # grid
169 code.writeline("# kernel launch")
171 # launch kernel
172 code.writeline("with torch_device_fn.device(in0.device):")
173 with code.indent():
174 kernel_launch: str = f"{kernel_name}[grid]("
175 code.writeline(kernel_launch)
177 with code.indent():
178 code.writeline("in0, out0, ")
180 if rank > 0:
181 s = ", ".join(f"x_shape[{j}]" for j in range(rank))
182 code.writeline(f"{s}, # shape for x")
184 s = ", ".join(f"in_strides0[{j}]" for j in range(rank))
185 code.writeline(f"{s}, # stride for x")
187 s = ", ".join(f"out_strides[{j}]" for j in range(rank))
188 code.writeline(f"{s}, # stride for out")
190 s = ", ".join(f"valid_dim{j}_start" for j in range(rank))
191 code.writeline(f"{s}, # valid dim start")
193 s = ", ".join(f"valid_dim{j}_end" for j in range(rank))
194 code.writeline(f"{s}, # valid dim end")
196 code.writeline("in0.numel(), ")
197 code.writeline("out0.numel(), ")
198 code.writeline("value, ")
199 code.writeline("IS_CONSTANT, ")
200 code.writeline("IS_REFLECT, ")
201 code.writeline("IS_REPLICATE, ")
202 code.writeline("IS_CIRCULAR, ")
203 code.writeline("BLOCK_SIZE, ")
204 code.writeline(")")
206 code.writeline("return out0")
207 code.newline()
208 code.newline()
209 return code
212def generate_pad_kernel(
213 rank: int,
214 kernel_name: str,
215 code: IndentedBuffer,
216) -> IndentedBuffer:
217 # make the inlined function visible in the context
218 code.newline()
220 # the decorators
221 code.writeline("@libentry()")
222 non_specialize_arg_names = ["value"]
223 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
225 # signature
226 code.writeline(f"def {kernel_name}(")
227 with code.indent():
228 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
230 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
232 if rank > 0:
233 # shape for inputs
234 shape_args = ", ".join(f"x_shape{j}: int" for j in range(rank))
235 code.writeline(f"{shape_args}, # shape for x")
237 # shape for inputs
238 stride_args = ", ".join(f"in_strides{j}: int" for j in range(rank))
239 code.writeline(f"{stride_args}, # stride for x")
241 # shape for inputs
242 stride_args = ", ".join(f"out_strides{j}: int" for j in range(rank))
243 code.writeline(f"{stride_args}, # stride for out")
245 # shape for inputs
246 stride_args = ", ".join(f"valid_dim{j}_start: int" for j in range(rank))
247 code.writeline(f"{stride_args}, # valid dim start")
249 # shape for inputs
250 stride_args = ", ".join(f"valid_dim{j}_end: int" for j in range(rank))
251 code.writeline(f"{stride_args}, # valid dim end")
253 code.writeline("in_elem_cnt: tl.constexpr, ")
254 code.writeline("out_elem_cnt: tl.constexpr, ")
255 code.writeline("value, # padding value")
256 code.writeline("IS_CONSTANT: tl.constexpr, ")
257 code.writeline("IS_REFLECT: tl.constexpr, ")
258 code.writeline("IS_REPLICATE: tl.constexpr, ")
259 code.writeline("IS_CIRCULAR: tl.constexpr, ")
260 code.writeline("BLOCK_SIZE: tl.constexpr, ")
262 code.writeline("):")
264 with code.indent():
265 code.writeline("pid = tle.program_id(0)")
266 code.writeline("block_offset = pid * BLOCK_SIZE")
267 code.writeline("offset = block_offset + tl.arange(0, BLOCK_SIZE)")
268 code.newline()
270 code.writeline("remaining = offset ")
271 for i in range(rank):
272 code.writeline(f"idx = remaining // out_strides{i}")
273 code.writeline(f"dst_index_{i} = idx")
274 code.writeline(f"remaining = remaining - idx * out_strides{i}")
275 code.newline()
277 code.writeline("if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)")
278 code.writeline("if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)")
280 code.writeline(
281 "cond = ((dst_index_0 >= valid_dim0_start) & (dst_index_0 < valid_dim0_end))"
282 )
284 for i in range(1, rank):
285 code.writeline(
286 f"cond &= ((dst_index_{i} >= valid_dim{i}_start) & (dst_index_{i} < valid_dim{i}_end))"
287 )
289 code.writeline(
290 "if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)"
291 )
293 for i in range(rank):
294 code.writeline(f"src_index_{i} = dst_index_{i} - valid_dim{i}_start ")
296 for i in range(rank):
297 code.writeline(
298 f"src_index_{i} = tl.where(src_index_{i} < 0, 0, src_index_{i})"
299 )
301 code.newline()
302 code.writeline("if IS_REFLECT: ")
303 with code.indent():
304 for i in range(rank):
305 code.writeline(
306 f"""src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start,
307 valid_dim{i}_start - dst_index_{i}, src_index_{i})"""
308 )
309 for i in range(rank):
310 code.writeline(
311 f"""src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end,
312 (x_shape{i} + valid_dim{i}_start - 1) * 2 - dst_index_{i} - valid_dim{i}_start, src_index_{i})"""
313 )
315 code.newline()
316 code.writeline("if IS_REPLICATE: ")
317 with code.indent():
318 for i in range(rank):
319 code.writeline(
320 f"src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start, 0, src_index_{i})"
321 )
322 for i in range(rank):
323 code.writeline(
324 f"src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end, x_shape{i} - 1, src_index_{i})"
325 )
327 code.newline()
328 code.writeline("if IS_CIRCULAR: ")
329 with code.indent():
330 for i in range(rank):
331 code.writeline(
332 f"""src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start,
333 dst_index_{i} + x_shape{i} - valid_dim{i}_start, src_index_{i})"""
334 )
335 for i in range(rank):
336 code.writeline(
337 f"""src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end,
338 dst_index_{i} - valid_dim{i}_end, src_index_{i})"""
339 )
341 code.newline()
343 code.writeline("src_offset = src_index_0 * in_strides0")
344 for i in range(1, rank):
345 code.writeline(f"src_offset += src_index_{i} * in_strides{i}")
347 code.writeline(f"load_cond = src_index_{i} < x_shape{i}")
348 for i in range(1, rank):
349 code.writeline(f"load_cond &= src_index_{i} < x_shape{i}")
351 code.writeline("if IS_CONSTANT: ")
352 with code.indent():
353 # use explicit comparison and bitwise-and for non-scalar masks
354 code.writeline(
355 "x_val = tl.load(in0_ptr + src_offset, mask=((if_pad == 0) & load_cond), other=value)"
356 )
357 code.writeline("else: ")
358 with code.indent():
359 code.writeline(
360 "x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)"
361 )
362 code.writeline("tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)")
364 return code
367def generate_code(
368 inputs: Tuple[Any],
369 wrapper_name: str,
370 destination_passing_func_name: str,
371 kernel_name: str,
372 code: IndentedBuffer,
373) -> IndentedBuffer:
374 shape = inputs[0].shape
375 rank = len(shape)
377 # the only runtime determined factor is the rank of the task space
378 code = generate_imports(code)
379 code = generate_functional_padding_wrapper(
380 wrapper_name, destination_passing_func_name, code
381 )
382 code = generate_destination_passing_padding_wrapper(
383 rank, destination_passing_func_name, kernel_name, code
384 )
385 code = generate_pad_kernel(rank, kernel_name, code)
386 return code
389class PadFunction:
390 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
391 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
392 The generated code are written out to the cache directory (defaults to ~/.flaggems).
393 """
395 def __init__(self):
396 self.pid = os.getpid()
397 self.overloads: Mapping[str, Callable] = {}
399 def __call__(self, *args, **kwargs):
400 # note: kwargs should not be used in JITFunction directly
401 key = f"{self.arg_key(*args)}"
402 if key in self.overloads:
403 overload = self.overloads[key]
404 else:
405 # generate file & import it
406 code = IndentedBuffer()
407 code = generate_code(
408 args,
409 "_pad_wrapper",
410 "_pad_wrapper_out",
411 "_pad_jit_function",
412 code,
413 )
415 file_name = f"constant_pad_rank_{key}.py"
416 file_path = code_cache_dir() / file_name
417 write_atomic(file_path, code.getvalue())
419 # load
420 spec = importlib.util.spec_from_file_location(
421 f"_gen_module_rank_{key}",
422 file_path,
423 )
425 m = importlib.util.module_from_spec(spec)
426 # do not expose it to sys.modules
427 # sys.modules["_add_module"] = m
428 spec.loader.exec_module(m)
429 overload = getattr(m, "_pad_wrapper")
430 self.overloads[key] = overload
431 return overload(*args, **kwargs)
433 def arg_key(self, *args):
434 tensors = [item for item in args if torch.is_tensor(item)]
435 max_rank = max(item.ndim for item in tensors)
436 return max_rank
439_pad_func = PadFunction()
442def pad(self, pad, mode="constant", value=None):
443 logger.debug("GEMS CONSTANT PAD ND")
445 ndim = self.ndim
447 if value is None:
448 value = 0.0
450 if mode == "reflect":
451 ndim //= 2
452 assert (
453 len(pad) == 2 * ndim
454 ), f"padding size is expected to be {2 * ndim}, but got {len(pad)}"
456 for i in range(ndim):
457 pad_l, pad_r = pad[2 * i], pad[2 * i + 1]
458 input_l, input_r = (
459 self.shape[ndim - (2 * i + 1) - 1],
460 self.shape[ndim - (2 * i + 1)],
461 )
462 assert (
463 pad_l < input_l and pad_r < input_r
464 ), \
465 f"padding size should be less than the corresponding input dimension, \
466 but got padding size: {pad_l}, {pad_r}, input size: {self.shape}"
468 if mode == "circular":
469 ndim //= 2
470 assert (
471 len(pad) == 2 * ndim
472 ), f"padding size is expected to be {2 * ndim}, but got {len(pad)}"
473 for i in range(ndim):
474 pad_l, pad_r = pad[2 * i], pad[2 * i + 1]
475 input_size = self.shape[ndim - i - 1]
476 assert (
477 pad_l <= input_size and pad_r <= input_size
478 ), "Padding value causes wrapping around more than once."
480 out = _pad_func(self, pad, mode, float(value))
481 return out
484def constant_pad_nd(self, pad_list, value=0):
485 return pad(self, pad_list, mode="constant", value=value)