Coverage for src/flag_gems/runtime/backend/_cambricon/ops/sum.py: 0%

142 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry, libtuner 

10 

11from ..utils import MAX_GRID_SIZE_X, TOTAL_CORE_NUM, cfggen_reduce_op 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@libentry() 

17@libtuner( 

18 configs=cfggen_reduce_op(), key=["M"], strategy=["log"], reset_to_zero=["out"] 

19) 

20@triton.jit 

21def sum_kernel_1( 

22 inp, 

23 out, 

24 M, 

25 BLOCK_SIZE: tl.constexpr, 

26): 

27 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

28 inp.dtype.element_ty == tl.bfloat16 

29 ): 

30 cdtype = tl.float32 

31 else: 

32 cdtype = inp.dtype.element_ty 

33 

34 pid = tl.program_id(0) 

35 num_jobs = tl.num_programs(axis=0) 

36 block_start = pid * BLOCK_SIZE 

37 step = num_jobs * BLOCK_SIZE 

38 _tmp = tl.zeros([BLOCK_SIZE], dtype=cdtype) 

39 block_start = block_start.to(tl.int64) 

40 for off in range(block_start, M, step): 

41 offset = off + tl.arange(0, BLOCK_SIZE) 

42 mask = offset < M 

43 inp_val = tl.load(inp + offset, mask=mask, other=0.0) 

44 _tmp = inp_val + _tmp 

45 

46 sum_val = tl.sum(_tmp) 

47 tl.atomic_add(out, sum_val) 

48 

49 

50@libentry() 

51@libtuner( 

52 configs=runtime.get_tuned_config("sum"), 

53 key=["M", "N"], 

54 strategy=["log", "log"], 

55) 

56@triton.jit 

57def sum_kernel( 

58 inp, 

59 out, 

60 M, 

61 N, 

62 BLOCK_M: tl.constexpr, 

63 BLOCK_N: tl.constexpr, 

64): 

65 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

66 inp.dtype.element_ty == tl.bfloat16 

67 ): 

68 cdtype = tl.float32 

69 elif tl.constexpr(inp.dtype.element_ty == tl.int1): 

70 cdtype = tl.int32 

71 else: 

72 cdtype = inp.dtype.element_ty 

73 prog_num = tl.num_programs(0).to(tl.uint64) 

74 sub_pid = tl.program_id(0).to(tl.uint64) 

75 task_num = tl.cdiv(M, BLOCK_M).to(tl.uint64) 

76 while sub_pid < task_num: 

77 # Map the program id to the row of inp it should compute. 

78 pid = sub_pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

79 inp_ = inp + pid * N 

80 out_ = out + pid 

81 row_mask = pid < M 

82 

83 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) 

84 for off in range(0, N, BLOCK_N): 

85 cols = off + tl.arange(0, BLOCK_N)[None, :] 

86 col_mask = cols < N 

87 mask = row_mask and col_mask 

88 

89 a = tl.load(inp_ + cols, mask, other=0).to(cdtype) 

90 _sum += a 

91 sum = tl.sum(_sum, axis=1)[:, None] 

92 tl.store(out_, sum, row_mask) 

93 sub_pid += prog_num 

94 

95 

96def sum(inp, *, dtype=None): 

97 logger.debug("GEMS_CAMBRICON SUM") 

98 M = inp.numel() 

99 if dtype is None: 

100 dtype = inp.dtype 

101 if dtype is torch.bool: 

102 inp = inp.to(torch.int32) 

103 dtype = torch.int32 

104 

105 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

106 out = torch.zeros([], dtype=dtype, device=inp.device) 

107 

108 with torch_device_fn.device(inp.device): 

109 sum_kernel_1[grid](inp, out, M) 

110 return out.to(dtype) 

111 

112 

113def sum_out(inp, *, dtype=None, out): 

114 logger.debug("GEMS_CAMBRICON SUM_OUT") 

115 M = inp.numel() 

116 if dtype is None: 

117 dtype = inp.dtype 

118 if dtype is torch.bool: 

119 inp = inp.to(torch.int32) 

120 dtype = torch.int32 

121 

122 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

123 

124 with torch_device_fn.device(inp.device): 

125 sum_kernel_1[grid](inp, out, M) 

126 return out.to(dtype) 

127 

128 

129def sum_dim(inp, dim=None, keepdim=False, *, dtype=None): 

130 logger.debug("GEMS_CAMBRICON SUM DIM") 

131 if dtype is None: 

132 dtype = inp.dtype 

133 if dtype is torch.bool: 

134 dtype = torch.int64 

135 

136 if dim is None: 

137 result = torch.sum(inp, dtype=dtype) 

138 if keepdim: 

139 result = result.reshape([1] * inp.ndim) 

140 return result 

141 

142 if dim == []: 

143 if not keepdim: 

144 return sum(inp, dtype=dtype) 

145 else: 

146 dim_num = inp.ndim 

147 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num) 

148 

149 shape = list(inp.shape) 

150 dim = [d % inp.ndim for d in dim] 

151 inp = dim_compress(inp, dim) 

152 N = 1 

153 for i in dim: 

154 N *= shape[i] 

155 shape[i] = 1 

156 M = inp.numel() // N 

157 

158 out = torch.empty(shape, dtype=dtype, device=inp.device) 

159 

160 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_M"]), MAX_GRID_SIZE_X // 4),) 

161 with torch_device_fn.device(inp.device): 

162 sum_kernel[grid](inp, out, M, N) 

163 if not keepdim: 

164 out = out.squeeze(dim=dim) 

165 return out 

166 

167 

168def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out): 

169 logger.debug("GEMS_CAMBRICON SUM_DIM_OUT") 

170 if dtype is None: 

171 dtype = inp.dtype 

172 if dtype is torch.bool: 

173 dtype = torch.int64 

174 

175 if dim is None: 

176 result = torch.sum(inp, dtype=dtype) 

177 if keepdim: 

178 result = result.reshape([1] * inp.ndim) 

179 return result 

180 

181 if dim == []: 

182 if not keepdim: 

183 return sum_out(inp, dtype=dtype, out=out) 

184 else: 

185 dim_num = inp.ndim 

186 return torch.reshape(sum_out(inp, dtype=dtype, out=out), [1] * dim_num) 

187 

188 shape = list(inp.shape) 

189 dim = [d % inp.ndim for d in dim] 

190 inp = dim_compress(inp, dim) 

191 N = 1 

192 for i in dim: 

193 N *= shape[i] 

194 shape[i] = 1 

195 M = inp.numel() // N 

196 

197 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_M"]), MAX_GRID_SIZE_X // 4),) 

198 with torch_device_fn.device(inp.device): 

199 sum_kernel[grid](inp, out, M, N) 

200 if not keepdim: 

201 out.squeeze_(dim=dim) 

202 return out