Coverage for src/flag_gems/ops/mean.py: 47%

193 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2import math 

3from functools import reduce 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import dim_compress, libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as tle 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@libentry() 

18@triton.jit 

19def mean_kernel_1( 

20 inp, 

21 mid, 

22 M, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 # accumulation dtype 

26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

27 inp.dtype.element_ty == tl.bfloat16 

28 ): 

29 cdtype = tl.float32 

30 else: 

31 cdtype = inp.dtype.element_ty 

32 

33 pid = tle.program_id(0) 

34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

35 inp_ptrs = inp + offset 

36 mask = offset < M 

37 

38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype) 

39 sum_val = tl.sum(inp_val) 

40 mid_ptr = mid + pid 

41 tl.store(mid_ptr, sum_val) 

42 

43 

44@libentry() 

45@triton.jit 

46def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr): 

47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr( 

48 mid.dtype.element_ty == tl.bfloat16 

49 ): 

50 cdtype = tl.float32 

51 else: 

52 cdtype = mid.dtype.element_ty 

53 

54 offset = tl.arange(0, BLOCK_MID) 

55 mid_ptrs = mid + offset 

56 mask = offset < MID_SIZE 

57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype) 

58 sum_val = tl.sum(mid_val) 

59 # divide by total element count M to get mean 

60 mean_val = sum_val / M 

61 tl.store(out, mean_val) 

62 

63 

64def mean(inp, *, dtype=None): 

65 logger.debug("GEMS MEAN") 

66 inp = inp.contiguous() 

67 M = inp.numel() 

68 if dtype is None: 

69 dtype = inp.dtype 

70 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

71 mid_size = triton.cdiv(M, block_size) 

72 block_mid = triton.next_power_of_2(mid_size) 

73 

74 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

75 out = torch.empty([], dtype=dtype, device=inp.device) 

76 

77 with torch_device_fn.device(inp.device): 

78 mean_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

79 mean_kernel_2[(1, 1, 1)](mid, out, M, mid_size, block_mid) 

80 return out 

81 

82 

83@libentry() 

84@triton.heuristics(runtime.get_heuristic_config("mean_non_inner")) 

85@triton.jit 

86def mean_dim_kernel_non_inner( 

87 output_ptr, 

88 input_ptr, 

89 M, 

90 N, 

91 K, 

92 TILE_N: tl.constexpr, 

93 TILE_K: tl.constexpr, 

94 ONE_TILE_PER_CTA: tl.constexpr, 

95): 

96 # accumulation dtype 

97 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

98 input_ptr.dtype.element_ty == tl.bfloat16 

99 ): 

100 cdtype = tl.float32 

101 else: 

102 cdtype = input_ptr.dtype.element_ty 

103 

104 pid_m = tle.program_id(0) 

105 pid_k = tle.program_id(1) 

106 

107 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :] 

108 

109 if ONE_TILE_PER_CTA: 

110 n_offsets = tl.arange(0, TILE_N)[:, None] 

111 inp_offset = pid_m * N * K + n_offsets * K + k_offsets 

112 mask = (n_offsets < N) & (k_offsets < K) 

113 input_ptrs = input_ptr + inp_offset 

114 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) 

115 # sum along reduction axis (N) -> keep dims so axis 0 corresponds to TILE_K 

116 summed = tl.sum(inp, axis=0, keep_dims=True) 

117 # divide by N to get mean 

118 out = summed / N 

119 out_offset = pid_m * K + k_offsets 

120 output_ptrs = output_ptr + out_offset 

121 tl.store(output_ptrs, out, mask=k_offsets < K) 

122 else: 

123 sum_tile = tl.zeros([TILE_N, TILE_K], dtype=cdtype) 

124 for start_n in range(0, N, TILE_N): 

125 n_offsets = start_n + tl.arange(0, TILE_N)[:, None] 

126 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets 

127 mask = (n_offsets < N) & (k_offsets < K) 

128 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) 

129 sum_tile += inp 

130 summed = tl.sum(sum_tile, axis=0, keep_dims=True) 

131 out = summed / N 

132 out_offset = pid_m * K + k_offsets 

133 output_ptrs = output_ptr + out_offset 

134 tl.store(output_ptrs, out, mask=k_offsets < K) 

135 

136 

137@libentry() 

138@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

139@triton.jit 

140def mean_dim_kernel_inner( 

141 output_ptr, 

142 input_ptr, 

143 M, 

144 N, 

145 TILE_N: tl.constexpr, 

146 ONE_TILE_PER_CTA: tl.constexpr, 

147): 

148 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

149 input_ptr.dtype.element_ty == tl.bfloat16 

150 ): 

151 cdtype = tl.float32 

152 else: 

153 cdtype = input_ptr.dtype.element_ty 

154 

155 pid_m = tle.program_id(0) 

156 if ONE_TILE_PER_CTA: 

157 n_offsets = tl.arange(0, TILE_N) 

158 inp_offset = pid_m * N + n_offsets 

159 input_ptrs = input_ptr + inp_offset 

160 mask = n_offsets < N 

161 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) 

162 summed = tl.sum(inp, axis=0) 

163 out = summed / N 

164 out_offset = pid_m 

165 output_ptrs = output_ptr + out_offset 

166 tl.store(output_ptrs, out) 

167 else: 

168 sum_vec = tl.zeros( 

169 [ 

170 TILE_N, 

171 ], 

172 dtype=cdtype, 

173 ) 

174 for start_n in range(0, N, TILE_N): 

175 n_offsets = start_n + tl.arange(0, TILE_N) 

176 inp_offsets = pid_m * N + n_offsets 

177 mask = n_offsets < N 

178 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) 

179 sum_vec += inp 

180 summed = tl.sum(sum_vec, axis=0) 

181 out = summed / N 

182 out_offset = pid_m 

183 output_ptrs = output_ptr + out_offset 

184 tl.store(output_ptrs, out) 

185 

186 

187@libentry() 

188@libtuner( 

189 configs=runtime.get_tuned_config("naive_reduction"), 

190 key=["M", "N"], 

191) 

192@triton.jit 

193def mean_dim_kernel( 

194 inp, 

195 out, 

196 M, 

197 N, 

198 BLOCK_M: tl.constexpr, 

199 BLOCK_N: tl.constexpr, 

200): 

201 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

202 inp.dtype.element_ty == tl.bfloat16 

203 ): 

204 cdtype = tl.float32 

205 else: 

206 cdtype = inp.dtype.element_ty 

207 

208 # Map the program id to the row of inp it should compute. 

209 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

210 inp = inp + pid * N 

211 out = out + pid 

212 row_mask = pid < M 

213 

214 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) 

215 for off in range(0, N, BLOCK_N): 

216 cols = off + tl.arange(0, BLOCK_N)[None, :] 

217 col_mask = cols < N 

218 mask = row_mask and col_mask 

219 

220 a = tl.load(inp + cols, mask, other=0).to(cdtype) 

221 _sum += a 

222 summed = tl.sum(_sum, axis=1)[:, None] 

223 mean = summed / N 

224 tl.store(out, mean, row_mask) 

225 

226 

227def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): 

228 logger.debug("GEMS MEAN_DIM") 

229 if dtype is None: 

230 dtype = inp.dtype 

231 if dtype is torch.bool: 

232 inp = inp.to(torch.int64) 

233 dtype = torch.int64 

234 

235 if dim == []: 

236 # mean over all elements 

237 if not keepdim: 

238 return mean(inp, dtype=dtype) 

239 else: 

240 dim_num = inp.ndim 

241 return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num) 

242 

243 shape = list(inp.shape) 

244 

245 # -------- normalize dim to a list of ints -------- 

246 if isinstance(dim, int): 

247 dim = [dim] 

248 else: 

249 try: 

250 dim = list(dim) 

251 except TypeError: 

252 raise TypeError( 

253 f"dim must be an int, iterable of ints, or [], got {type(dim)}" 

254 ) 

255 

256 dim = [d % inp.ndim for d in dim] 

257 # ------------------------------------------------- 

258 

259 if len(dim) == 1: 

260 dim0 = dim[0] 

261 N = inp.shape[dim0] # reduction length 

262 # product of dims before dim0; use initializer 1 for empty slice 

263 M = reduce(lambda x, y: x * y, shape[:dim0], 1) 

264 inp = inp.contiguous() 

265 K = inp.numel() // M // N 

266 shape[dim0] = 1 

267 if out is None: 

268 out = torch.empty(shape, dtype=dtype, device=inp.device) 

269 

270 with torch_device_fn.device(inp.device): 

271 if K > 1: 

272 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

273 mean_dim_kernel_non_inner[grid]( 

274 out, 

275 inp, 

276 M, 

277 N, 

278 K, 

279 ) 

280 else: 

281 grid = (M, 1, 1) 

282 mean_dim_kernel_inner[grid]( 

283 out, 

284 inp, 

285 M, 

286 N, 

287 ) 

288 if not keepdim: 

289 out = out.squeeze(dim=dim0) 

290 return out 

291 else: 

292 inp = dim_compress(inp, dim) 

293 N = 1 

294 for i in dim: 

295 N *= shape[i] 

296 shape[i] = 1 

297 M = inp.numel() // N 

298 if out is None: 

299 out = torch.empty(shape, dtype=dtype, device=inp.device) 

300 

301 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

302 with torch_device_fn.device(inp.device): 

303 mean_dim_kernel[grid](inp, out, M, N) 

304 if not keepdim: 

305 out = out.squeeze(dim=dim) 

306 return out 

307 

308 

309def mean_dim(inp, dim=None, keepdim=False, *, dtype=None): 

310 logger.debug("GEMS MEAN_DIM (wrapper)") 

311 

312 return mean_dim_comm(inp, dim, keepdim, dtype=dtype)