Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/sum.py: 0%

169 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-30 03:43 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.ops.zeros import zero_ 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13from ..utils.block_size_utils import get_block_size_1d 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16 

17 

18@libentry() 

19@triton.jit 

20def sum_kernel_1( 

21 inp, 

22 mid, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

27 inp.dtype.element_ty == tl.bfloat16 

28 ): 

29 cdtype = tl.float32 

30 else: 

31 cdtype = inp.dtype.element_ty 

32 

33 pid = tle.program_id(0) 

34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

35 inp_ptrs = inp + offset 

36 mask = offset < M 

37 

38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype) 

39 sum_val = tl.sum(inp_val) 

40 mid_ptr = mid + pid 

41 tl.store(mid_ptr, sum_val) 

42 

43 

44@libentry() 

45@triton.jit 

46def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr( 

48 mid.dtype.element_ty == tl.bfloat16 

49 ): 

50 cdtype = tl.float32 

51 else: 

52 cdtype = mid.dtype.element_ty 

53 

54 offset = tl.arange(0, BLOCK_MID) 

55 mid_ptrs = mid + offset 

56 mask = offset < mid_size 

57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype) 

58 sum_val = tl.sum(mid_val) 

59 tl.store(out, sum_val) 

60 

61 

62def heur_m_block_size(args): 

63 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

64 

65 

66def heur_n_block_size(args): 

67 import builtins 

68 

69 return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

70 

71 

72@libentry() 

73@triton.heuristics( 

74 values={ 

75 "BLOCK_M": heur_m_block_size, 

76 "BLOCK_N": heur_n_block_size, 

77 }, 

78) 

79@triton.jit 

80def sum_kernel( 

81 inp, 

82 out, 

83 M, 

84 N, 

85 BLOCK_M: tl.constexpr, 

86 BLOCK_N: tl.constexpr, 

87): 

88 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

89 inp.dtype.element_ty == tl.bfloat16 

90 ): 

91 cdtype = tl.float32 

92 else: 

93 cdtype = inp.dtype.element_ty 

94 

95 # Map the program id to the row of inp it should compute. 

96 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

97 inp = inp + pid * N 

98 out = out + pid 

99 row_mask = pid < M 

100 

101 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) 

102 for off in range(0, N, BLOCK_N): 

103 cols = off + tl.arange(0, BLOCK_N)[None, :] 

104 col_mask = cols < N 

105 mask = row_mask and col_mask 

106 

107 a = tl.load(inp + cols, mask, other=0).to(cdtype) 

108 _sum += a 

109 sum = tl.sum(_sum, axis=1)[:, None] 

110 tl.store(out, sum, row_mask) 

111 

112 

113def sum(inp, *, dtype=None): 

114 logger.debug("GEMS SUM") 

115 M = inp.numel() 

116 if dtype is None: 

117 dtype = inp.dtype 

118 if dtype is torch.bool: 

119 inp = inp.to(torch.int64) 

120 dtype = torch.int64 

121 block_size = get_block_size_1d(M, inp.element_size()) 

122 mid_size = triton.cdiv(M, block_size) 

123 block_mid = triton.next_power_of_2(mid_size) 

124 

125 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

126 out = torch.empty([], dtype=dtype, device=inp.device) 

127 

128 with torch_device_fn.device(inp.device): 

129 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048) 

130 if mid_size == 1: 

131 return mid.reshape([]) 

132 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048) 

133 return out 

134 

135 

136def sum_out(inp, *, dtype=None, out): 

137 logger.debug("GEMS SUM_OUT") 

138 M = inp.numel() 

139 if dtype is None: 

140 dtype = inp.dtype 

141 if dtype is torch.bool: 

142 inp = inp.to(torch.int64) 

143 dtype = torch.int64 

144 block_size = get_block_size_1d(M, inp.element_size()) 

145 mid_size = triton.cdiv(M, block_size) 

146 block_mid = triton.next_power_of_2(mid_size) 

147 

148 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

149 with torch_device_fn.device(inp.device): 

150 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048) 

151 if mid_size == 1: 

152 return mid.reshape([]) 

153 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048) 

154 return out 

155 

156 

157def sum_dim(inp, dim=None, keepdim=False, *, dtype=None): 

158 logger.debug("GEMS SUM DIM") 

159 if dtype is None: 

160 dtype = inp.dtype 

161 if dtype is torch.bool: 

162 dtype = torch.int64 

163 

164 if inp.numel() == 0: 

165 out_shape = list(inp.shape) 

166 if dim is None: 

167 out_shape = [1] * len(out_shape) if keepdim else [] 

168 else: 

169 dims = dim if isinstance(dim, (list, tuple)) else [dim] 

170 if keepdim: 

171 for d in dims: 

172 out_shape[d % inp.ndim] = 1 

173 else: 

174 for d in sorted(dims, key=lambda x: x % inp.ndim, reverse=True): 

175 out_shape.pop(d % inp.ndim) 

176 out = torch.empty(out_shape, dtype=dtype, device=inp.device) 

177 zero_(out) 

178 return out 

179 

180 if dim == []: 

181 if not keepdim: 

182 return sum(inp, dtype=dtype) 

183 else: 

184 dim_num = inp.ndim 

185 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num) 

186 

187 shape = list(inp.shape) 

188 dim = [d % inp.ndim for d in dim] 

189 inp = dim_compress(inp, dim) 

190 N = 1 

191 for i in dim: 

192 N *= shape[i] 

193 shape[i] = 1 

194 M = inp.numel() // N 

195 

196 out = torch.empty(shape, dtype=dtype, device=inp.device) 

197 

198 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

199 with torch_device_fn.device(inp.device): 

200 sum_kernel[grid](inp, out, M, N, buffer_size_limit=2048) 

201 if not keepdim: 

202 out = out.squeeze(dim=dim) 

203 return out 

204 

205 

206def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out): 

207 logger.debug("GEMS SUM_DIM_OUT") 

208 if dtype is None: 

209 dtype = inp.dtype 

210 if dtype is torch.bool: 

211 dtype = torch.int64 

212 

213 if inp.numel() == 0: 

214 dims = ( 

215 dim 

216 if isinstance(dim, (list, tuple)) 

217 else ([dim] if dim is not None else []) 

218 ) 

219 if keepdim: 

220 for d in dims: 

221 pass # out shape already correct from caller 

222 zero_(out) 

223 return out 

224 

225 if dim == []: 

226 if not keepdim: 

227 return sum_out(inp, dtype=dtype, out=out) 

228 else: 

229 dim_num = inp.ndim 

230 return torch.reshape(sum_out(inp, dtype=dtype, out=out), [1] * dim_num) 

231 

232 shape = list(inp.shape) 

233 dim = [d % inp.ndim for d in dim] 

234 inp = dim_compress(inp, dim) 

235 N = 1 

236 for i in dim: 

237 N *= shape[i] 

238 shape[i] = 1 

239 M = inp.numel() // N 

240 

241 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

242 with torch_device_fn.device(inp.device): 

243 sum_kernel[grid](inp, out, M, N, buffer_size_limit=2048) 

244 if not keepdim: 

245 out.squeeze_(dim=dim) 

246 return out