Coverage for src/flag_gems/experimental_ops/replication_pad1d.py: 0%

76 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def replication_pad1d_kernel( 

8 in_ptr, 

9 out_ptr, 

10 B: tl.constexpr, 

11 C: tl.constexpr, 

12 W_in, 

13 W_out, 

14 pad_left, 

15 in_stride_n, 

16 in_stride_c, 

17 in_stride_w, 

18 out_stride_n, 

19 out_stride_c, 

20 out_stride_w, 

21 BLOCK_SIZE: tl.constexpr, 

22): 

23 pid_w = tl.program_id(axis=0) 

24 pid_nc = tl.program_id(axis=1) 

25 

26 n = pid_nc // C 

27 c = pid_nc % C 

28 

29 off_w = pid_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

30 mask = off_w < W_out 

31 

32 # Compute clamped source indices for replication pad 

33 w_in = off_w - pad_left 

34 w_in = tl.maximum(w_in, 0) 

35 w_in = tl.minimum(w_in, W_in - 1) 

36 

37 # Base offsets 

38 base_in = n.to(tl.int64) * in_stride_n + c.to(tl.int64) * in_stride_c 

39 base_out = n.to(tl.int64) * out_stride_n + c.to(tl.int64) * out_stride_c 

40 

41 ptrs_in = in_ptr + base_in + w_in.to(tl.int64) * in_stride_w 

42 ptrs_out = out_ptr + base_out + off_w.to(tl.int64) * out_stride_w 

43 

44 x = tl.load(ptrs_in, mask=mask, other=0) 

45 tl.store(ptrs_out, x, mask=mask) 

46 

47 

48def _launch_replication_pad1d_kernel(input: torch.Tensor, padding, out: torch.Tensor): 

49 if not input.is_cuda or not out.is_cuda: 

50 raise RuntimeError("Triton kernels require CUDA tensors") 

51 

52 if isinstance(padding, torch.Tensor): 

53 padding = tuple(padding.tolist()) 

54 left, right = int(padding[0]), int(padding[1]) 

55 if left < 0 or right < 0: 

56 raise ValueError("Padding values must be non-negative for replication_pad1d") 

57 

58 dim = input.dim() 

59 if dim not in (2, 3): 

60 raise ValueError("replication_pad1d expects 2D (C, W) or 3D (N, C, W) input") 

61 

62 if dim == 3: 

63 N, C, W_in = input.shape 

64 B = N 

65 in_s_n, in_s_c, in_s_w = input.stride() 

66 out_s_n, out_s_c, out_s_w = out.stride() 

67 expected_out_shape = (N, C, W_in + left + right) 

68 else: 

69 C, W_in = input.shape 

70 B = 1 

71 N = 1 # dummy 

72 in_s_c, in_s_w = input.stride() 

73 in_s_n = 0 

74 if out.dim() == 2: 

75 out_s_c, out_s_w = out.stride() 

76 out_s_n = 0 

77 elif out.dim() == 3: 

78 out_s_n, out_s_c, out_s_w = out.stride() 

79 else: 

80 raise ValueError("Output tensor has invalid dimensions") 

81 expected_out_shape = (C, W_in + left + right) 

82 

83 W_out = W_in + left + right 

84 

85 # Validate output shape 

86 if tuple(out.shape) != expected_out_shape: 

87 raise ValueError( 

88 f"Output tensor has incorrect shape. Expected {expected_out_shape}, got {tuple(out.shape)}" 

89 ) 

90 

91 grid = (triton.cdiv(W_out, 256), B * C) 

92 replication_pad1d_kernel[grid]( 

93 input, 

94 out, 

95 B, 

96 C, 

97 W_in, 

98 W_out, 

99 left, 

100 in_s_n if dim == 3 else in_s_n, 

101 in_s_c, 

102 in_s_w, 

103 out_s_n if (dim == 3 or out.dim() == 3) else 0, 

104 out_s_c, 

105 out_s_w, 

106 BLOCK_SIZE=256, 

107 ) 

108 return out 

109 

110 

111def replication_pad1d(input: torch.Tensor, padding): 

112 if isinstance(padding, torch.Tensor): 

113 padding = tuple(padding.tolist()) 

114 left, right = int(padding[0]), int(padding[1]) 

115 if input.dim() == 3: 

116 N, C, W_in = input.shape 

117 out = torch.empty( 

118 (N, C, W_in + left + right), 

119 device=input.device, 

120 dtype=input.dtype, 

121 layout=input.layout, 

122 ) 

123 elif input.dim() == 2: 

124 C, W_in = input.shape 

125 out = torch.empty( 

126 (C, W_in + left + right), 

127 device=input.device, 

128 dtype=input.dtype, 

129 layout=input.layout, 

130 ) 

131 else: 

132 raise ValueError("replication_pad1d expects 2D (C, W) or 3D (N, C, W) input") 

133 return _launch_replication_pad1d_kernel(input, (left, right), out) 

134 

135 

136def replication_pad1d_out(input: torch.Tensor, padding, out: torch.Tensor): 

137 if isinstance(padding, torch.Tensor): 

138 padding = tuple(padding.tolist()) 

139 left, right = int(padding[0]), int(padding[1]) 

140 

141 # Validate dtype/device 

142 if out.dtype != input.dtype: 

143 raise ValueError("Output dtype must match input dtype") 

144 if out.device != input.device: 

145 raise ValueError("Output device must match input device") 

146 

147 return _launch_replication_pad1d_kernel(input, (left, right), out)