Coverage for src/flag_gems/ops/replication_pad1d.py: 66%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def replication_pad1d_kernel( 

15 in_ptr, 

16 out_ptr, 

17 B: tl.constexpr, 

18 C: tl.constexpr, 

19 W_in, 

20 W_out, 

21 pad_left, 

22 in_stride_n, 

23 in_stride_c, 

24 in_stride_w, 

25 out_stride_n, 

26 out_stride_c, 

27 out_stride_w, 

28 BLOCK_SIZE: tl.constexpr, 

29): 

30 pid_w = tl.program_id(axis=0) 

31 pid_nc = tl.program_id(axis=1) 

32 

33 n = pid_nc // C 

34 c = pid_nc % C 

35 

36 off_w = pid_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

37 mask = off_w < W_out 

38 

39 # Compute clamped source indices for replication pad 

40 w_in = off_w - pad_left 

41 w_in = tl.maximum(w_in, 0) 

42 w_in = tl.minimum(w_in, W_in - 1) 

43 

44 # Base offsets 

45 base_in = n.to(tl.int64) * in_stride_n + c.to(tl.int64) * in_stride_c 

46 base_out = n.to(tl.int64) * out_stride_n + c.to(tl.int64) * out_stride_c 

47 

48 ptrs_in = in_ptr + base_in + w_in.to(tl.int64) * in_stride_w 

49 ptrs_out = out_ptr + base_out + off_w.to(tl.int64) * out_stride_w 

50 

51 x = tl.load(ptrs_in, mask=mask, other=0) 

52 tl.store(ptrs_out, x, mask=mask) 

53 

54 

55def _launch_replication_pad1d_kernel(input: torch.Tensor, padding, out: torch.Tensor): 

56 if isinstance(padding, torch.Tensor): 

57 padding = tuple(padding.tolist()) 

58 left, right = int(padding[0]), int(padding[1]) 

59 if left < 0 or right < 0: 

60 raise ValueError("Padding values must be non-negative for replication_pad1d") 

61 

62 dim = input.dim() 

63 if dim not in (2, 3): 

64 raise ValueError("replication_pad1d expects 2D (C, W) or 3D (N, C, W) input") 

65 

66 if dim == 3: 

67 N, C, W_in = input.shape 

68 B = N 

69 in_s_n, in_s_c, in_s_w = input.stride() 

70 out_s_n, out_s_c, out_s_w = out.stride() 

71 expected_out_shape = (N, C, W_in + left + right) 

72 else: 

73 C, W_in = input.shape 

74 B = 1 

75 in_s_c, in_s_w = input.stride() 

76 in_s_n = 0 

77 if out.dim() == 2: 

78 out_s_c, out_s_w = out.stride() 

79 out_s_n = 0 

80 elif out.dim() == 3: 

81 out_s_n, out_s_c, out_s_w = out.stride() 

82 else: 

83 raise ValueError("Output tensor has invalid dimensions") 

84 expected_out_shape = (C, W_in + left + right) 

85 

86 W_out = W_in + left + right 

87 

88 # Validate output shape 

89 if tuple(out.shape) != expected_out_shape: 

90 raise ValueError( 

91 f"Output tensor has incorrect shape. Expected {expected_out_shape}, got {tuple(out.shape)}" 

92 ) 

93 

94 grid = (triton.cdiv(W_out, 256), B * C) 

95 with torch_device_fn.device(input.device): 

96 replication_pad1d_kernel[grid]( 

97 input, 

98 out, 

99 B, 

100 C, 

101 W_in, 

102 W_out, 

103 left, 

104 in_s_n if dim == 3 else in_s_n, 

105 in_s_c, 

106 in_s_w, 

107 out_s_n if (dim == 3 or out.dim() == 3) else 0, 

108 out_s_c, 

109 out_s_w, 

110 BLOCK_SIZE=256, 

111 ) 

112 return out 

113 

114 

115def replication_pad1d(input: torch.Tensor, padding): 

116 logger.debug("GEMS REPLICATION_PAD1D") 

117 if isinstance(padding, torch.Tensor): 

118 padding = tuple(padding.tolist()) 

119 left, right = int(padding[0]), int(padding[1]) 

120 if input.dim() == 3: 

121 N, C, W_in = input.shape 

122 out = torch.empty( 

123 (N, C, W_in + left + right), 

124 device=input.device, 

125 dtype=input.dtype, 

126 layout=input.layout, 

127 ) 

128 elif input.dim() == 2: 

129 C, W_in = input.shape 

130 out = torch.empty( 

131 (C, W_in + left + right), 

132 device=input.device, 

133 dtype=input.dtype, 

134 layout=input.layout, 

135 ) 

136 else: 

137 raise ValueError("replication_pad1d expects 2D (C, W) or 3D (N, C, W) input") 

138 return _launch_replication_pad1d_kernel(input, (left, right), out) 

139 

140 

141def replication_pad1d_out(input: torch.Tensor, padding, out: torch.Tensor): 

142 logger.debug("GEMS REPLICATION_PAD1D_OUT") 

143 if isinstance(padding, torch.Tensor): 

144 padding = tuple(padding.tolist()) 

145 left, right = int(padding[0]), int(padding[1]) 

146 

147 # Validate dtype/device 

148 if out.dtype != input.dtype: 

149 raise ValueError("Output dtype must match input dtype") 

150 if out.device != input.device: 

151 raise ValueError("Output device must match input device") 

152 

153 return _launch_replication_pad1d_kernel(input, (left, right), out)