Coverage for src/flag_gems/experimental_ops/_adaptive_avg_pool3d.py: 0%

114 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def adaptive_avg_pool3d_kernel( 

8 in_ptr, 

9 out_ptr, 

10 N, 

11 C, 

12 D_in, 

13 H_in, 

14 W_in, 

15 D_out, 

16 H_out, 

17 W_out, 

18 stride_in_n, 

19 stride_in_c, 

20 stride_in_d, 

21 stride_in_h, 

22 stride_in_w, 

23 stride_out_n, 

24 stride_out_c, 

25 stride_out_d, 

26 stride_out_h, 

27 stride_out_w, 

28): 

29 pid = tl.program_id(axis=0) 

30 

31 # Unravel pid -> (n, c, d_o, h_o, w_o) 

32 W_out_i64 = tl.full((), W_out, tl.int64) 

33 H_out_i64 = tl.full((), H_out, tl.int64) 

34 D_out_i64 = tl.full((), D_out, tl.int64) 

35 C_i64 = tl.full((), C, tl.int64) 

36 

37 idx = tl.cast(pid, tl.int64) 

38 w_o = idx % W_out_i64 

39 idx = idx // W_out_i64 

40 h_o = idx % H_out_i64 

41 idx = idx // H_out_i64 

42 d_o = idx % D_out_i64 

43 idx = idx // D_out_i64 

44 c = idx % C_i64 

45 n = idx // C_i64 

46 

47 # Compute start/end indices for each dimension (integer arithmetic) 

48 D_in_i64 = tl.full((), D_in, tl.int64) 

49 H_in_i64 = tl.full((), H_in, tl.int64) 

50 W_in_i64 = tl.full((), W_in, tl.int64) 

51 

52 d0 = (d_o * D_in_i64) // D_out_i64 

53 d1 = ((d_o + 1) * D_in_i64 + D_out_i64 - 1) // D_out_i64 

54 h0 = (h_o * H_in_i64) // H_out_i64 

55 h1 = ((h_o + 1) * H_in_i64 + H_out_i64 - 1) // H_out_i64 

56 w0 = (w_o * W_in_i64) // W_out_i64 

57 w1 = ((w_o + 1) * W_in_i64 + W_out_i64 - 1) // W_out_i64 

58 

59 dd = d1 - d0 

60 hh = h1 - h0 

61 ww = w1 - w0 

62 denom = dd * hh * ww 

63 

64 # Base offsets and strides (int64) 

65 stride_in_n_i64 = tl.full((), stride_in_n, tl.int64) 

66 stride_in_c_i64 = tl.full((), stride_in_c, tl.int64) 

67 stride_in_d_i64 = tl.full((), stride_in_d, tl.int64) 

68 stride_in_h_i64 = tl.full((), stride_in_h, tl.int64) 

69 stride_in_w_i64 = tl.full((), stride_in_w, tl.int64) 

70 

71 stride_out_n_i64 = tl.full((), stride_out_n, tl.int64) 

72 stride_out_c_i64 = tl.full((), stride_out_c, tl.int64) 

73 stride_out_d_i64 = tl.full((), stride_out_d, tl.int64) 

74 stride_out_h_i64 = tl.full((), stride_out_h, tl.int64) 

75 stride_out_w_i64 = tl.full((), stride_out_w, tl.int64) 

76 

77 base_nc = n * stride_in_n_i64 + c * stride_in_c_i64 

78 

79 acc = tl.zeros((), dtype=tl.float32) 

80 

81 di = d0 

82 while di < d1: 

83 hi = h0 

84 while hi < h1: 

85 wi = w0 

86 while wi < w1: 

87 in_idx = ( 

88 base_nc 

89 + di * stride_in_d_i64 

90 + hi * stride_in_h_i64 

91 + wi * stride_in_w_i64 

92 ) 

93 val = tl.load(in_ptr + in_idx) 

94 acc += tl.cast(val, tl.float32) 

95 wi += 1 

96 hi += 1 

97 di += 1 

98 

99 denom_f = tl.cast(denom, tl.float32) 

100 out_val = acc / denom_f 

101 

102 out_idx = ( 

103 n * stride_out_n_i64 

104 + c * stride_out_c_i64 

105 + d_o * stride_out_d_i64 

106 + h_o * stride_out_h_i64 

107 + w_o * stride_out_w_i64 

108 ) 

109 tl.store(out_ptr + out_idx, out_val) 

110 

111 

112def _normalize_output_size_3d(output_size): 

113 if isinstance(output_size, torch.Size): 

114 output_size = tuple(output_size) 

115 if isinstance(output_size, (list, tuple)): 

116 if len(output_size) != 3: 

117 raise ValueError( 

118 "output_size for _adaptive_avg_pool3d must have 3 elements (D_out, H_out, W_out)" 

119 ) 

120 return tuple(int(x) for x in output_size) 

121 raise TypeError("output_size must be a sequence of three integers") 

122 

123 

124def _prepare_5d_input(t): 

125 if t.dim() == 5: 

126 return t, False 

127 if t.dim() == 4: 

128 return t.unsqueeze(0), True # add N=1 

129 raise ValueError( 

130 "input for _adaptive_avg_pool3d must be 4D (C,D,H,W) or 5D (N,C,D,H,W)" 

131 ) 

132 

133 

134def _launch_adaptive_avg_pool3d_kernel(x, out): 

135 assert x.is_cuda and out.is_cuda, "Tensors must be CUDA tensors" 

136 N, C, D_in, H_in, W_in = x.shape 

137 D_out, H_out, W_out = out.shape[-3], out.shape[-2], out.shape[-1] 

138 

139 stride_in_n, stride_in_c, stride_in_d, stride_in_h, stride_in_w = x.stride() 

140 stride_out_n, stride_out_c, stride_out_d, stride_out_h, stride_out_w = out.stride() 

141 

142 total = N * C * D_out * H_out * W_out 

143 if total == 0: 

144 return 

145 

146 grid = (total,) 

147 adaptive_avg_pool3d_kernel[grid]( 

148 x, 

149 out, 

150 N, 

151 C, 

152 D_in, 

153 H_in, 

154 W_in, 

155 D_out, 

156 H_out, 

157 W_out, 

158 stride_in_n, 

159 stride_in_c, 

160 stride_in_d, 

161 stride_in_h, 

162 stride_in_w, 

163 stride_out_n, 

164 stride_out_c, 

165 stride_out_d, 

166 stride_out_h, 

167 stride_out_w, 

168 num_warps=4, 

169 ) 

170 

171 

172def _adaptive_avg_pool3d(input: torch.Tensor, output_size): 

173 x5d, squeezed = _prepare_5d_input(input) 

174 D_out, H_out, W_out = _normalize_output_size_3d(output_size) 

175 

176 N, C, D_in, H_in, W_in = x5d.shape 

177 out_shape_5d = (N, C, D_out, H_out, W_out) 

178 out5d = torch.empty( 

179 out_shape_5d, device=x5d.device, dtype=x5d.dtype, layout=x5d.layout 

180 ) 

181 

182 _launch_adaptive_avg_pool3d_kernel(x5d, out5d) 

183 

184 if squeezed: 

185 return out5d.squeeze(0) 

186 return out5d 

187 

188 

189def _adaptive_avg_pool3d_out(input: torch.Tensor, output_size, out: torch.Tensor): 

190 x5d, squeezed = _prepare_5d_input(input) 

191 D_out, H_out, W_out = _normalize_output_size_3d(output_size) 

192 

193 # Prepare out to be 5D if needed 

194 if squeezed: 

195 if out.dim() == 4: 

196 out5d = out.unsqueeze(0) 

197 elif out.dim() == 5 and out.size(0) == 1: 

198 out5d = out 

199 else: 

200 raise ValueError("Provided 'out' must be 4D (C,D,H,W) when input is 4D") 

201 else: 

202 out5d = out 

203 if out5d.dim() != 5: 

204 raise ValueError("Provided 'out' must be 5D (N,C,D,H,W) when input is 5D") 

205 

206 # Validate shape 

207 expected_shape = (x5d.size(0), x5d.size(1), D_out, H_out, W_out) 

208 if tuple(out5d.shape) != expected_shape: 

209 raise ValueError( 

210 f"out has incorrect shape. Expected {expected_shape}, got {tuple(out5d.shape)}" 

211 ) 

212 

213 if out5d.device != x5d.device or out5d.dtype != x5d.dtype: 

214 raise ValueError( 

215 "out must be on the same device and have the same dtype as input" 

216 ) 

217 

218 _launch_adaptive_avg_pool3d_kernel(x5d, out5d) 

219 

220 return out