Coverage for src/flag_gems/ops/sum.py: 40%

206 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2import math 

3from functools import reduce 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import dim_compress, libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as tle 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@libentry() 

18@triton.jit 

19def sum_kernel_1( 

20 inp, 

21 mid, 

22 M, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

26 inp.dtype.element_ty == tl.bfloat16 

27 ): 

28 cdtype = tl.float32 

29 else: 

30 cdtype = inp.dtype.element_ty 

31 

32 pid = tle.program_id(0) 

33 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

34 inp_ptrs = inp + offset 

35 mask = offset < M 

36 

37 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype) 

38 sum_val = tl.sum(inp_val) 

39 mid_ptr = mid + pid 

40 tl.store(mid_ptr, sum_val) 

41 

42 

43@libentry() 

44@triton.jit 

45def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

46 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr( 

47 mid.dtype.element_ty == tl.bfloat16 

48 ): 

49 cdtype = tl.float32 

50 else: 

51 cdtype = mid.dtype.element_ty 

52 

53 offset = tl.arange(0, BLOCK_MID) 

54 mid_ptrs = mid + offset 

55 mask = offset < mid_size 

56 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype) 

57 sum_val = tl.sum(mid_val) 

58 tl.store(out, sum_val) 

59 

60 

61def sum(inp, *, dtype=None): 

62 logger.debug("GEMS SUM") 

63 inp = inp.contiguous() 

64 M = inp.numel() 

65 if dtype is None: 

66 dtype = inp.dtype 

67 if dtype is torch.bool: 

68 inp = inp.to(torch.int64) 

69 dtype = torch.int64 

70 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

71 mid_size = triton.cdiv(M, block_size) 

72 block_mid = triton.next_power_of_2(mid_size) 

73 

74 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

75 out = torch.empty([], dtype=dtype, device=inp.device) 

76 

77 with torch_device_fn.device(inp.device): 

78 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

79 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) 

80 return out 

81 

82 

83def sum_out(inp, *, dtype=None, out): 

84 logger.debug("GEMS SUM_OUT") 

85 M = inp.numel() 

86 if dtype is None: 

87 dtype = inp.dtype 

88 if dtype is torch.bool: 

89 inp = inp.to(torch.int64) 

90 dtype = torch.int64 

91 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

92 mid_size = triton.cdiv(M, block_size) 

93 block_mid = triton.next_power_of_2(mid_size) 

94 

95 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

96 with torch_device_fn.device(inp.device): 

97 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

98 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) 

99 return out 

100 

101 

102@libentry() 

103@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner")) 

104@triton.jit 

105def sum_dim_kernel_non_inner( 

106 output_ptr, 

107 input_ptr, 

108 M, 

109 N, 

110 K, 

111 TILE_N: tl.constexpr, 

112 TILE_K: tl.constexpr, 

113 ONE_TILE_PER_CTA: tl.constexpr, 

114): 

115 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

116 input_ptr.dtype.element_ty == tl.bfloat16 

117 ): 

118 cdtype = tl.float32 

119 else: 

120 cdtype = input_ptr.dtype.element_ty 

121 

122 pid_m = tle.program_id(0) 

123 pid_k = tle.program_id(1) 

124 

125 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :] 

126 

127 if ONE_TILE_PER_CTA: 

128 n_offsets = tl.arange(0, TILE_N)[:, None] 

129 inp_offset = pid_m * N * K + n_offsets * K + k_offsets 

130 mask = (n_offsets < N) & (k_offsets < K) 

131 input_ptrs = input_ptr + inp_offset 

132 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) 

133 out = tl.sum(inp, axis=0, keep_dims=True) 

134 out_offset = pid_m * K + k_offsets 

135 output_ptrs = output_ptr + out_offset 

136 tl.store(output_ptrs, out, mask=k_offsets < K) 

137 else: 

138 sum = tl.zeros([TILE_N, TILE_K], dtype=cdtype) 

139 

140 # specialization does not improve performance inn this example, as tested 

141 for start_n in range(0, N, TILE_N): 

142 n_offsets = start_n + tl.arange(0, TILE_N)[:, None] 

143 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets 

144 mask = (n_offsets < N) & (k_offsets < K) 

145 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) 

146 sum += inp 

147 out = tl.sum(sum, axis=0, keep_dims=True) 

148 out_offset = pid_m * K + k_offsets 

149 output_ptrs = output_ptr + out_offset 

150 tl.store(output_ptrs, out, mask=k_offsets < K) 

151 

152 

153@libentry() 

154@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

155@triton.jit 

156def sum_dim_kernel_inner( 

157 output_ptr, 

158 input_ptr, 

159 M, 

160 N, 

161 TILE_N: tl.constexpr, 

162 ONE_TILE_PER_CTA: tl.constexpr, 

163): 

164 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

165 input_ptr.dtype.element_ty == tl.bfloat16 

166 ): 

167 cdtype = tl.float32 

168 else: 

169 cdtype = input_ptr.dtype.element_ty 

170 

171 pid_m = tle.program_id(0) 

172 if ONE_TILE_PER_CTA: 

173 n_offsets = tl.arange(0, TILE_N) 

174 inp_offset = pid_m * N + n_offsets 

175 input_ptrs = input_ptr + inp_offset 

176 mask = n_offsets < N 

177 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) 

178 out = tl.sum(inp, axis=0) 

179 out_offset = pid_m 

180 output_ptrs = output_ptr + out_offset 

181 tl.store(output_ptrs, out) 

182 else: 

183 sum = tl.zeros( 

184 [ 

185 TILE_N, 

186 ], 

187 dtype=cdtype, 

188 ) 

189 for start_n in range(0, N, TILE_N): 

190 n_offsets = start_n + tl.arange(0, TILE_N) 

191 inp_offsets = pid_m * N + n_offsets 

192 mask = n_offsets < N 

193 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) 

194 sum += inp 

195 out = tl.sum(sum, axis=0) 

196 out_offset = pid_m 

197 output_ptrs = output_ptr + out_offset 

198 tl.store(output_ptrs, out) 

199 

200 

201@libentry() 

202@libtuner( 

203 configs=runtime.get_tuned_config("naive_reduction"), 

204 key=["M", "N"], 

205) 

206@triton.jit 

207def sum_dim_kernel( 

208 inp, 

209 out, 

210 M, 

211 N, 

212 BLOCK_M: tl.constexpr, 

213 BLOCK_N: tl.constexpr, 

214): 

215 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

216 inp.dtype.element_ty == tl.bfloat16 

217 ): 

218 cdtype = tl.float32 

219 else: 

220 cdtype = inp.dtype.element_ty 

221 

222 # Map the program id to the row of inp it should compute. 

223 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

224 inp = inp + pid * N 

225 out = out + pid 

226 row_mask = pid < M 

227 

228 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) 

229 for off in range(0, N, BLOCK_N): 

230 cols = off + tl.arange(0, BLOCK_N)[None, :] 

231 col_mask = cols < N 

232 mask = row_mask and col_mask 

233 

234 a = tl.load(inp + cols, mask, other=0).to(cdtype) 

235 _sum += a 

236 sum = tl.sum(_sum, axis=1)[:, None] 

237 tl.store(out, sum, row_mask) 

238 

239 

240def sum_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): 

241 if dtype is None: 

242 dtype = inp.dtype 

243 if dtype is torch.bool: 

244 dtype = torch.int64 

245 

246 if dim is None: 

247 result = torch.sum(inp, dtype=dtype) 

248 if keepdim: 

249 result = result.reshape([1] * inp.ndim) 

250 return result 

251 

252 if dim == []: 

253 if not keepdim: 

254 return sum(inp, dtype=dtype) 

255 else: 

256 dim_num = inp.ndim 

257 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num) 

258 

259 shape = list(inp.shape) 

260 dim = [d % inp.ndim for d in dim] 

261 

262 if len(dim) == 1: 

263 dim = dim[0] 

264 N = inp.shape[dim] 

265 M = reduce(lambda x, y: x * y, shape[:dim], 1) 

266 inp = inp.contiguous() 

267 K = inp.numel() // M // N 

268 shape[dim] = 1 

269 if out is None: 

270 out = torch.empty(shape, dtype=dtype, device=inp.device) 

271 

272 with torch_device_fn.device(inp.device): 

273 if K > 1: 

274 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

275 sum_dim_kernel_non_inner[grid]( 

276 out, 

277 inp, 

278 M, 

279 N, 

280 K, 

281 ) 

282 else: 

283 grid = (M, 1, 1) 

284 sum_dim_kernel_inner[grid]( 

285 out, 

286 inp, 

287 M, 

288 N, 

289 ) 

290 if not keepdim: 

291 out = out.squeeze(dim=dim) 

292 return out 

293 else: 

294 inp = dim_compress(inp, dim) 

295 N = 1 

296 for i in dim: 

297 N *= shape[i] 

298 shape[i] = 1 

299 M = inp.numel() // N 

300 if out is None: 

301 out = torch.empty(shape, dtype=dtype, device=inp.device) 

302 

303 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

304 with torch_device_fn.device(inp.device): 

305 sum_dim_kernel[grid](inp, out, M, N) 

306 if not keepdim: 

307 out = out.squeeze(dim=dim) 

308 return out 

309 

310 

311def sum_dim(inp, dim=None, keepdim=False, *, dtype=None): 

312 logger.debug("GEMS SUM_DIM") 

313 return sum_dim_comm(inp, dim, keepdim, dtype=dtype) 

314 

315 

316def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out): 

317 logger.debug("GEMS SUM_DIM_OUT") 

318 return sum_dim_comm(inp, dim, keepdim, dtype=dtype, out=out)