Coverage for src/flag_gems/ops/sum.py: 45%

232 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2import math 

3from functools import reduce 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.ops.zeros import zero_ 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import dim_compress, libentry, libtuner 

13from flag_gems.utils import triton_lang_extension as tle 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18@libentry() 

19@triton.jit 

20def sum_kernel_1( 

21 inp, 

22 mid, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

27 inp.dtype.element_ty == tl.bfloat16 

28 ): 

29 cdtype = tl.float32 

30 else: 

31 cdtype = inp.dtype.element_ty 

32 

33 pid = tle.program_id(0) 

34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

35 inp_ptrs = inp + offset 

36 mask = offset < M 

37 

38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype) 

39 sum_val = tl.sum(inp_val) 

40 mid_ptr = mid + pid 

41 tl.store(mid_ptr, sum_val) 

42 

43 

44@libentry() 

45@triton.jit 

46def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr( 

48 mid.dtype.element_ty == tl.bfloat16 

49 ): 

50 cdtype = tl.float32 

51 else: 

52 cdtype = mid.dtype.element_ty 

53 

54 offset = tl.arange(0, BLOCK_MID) 

55 mid_ptrs = mid + offset 

56 mask = offset < mid_size 

57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype) 

58 sum_val = tl.sum(mid_val) 

59 tl.store(out, sum_val) 

60 

61 

62def sum(inp, *, dtype=None): 

63 logger.debug("GEMS SUM") 

64 inp = inp.contiguous() 

65 M = inp.numel() 

66 if dtype is None: 

67 dtype = inp.dtype 

68 if dtype is torch.bool: 

69 inp = inp.to(torch.int64) 

70 dtype = torch.int64 

71 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

72 mid_size = triton.cdiv(M, block_size) 

73 block_mid = triton.next_power_of_2(mid_size) 

74 

75 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

76 out = torch.empty([], dtype=dtype, device=inp.device) 

77 

78 with torch_device_fn.device(inp.device): 

79 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

80 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) 

81 return out 

82 

83 

84def sum_out(inp, *, dtype=None, out): 

85 logger.debug("GEMS SUM_OUT") 

86 M = inp.numel() 

87 if dtype is None: 

88 dtype = inp.dtype 

89 if dtype is torch.bool: 

90 inp = inp.to(torch.int64) 

91 dtype = torch.int64 

92 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

93 mid_size = triton.cdiv(M, block_size) 

94 block_mid = triton.next_power_of_2(mid_size) 

95 

96 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

97 with torch_device_fn.device(inp.device): 

98 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

99 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) 

100 return out 

101 

102 

103@libentry() 

104@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner")) 

105@triton.jit 

106def sum_dim_kernel_non_inner( 

107 output_ptr, 

108 input_ptr, 

109 M, 

110 N, 

111 K, 

112 TILE_N: tl.constexpr, 

113 TILE_K: tl.constexpr, 

114 ONE_TILE_PER_CTA: tl.constexpr, 

115): 

116 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

117 input_ptr.dtype.element_ty == tl.bfloat16 

118 ): 

119 cdtype = tl.float32 

120 else: 

121 cdtype = input_ptr.dtype.element_ty 

122 

123 pid_m = tle.program_id(0) 

124 pid_k = tle.program_id(1) 

125 

126 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :] 

127 

128 if ONE_TILE_PER_CTA: 

129 n_offsets = tl.arange(0, TILE_N)[:, None] 

130 inp_offset = pid_m * N * K + n_offsets * K + k_offsets 

131 mask = (n_offsets < N) & (k_offsets < K) 

132 input_ptrs = input_ptr + inp_offset 

133 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) 

134 out = tl.sum(inp, axis=0, keep_dims=True) 

135 out_offset = pid_m * K + k_offsets 

136 output_ptrs = output_ptr + out_offset 

137 tl.store(output_ptrs, out, mask=k_offsets < K) 

138 else: 

139 sum = tl.zeros([TILE_N, TILE_K], dtype=cdtype) 

140 

141 # specialization does not improve performance inn this example, as tested 

142 for start_n in range(0, N, TILE_N): 

143 n_offsets = start_n + tl.arange(0, TILE_N)[:, None] 

144 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets 

145 mask = (n_offsets < N) & (k_offsets < K) 

146 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) 

147 sum += inp 

148 out = tl.sum(sum, axis=0, keep_dims=True) 

149 out_offset = pid_m * K + k_offsets 

150 output_ptrs = output_ptr + out_offset 

151 tl.store(output_ptrs, out, mask=k_offsets < K) 

152 

153 

154@libentry() 

155@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

156@triton.jit 

157def sum_dim_kernel_inner( 

158 output_ptr, 

159 input_ptr, 

160 M, 

161 N, 

162 TILE_N: tl.constexpr, 

163 ONE_TILE_PER_CTA: tl.constexpr, 

164): 

165 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

166 input_ptr.dtype.element_ty == tl.bfloat16 

167 ): 

168 cdtype = tl.float32 

169 else: 

170 cdtype = input_ptr.dtype.element_ty 

171 

172 pid_m = tle.program_id(0) 

173 if ONE_TILE_PER_CTA: 

174 n_offsets = tl.arange(0, TILE_N) 

175 inp_offset = pid_m * N + n_offsets 

176 input_ptrs = input_ptr + inp_offset 

177 mask = n_offsets < N 

178 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) 

179 out = tl.sum(inp, axis=0) 

180 out_offset = pid_m 

181 output_ptrs = output_ptr + out_offset 

182 tl.store(output_ptrs, out) 

183 else: 

184 sum = tl.zeros( 

185 [ 

186 TILE_N, 

187 ], 

188 dtype=cdtype, 

189 ) 

190 for start_n in range(0, N, TILE_N): 

191 n_offsets = start_n + tl.arange(0, TILE_N) 

192 inp_offsets = pid_m * N + n_offsets 

193 mask = n_offsets < N 

194 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) 

195 sum += inp 

196 out = tl.sum(sum, axis=0) 

197 out_offset = pid_m 

198 output_ptrs = output_ptr + out_offset 

199 tl.store(output_ptrs, out) 

200 

201 

202@libentry() 

203@libtuner( 

204 configs=runtime.get_tuned_config("naive_reduction"), 

205 key=["M", "N"], 

206) 

207@triton.jit 

208def sum_dim_kernel( 

209 inp, 

210 out, 

211 M, 

212 N, 

213 BLOCK_M: tl.constexpr, 

214 BLOCK_N: tl.constexpr, 

215): 

216 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

217 inp.dtype.element_ty == tl.bfloat16 

218 ): 

219 cdtype = tl.float32 

220 else: 

221 cdtype = inp.dtype.element_ty 

222 

223 # Map the program id to the row of inp it should compute. 

224 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

225 inp = inp + pid * N 

226 out = out + pid 

227 row_mask = pid < M 

228 

229 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) 

230 for off in range(0, N, BLOCK_N): 

231 cols = off + tl.arange(0, BLOCK_N)[None, :] 

232 col_mask = cols < N 

233 mask = row_mask and col_mask 

234 

235 a = tl.load(inp + cols, mask, other=0).to(cdtype) 

236 _sum += a 

237 sum = tl.sum(_sum, axis=1)[:, None] 

238 tl.store(out, sum, row_mask) 

239 

240 

241def sum_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): 

242 if dtype is None: 

243 dtype = inp.dtype 

244 if dtype is torch.bool: 

245 dtype = torch.int64 

246 

247 if dim is None: 

248 result = torch.sum(inp, dtype=dtype) 

249 if keepdim: 

250 result = result.reshape([1] * inp.ndim) 

251 return result 

252 

253 if dim == []: 

254 if not keepdim: 

255 return sum(inp, dtype=dtype) 

256 else: 

257 dim_num = inp.ndim 

258 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num) 

259 

260 shape = list(inp.shape) 

261 dim = [d % inp.ndim for d in dim] 

262 

263 if len(dim) == 1: 

264 dim = dim[0] 

265 N = inp.shape[dim] 

266 M = reduce(lambda x, y: x * y, shape[:dim], 1) 

267 inp = inp.contiguous() 

268 K = inp.numel() // M // N 

269 shape[dim] = 1 

270 if out is None: 

271 out = torch.empty(shape, dtype=dtype, device=inp.device) 

272 

273 with torch_device_fn.device(inp.device): 

274 if K > 1: 

275 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

276 sum_dim_kernel_non_inner[grid]( 

277 out, 

278 inp, 

279 M, 

280 N, 

281 K, 

282 ) 

283 else: 

284 grid = (M, 1, 1) 

285 sum_dim_kernel_inner[grid]( 

286 out, 

287 inp, 

288 M, 

289 N, 

290 ) 

291 if not keepdim: 

292 out = out.squeeze(dim=dim) 

293 return out 

294 else: 

295 inp = dim_compress(inp, dim) 

296 N = 1 

297 for i in dim: 

298 N *= shape[i] 

299 shape[i] = 1 

300 M = inp.numel() // N 

301 if out is None: 

302 out = torch.empty(shape, dtype=dtype, device=inp.device) 

303 

304 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

305 with torch_device_fn.device(inp.device): 

306 sum_dim_kernel[grid](inp, out, M, N) 

307 if not keepdim: 

308 out = out.squeeze(dim=dim) 

309 return out 

310 

311 

312def sum_dim(inp, dim=None, keepdim=False, *, dtype=None): 

313 logger.debug("GEMS SUM_DIM") 

314 # support dim = 0, which are consistent with PyTorch 

315 if inp.numel() == 0: 

316 if dtype is None: 

317 dtype = inp.dtype 

318 if dtype is torch.bool: 

319 dtype = torch.int64 

320 

321 out_shape = list(inp.shape) 

322 if dim is None: 

323 if keepdim: 

324 out_shape = [1] * len(out_shape) 

325 else: 

326 out_shape = [] 

327 elif isinstance(dim, (list, tuple)) and len(dim) == 0: 

328 if keepdim: 

329 out_shape = [1] * len(out_shape) 

330 else: 

331 out_shape = [] 

332 else: 

333 dims_to_reduce = dim if isinstance(dim, (list, tuple)) else [dim] 

334 if keepdim: 

335 for d in dims_to_reduce: 

336 out_shape[d % inp.ndim] = 1 

337 else: 

338 sorted_dims_to_remove = sorted( 

339 dims_to_reduce, key=lambda x: x % inp.ndim, reverse=True 

340 ) 

341 for d in sorted_dims_to_remove: 

342 index_to_remove = d % inp.ndim 

343 out_shape.pop(index_to_remove) 

344 out = torch.empty(out_shape, dtype=dtype, device=inp.device) 

345 zero_(out) 

346 return out 

347 return sum_dim_comm(inp, dim, keepdim, dtype=dtype) 

348 

349 

350def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out): 

351 logger.debug("GEMS SUM_DIM_OUT") 

352 return sum_dim_comm(inp, dim, keepdim, dtype=dtype, out=out)