Coverage for src/flag_gems/ops/reflection_pad1d.py: 54%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
11logger = logging.getLogger(__name__)
14@triton.jit
15def reflection_pad1d_kernel(
16 in_ptr, out_ptr, B, W_in, pad_left, W_out, BLOCK_W: tl.constexpr
17):
18 pid_b = tl.program_id(axis=0)
19 pid_w = tl.program_id(axis=1)
21 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
22 mask = (offs_w < W_out) & (pid_b < B)
24 base_in = pid_b * W_in
25 base_out = pid_b * W_out
27 # Compute reflected indices
28 x = offs_w.to(tl.int32) - pad_left # shift by left pad
29 Wm1 = W_in - 1
30 p = 2 * Wm1 # period for reflection; guaranteed > 0 when this kernel is used
32 t = tl.abs(x)
33 m = t % p
34 iw = tl.where(m < W_in, m, p - m)
36 vals = tl.load(in_ptr + base_in + iw, mask=mask, other=0)
37 tl.store(out_ptr + base_out + offs_w, vals, mask=mask)
40@triton.jit
41def _copy_rows_kernel(in_ptr, out_ptr, B, W, BLOCK_W: tl.constexpr):
42 pid_b = tl.program_id(axis=0)
43 pid_w = tl.program_id(axis=1)
45 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
46 mask = (offs_w < W) & (pid_b < B)
48 base = pid_b * W
49 vals = tl.load(in_ptr + base + offs_w, mask=mask, other=0)
50 tl.store(out_ptr + base + offs_w, vals, mask=mask)
53def _launch_reflection_pad1d(input: torch.Tensor, padding, out: torch.Tensor = None):
54 if not isinstance(padding, (list, tuple)) or len(padding) != 2:
55 raise ValueError(
56 "padding must be a sequence of length 2: (pad_left, pad_right)"
57 )
58 pad_left, pad_right = int(padding[0]), int(padding[1])
59 if pad_left < 0 or pad_right < 0:
60 raise ValueError("padding values must be >= 0")
61 if input.dim() < 1:
62 raise ValueError("input must have at least 1 dimension")
64 x = input.contiguous()
65 W_in = int(x.shape[-1])
66 if W_in <= 0:
67 raise ValueError("last dimension (width) must be > 0")
69 W_out = W_in + pad_left + pad_right
70 leading_shape = x.shape[:-1]
71 B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1
73 if out is None:
74 out = torch.empty((*leading_shape, W_out), device=x.device, dtype=x.dtype)
75 else:
76 expected_shape = (*leading_shape, W_out)
77 if tuple(out.shape) != expected_shape:
78 raise ValueError(
79 f"out tensor has shape {tuple(out.shape)}, expected {expected_shape}"
80 )
81 if out.dtype != x.dtype:
82 raise ValueError(
83 f"out dtype {out.dtype} does not match input dtype {x.dtype}"
84 )
85 if out.device != x.device:
86 raise ValueError("out must be on the same device as input")
87 out = out.contiguous()
89 # No padding: just copy
90 if pad_left == 0 and pad_right == 0:
91 if W_out != W_in:
92 raise RuntimeError(
93 "Internal error: W_out should equal W_in when no padding"
94 )
95 grid = (B, triton.cdiv(W_in, 256))
96 with torch_device_fn.device(x.device):
97 _copy_rows_kernel[grid](x, out, B, W_in, BLOCK_W=256)
98 return out
100 # Validate reflection padding constraints
101 if W_in < 2:
102 raise ValueError(
103 "input width must be at least 2 for reflection padding when padding > 0"
104 )
105 if pad_left >= W_in or pad_right >= W_in:
106 raise ValueError(
107 "padding values must be less than the input width for reflection padding"
108 )
110 grid = (B, triton.cdiv(W_out, 256))
111 with torch_device_fn.device(x.device):
112 reflection_pad1d_kernel[grid](x, out, B, W_in, pad_left, W_out, BLOCK_W=256)
113 return out
116def reflection_pad1d(input: torch.Tensor, padding):
117 logger.debug("GEMS REFLECTION_PAD1D")
118 return _launch_reflection_pad1d(input, padding, out=None)
121def reflection_pad1d_out(input: torch.Tensor, padding, out: torch.Tensor):
122 logger.debug("GEMS REFLECTION_PAD1D_OUT")
123 return _launch_reflection_pad1d(input, padding, out=out)