Coverage for src/flag_gems/experimental_ops/reflection_pad1d.py: 0%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import math
3import torch
4import triton
5import triton.language as tl
8@triton.jit
9def _reflection_pad1d_kernel(
10 in_ptr, out_ptr, B, W_in, pad_left, W_out, BLOCK_W: tl.constexpr
11):
12 pid_b = tl.program_id(axis=0)
13 pid_w = tl.program_id(axis=1)
15 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
16 mask = (offs_w < W_out) & (pid_b < B)
18 base_in = pid_b * W_in
19 base_out = pid_b * W_out
21 # Compute reflected indices
22 x = offs_w.to(tl.int32) - pad_left # shift by left pad
23 Wm1 = W_in - 1
24 p = 2 * Wm1 # period for reflection; guaranteed > 0 when this kernel is used
26 t = tl.abs(x)
27 m = t % p
28 iw = tl.where(m < W_in, m, p - m)
30 vals = tl.load(in_ptr + base_in + iw, mask=mask, other=0)
31 tl.store(out_ptr + base_out + offs_w, vals, mask=mask)
34@triton.jit
35def _copy_rows_kernel(in_ptr, out_ptr, B, W, BLOCK_W: tl.constexpr):
36 pid_b = tl.program_id(axis=0)
37 pid_w = tl.program_id(axis=1)
39 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
40 mask = (offs_w < W) & (pid_b < B)
42 base = pid_b * W
43 vals = tl.load(in_ptr + base + offs_w, mask=mask, other=0)
44 tl.store(out_ptr + base + offs_w, vals, mask=mask)
47def _launch_reflection_pad1d(input: torch.Tensor, padding, out: torch.Tensor = None):
48 if not isinstance(padding, (list, tuple)) or len(padding) != 2:
49 raise ValueError(
50 "padding must be a sequence of length 2: (pad_left, pad_right)"
51 )
52 pad_left, pad_right = int(padding[0]), int(padding[1])
53 if pad_left < 0 or pad_right < 0:
54 raise ValueError("padding values must be >= 0")
55 if input.dim() < 1:
56 raise ValueError("input must have at least 1 dimension")
57 if not input.is_cuda:
58 raise ValueError("input must be a CUDA tensor")
60 x = input.contiguous()
61 W_in = int(x.shape[-1])
62 if W_in <= 0:
63 raise ValueError("last dimension (width) must be > 0")
65 W_out = W_in + pad_left + pad_right
66 leading_shape = x.shape[:-1]
67 B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1
69 if out is None:
70 out = torch.empty((*leading_shape, W_out), device=x.device, dtype=x.dtype)
71 else:
72 if not out.is_cuda:
73 raise ValueError("out must be a CUDA tensor")
74 expected_shape = (*leading_shape, W_out)
75 if tuple(out.shape) != expected_shape:
76 raise ValueError(
77 f"out tensor has shape {tuple(out.shape)}, expected {expected_shape}"
78 )
79 if out.dtype != x.dtype:
80 raise ValueError(
81 f"out dtype {out.dtype} does not match input dtype {x.dtype}"
82 )
83 if out.device != x.device:
84 raise ValueError("out must be on the same device as input")
85 out = out.contiguous()
87 # No padding: just copy
88 if pad_left == 0 and pad_right == 0:
89 if W_out != W_in:
90 raise RuntimeError(
91 "Internal error: W_out should equal W_in when no padding"
92 )
93 grid = (B, triton.cdiv(W_in, 256))
94 _copy_rows_kernel[grid](x, out, B, W_in, BLOCK_W=256)
95 return out
97 # Validate reflection padding constraints
98 if W_in < 2:
99 raise ValueError(
100 "input width must be at least 2 for reflection padding when padding > 0"
101 )
102 if pad_left >= W_in or pad_right >= W_in:
103 raise ValueError(
104 "padding values must be less than the input width for reflection padding"
105 )
107 grid = (B, triton.cdiv(W_out, 256))
108 _reflection_pad1d_kernel[grid](x, out, B, W_in, pad_left, W_out, BLOCK_W=256)
109 return out
112def reflection_pad1d(input: torch.Tensor, padding):
113 return _launch_reflection_pad1d(input, padding, out=None)
116def reflection_pad1d_out(input: torch.Tensor, padding, out: torch.Tensor):
117 return _launch_reflection_pad1d(input, padding, out=out)