Coverage for src/flag_gems/experimental_ops/replication_pad2d.py: 0%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def replication_pad2d_kernel( 

8 in_ptr, # *Pointer* to input tensor 

9 out_ptr, # *Pointer* to output tensor 

10 N, 

11 C, 

12 H, 

13 W, # input dimensions 

14 OH, 

15 OW, # output H and W 

16 PAD_LEFT, 

17 PAD_TOP, # padding sizes 

18 TOTAL_ELEMS, # total number of output elements 

19 BLOCK_SIZE: tl.constexpr, 

20): 

21 pid = tl.program_id(axis=0) 

22 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

23 mask = offs < TOTAL_ELEMS 

24 

25 # Cast to int64 for safe indexing 

26 offs64 = offs.to(tl.int64) 

27 

28 OW_i64 = tl.full([1], OW, dtype=tl.int64) 

29 OH_i64 = tl.full([1], OH, dtype=tl.int64) 

30 C_i64 = tl.full([1], C, dtype=tl.int64) 

31 W_i64 = tl.full([1], W, dtype=tl.int64) 

32 H_i64 = tl.full([1], H, dtype=tl.int64) 

33 PAD_LEFT_i64 = tl.full([1], PAD_LEFT, dtype=tl.int64) 

34 PAD_TOP_i64 = tl.full([1], PAD_TOP, dtype=tl.int64) 

35 

36 ow = offs64 % OW_i64 

37 tmp = offs64 // OW_i64 

38 oh = tmp % OH_i64 

39 tmp = tmp // OH_i64 

40 c = tmp % C_i64 

41 n = tmp // C_i64 

42 

43 ih = oh - PAD_TOP_i64 

44 iw = ow - PAD_LEFT_i64 

45 

46 zero = tl.full([1], 0, dtype=tl.int64) 

47 Hm1 = H_i64 - 1 

48 Wm1 = W_i64 - 1 

49 

50 ih = tl.maximum(zero, tl.minimum(Hm1, ih)) 

51 iw = tl.maximum(zero, tl.minimum(Wm1, iw)) 

52 

53 in_index = ((n * C_i64 + c) * H_i64 + ih) * W_i64 + iw 

54 out_index = offs64 

55 

56 x = tl.load(in_ptr + in_index, mask=mask) 

57 tl.store(out_ptr + out_index, x, mask=mask) 

58 

59 

60def _prepare_dims_and_out(input: torch.Tensor, padding, out: torch.Tensor | None): 

61 if not isinstance(padding, (tuple, list)) or len(padding) != 4: 

62 raise ValueError( 

63 "padding must be a sequence of 4 integers: (pad_left, pad_right, pad_top, pad_bottom)" 

64 ) 

65 pad_left, pad_right, pad_top, pad_bottom = map(int, padding) 

66 if pad_left < 0 or pad_right < 0 or pad_top < 0 or pad_bottom < 0: 

67 raise ValueError("replication_pad2d does not support negative padding") 

68 

69 if input.dim() == 4: 

70 N, C, H, W = input.shape 

71 out_shape = (N, C, H + pad_top + pad_bottom, W + pad_left + pad_right) 

72 kernel_N, kernel_C = N, C 

73 elif input.dim() == 3: 

74 C, H, W = input.shape 

75 out_shape = (C, H + pad_top + pad_bottom, W + pad_left + pad_right) 

76 kernel_N, kernel_C = 1, C 

77 else: 

78 raise ValueError( 

79 "replication_pad2d expects a 3D (C, H, W) or 4D (N, C, H, W) input" 

80 ) 

81 

82 if H <= 0 or W <= 0: 

83 raise ValueError( 

84 "Input height and width must be greater than 0 for replication padding" 

85 ) 

86 

87 if out is None: 

88 out = torch.empty(out_shape, device=input.device, dtype=input.dtype) 

89 else: 

90 if tuple(out.shape) != tuple(out_shape): 

91 raise ValueError( 

92 f"Provided out tensor has shape {tuple(out.shape)}, expected {out_shape}" 

93 ) 

94 if out.device != input.device: 

95 raise ValueError("Input and out must be on the same device") 

96 if out.dtype != input.dtype: 

97 raise ValueError("Input and out must have the same dtype") 

98 

99 return ( 

100 kernel_N, 

101 kernel_C, 

102 H, 

103 W, 

104 out.shape[-2], 

105 out.shape[-1], 

106 pad_left, 

107 pad_top, 

108 ), out 

109 

110 

111def _launch_replication_pad2d_kernel( 

112 input: torch.Tensor, out: torch.Tensor, kernel_params 

113): 

114 if not input.is_cuda or not out.is_cuda: 

115 raise ValueError("Tensors must be CUDA tensors") 

116 if not input.is_contiguous() or not out.is_contiguous(): 

117 raise ValueError("Only contiguous tensors are supported") 

118 

119 N, C, H, W, OH, OW, pad_left, pad_top = kernel_params 

120 total_elems = out.numel() 

121 if total_elems == 0: 

122 return out 

123 

124 BLOCK_SIZE = 1024 

125 grid = (triton.cdiv(total_elems, BLOCK_SIZE),) 

126 

127 replication_pad2d_kernel[grid]( 

128 input, 

129 out, 

130 N, 

131 C, 

132 H, 

133 W, 

134 OH, 

135 OW, 

136 pad_left, 

137 pad_top, 

138 total_elems, 

139 BLOCK_SIZE=BLOCK_SIZE, 

140 ) 

141 return out 

142 

143 

144def replication_pad2d(input: torch.Tensor, padding): 

145 kernel_params, out = _prepare_dims_and_out(input, padding, out=None) 

146 return _launch_replication_pad2d_kernel(input, out, kernel_params) 

147 

148 

149def replication_pad2d_out(input: torch.Tensor, padding, out: torch.Tensor): 

150 kernel_params, out = _prepare_dims_and_out(input, padding, out=out) 

151 return _launch_replication_pad2d_kernel(input, out, kernel_params)