Coverage for src/flag_gems/experimental_ops/reflection_pad1d.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import math 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7 

8@triton.jit 

9def _reflection_pad1d_kernel( 

10 in_ptr, out_ptr, B, W_in, pad_left, W_out, BLOCK_W: tl.constexpr 

11): 

12 pid_b = tl.program_id(axis=0) 

13 pid_w = tl.program_id(axis=1) 

14 

15 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

16 mask = (offs_w < W_out) & (pid_b < B) 

17 

18 base_in = pid_b * W_in 

19 base_out = pid_b * W_out 

20 

21 # Compute reflected indices 

22 x = offs_w.to(tl.int32) - pad_left # shift by left pad 

23 Wm1 = W_in - 1 

24 p = 2 * Wm1 # period for reflection; guaranteed > 0 when this kernel is used 

25 

26 t = tl.abs(x) 

27 m = t % p 

28 iw = tl.where(m < W_in, m, p - m) 

29 

30 vals = tl.load(in_ptr + base_in + iw, mask=mask, other=0) 

31 tl.store(out_ptr + base_out + offs_w, vals, mask=mask) 

32 

33 

34@triton.jit 

35def _copy_rows_kernel(in_ptr, out_ptr, B, W, BLOCK_W: tl.constexpr): 

36 pid_b = tl.program_id(axis=0) 

37 pid_w = tl.program_id(axis=1) 

38 

39 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

40 mask = (offs_w < W) & (pid_b < B) 

41 

42 base = pid_b * W 

43 vals = tl.load(in_ptr + base + offs_w, mask=mask, other=0) 

44 tl.store(out_ptr + base + offs_w, vals, mask=mask) 

45 

46 

47def _launch_reflection_pad1d(input: torch.Tensor, padding, out: torch.Tensor = None): 

48 if not isinstance(padding, (list, tuple)) or len(padding) != 2: 

49 raise ValueError( 

50 "padding must be a sequence of length 2: (pad_left, pad_right)" 

51 ) 

52 pad_left, pad_right = int(padding[0]), int(padding[1]) 

53 if pad_left < 0 or pad_right < 0: 

54 raise ValueError("padding values must be >= 0") 

55 if input.dim() < 1: 

56 raise ValueError("input must have at least 1 dimension") 

57 if not input.is_cuda: 

58 raise ValueError("input must be a CUDA tensor") 

59 

60 x = input.contiguous() 

61 W_in = int(x.shape[-1]) 

62 if W_in <= 0: 

63 raise ValueError("last dimension (width) must be > 0") 

64 

65 W_out = W_in + pad_left + pad_right 

66 leading_shape = x.shape[:-1] 

67 B = int(math.prod(leading_shape)) if len(leading_shape) > 0 else 1 

68 

69 if out is None: 

70 out = torch.empty((*leading_shape, W_out), device=x.device, dtype=x.dtype) 

71 else: 

72 if not out.is_cuda: 

73 raise ValueError("out must be a CUDA tensor") 

74 expected_shape = (*leading_shape, W_out) 

75 if tuple(out.shape) != expected_shape: 

76 raise ValueError( 

77 f"out tensor has shape {tuple(out.shape)}, expected {expected_shape}" 

78 ) 

79 if out.dtype != x.dtype: 

80 raise ValueError( 

81 f"out dtype {out.dtype} does not match input dtype {x.dtype}" 

82 ) 

83 if out.device != x.device: 

84 raise ValueError("out must be on the same device as input") 

85 out = out.contiguous() 

86 

87 # No padding: just copy 

88 if pad_left == 0 and pad_right == 0: 

89 if W_out != W_in: 

90 raise RuntimeError( 

91 "Internal error: W_out should equal W_in when no padding" 

92 ) 

93 grid = (B, triton.cdiv(W_in, 256)) 

94 _copy_rows_kernel[grid](x, out, B, W_in, BLOCK_W=256) 

95 return out 

96 

97 # Validate reflection padding constraints 

98 if W_in < 2: 

99 raise ValueError( 

100 "input width must be at least 2 for reflection padding when padding > 0" 

101 ) 

102 if pad_left >= W_in or pad_right >= W_in: 

103 raise ValueError( 

104 "padding values must be less than the input width for reflection padding" 

105 ) 

106 

107 grid = (B, triton.cdiv(W_out, 256)) 

108 _reflection_pad1d_kernel[grid](x, out, B, W_in, pad_left, W_out, BLOCK_W=256) 

109 return out 

110 

111 

112def reflection_pad1d(input: torch.Tensor, padding): 

113 return _launch_reflection_pad1d(input, padding, out=None) 

114 

115 

116def reflection_pad1d_out(input: torch.Tensor, padding, out: torch.Tensor): 

117 return _launch_reflection_pad1d(input, padding, out=out)