Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/sum.py: 0%

148 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12from ..utils.block_size_utils import get_block_size_1d 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17@libentry() 

18@triton.jit 

19def sum_kernel_1( 

20 inp, 

21 mid, 

22 M, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

26 inp.dtype.element_ty == tl.bfloat16 

27 ): 

28 cdtype = tl.float32 

29 else: 

30 cdtype = inp.dtype.element_ty 

31 

32 pid = tle.program_id(0) 

33 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

34 inp_ptrs = inp + offset 

35 mask = offset < M 

36 

37 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype) 

38 sum_val = tl.sum(inp_val) 

39 mid_ptr = mid + pid 

40 tl.store(mid_ptr, sum_val) 

41 

42 

43@libentry() 

44@triton.jit 

45def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

46 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr( 

47 mid.dtype.element_ty == tl.bfloat16 

48 ): 

49 cdtype = tl.float32 

50 else: 

51 cdtype = mid.dtype.element_ty 

52 

53 offset = tl.arange(0, BLOCK_MID) 

54 mid_ptrs = mid + offset 

55 mask = offset < mid_size 

56 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype) 

57 sum_val = tl.sum(mid_val) 

58 tl.store(out, sum_val) 

59 

60 

61def heur_m_block_size(args): 

62 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

63 

64 

65def heur_n_block_size(args): 

66 import builtins 

67 

68 return builtins.min(triton.next_power_of_2(args["N"]), 8192) 

69 

70 

71@libentry() 

72@triton.heuristics( 

73 values={ 

74 "BLOCK_M": heur_m_block_size, 

75 "BLOCK_N": heur_n_block_size, 

76 }, 

77) 

78@triton.jit 

79def sum_kernel( 

80 inp, 

81 out, 

82 M, 

83 N, 

84 BLOCK_M: tl.constexpr, 

85 BLOCK_N: tl.constexpr, 

86): 

87 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

88 inp.dtype.element_ty == tl.bfloat16 

89 ): 

90 cdtype = tl.float32 

91 else: 

92 cdtype = inp.dtype.element_ty 

93 

94 # Map the program id to the row of inp it should compute. 

95 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

96 inp = inp + pid * N 

97 out = out + pid 

98 row_mask = pid < M 

99 

100 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) 

101 for off in range(0, N, BLOCK_N): 

102 cols = off + tl.arange(0, BLOCK_N)[None, :] 

103 col_mask = cols < N 

104 mask = row_mask and col_mask 

105 

106 a = tl.load(inp + cols, mask, other=0).to(cdtype) 

107 _sum += a 

108 sum = tl.sum(_sum, axis=1)[:, None] 

109 tl.store(out, sum, row_mask) 

110 

111 

112def sum(inp, *, dtype=None): 

113 logger.debug("GEMS SUM") 

114 M = inp.numel() 

115 if dtype is None: 

116 dtype = inp.dtype 

117 if dtype is torch.bool: 

118 inp = inp.to(torch.int64) 

119 dtype = torch.int64 

120 block_size = get_block_size_1d(M, inp.element_size()) 

121 mid_size = triton.cdiv(M, block_size) 

122 block_mid = triton.next_power_of_2(mid_size) 

123 

124 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

125 out = torch.empty([], dtype=dtype, device=inp.device) 

126 

127 with torch_device_fn.device(inp.device): 

128 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048) 

129 if mid_size == 1: 

130 return mid.reshape([]) 

131 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048) 

132 return out 

133 

134 

135def sum_out(inp, *, dtype=None, out): 

136 logger.debug("GEMS SUM_OUT") 

137 M = inp.numel() 

138 if dtype is None: 

139 dtype = inp.dtype 

140 if dtype is torch.bool: 

141 inp = inp.to(torch.int64) 

142 dtype = torch.int64 

143 block_size = get_block_size_1d(M, inp.element_size()) 

144 mid_size = triton.cdiv(M, block_size) 

145 block_mid = triton.next_power_of_2(mid_size) 

146 

147 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

148 with torch_device_fn.device(inp.device): 

149 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048) 

150 if mid_size == 1: 

151 return mid.reshape([]) 

152 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048) 

153 return out 

154 

155 

156def sum_dim(inp, dim=None, keepdim=False, *, dtype=None): 

157 logger.debug("GEMS SUM DIM") 

158 if dtype is None: 

159 dtype = inp.dtype 

160 if dtype is torch.bool: 

161 dtype = torch.int64 

162 

163 if dim == []: 

164 if not keepdim: 

165 return sum(inp, dtype=dtype) 

166 else: 

167 dim_num = inp.ndim 

168 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num) 

169 

170 shape = list(inp.shape) 

171 dim = [d % inp.ndim for d in dim] 

172 inp = dim_compress(inp, dim) 

173 N = 1 

174 for i in dim: 

175 N *= shape[i] 

176 shape[i] = 1 

177 M = inp.numel() // N 

178 

179 out = torch.empty(shape, dtype=dtype, device=inp.device) 

180 

181 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

182 with torch_device_fn.device(inp.device): 

183 sum_kernel[grid](inp, out, M, N, buffer_size_limit=2048) 

184 if not keepdim: 

185 out = out.squeeze(dim=dim) 

186 return out 

187 

188 

189def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out): 

190 logger.debug("GEMS SUM_DIM_OUT") 

191 if dtype is None: 

192 dtype = inp.dtype 

193 if dtype is torch.bool: 

194 dtype = torch.int64 

195 

196 if dim == []: 

197 if not keepdim: 

198 return sum_out(inp, dtype=dtype, out=out) 

199 else: 

200 dim_num = inp.ndim 

201 return torch.reshape(sum_out(inp, dtype=dtype, out=out), [1] * dim_num) 

202 

203 shape = list(inp.shape) 

204 dim = [d % inp.ndim for d in dim] 

205 inp = dim_compress(inp, dim) 

206 N = 1 

207 for i in dim: 

208 N *= shape[i] 

209 shape[i] = 1 

210 M = inp.numel() // N 

211 

212 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

213 with torch_device_fn.device(inp.device): 

214 sum_kernel[grid](inp, out, M, N, buffer_size_limit=2048) 

215 if not keepdim: 

216 out.squeeze_(dim=dim) 

217 return out