Coverage for src/flag_gems/runtime/backend/_cambricon/ops/max.py: 0%

177 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2import math 

3from collections import namedtuple 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry, libtuner 

12from flag_gems.utils.limits import get_dtype_min 

13 

14from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op, prune_reduce_config 

15 

16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

17 

18 

19@libentry() 

20@triton.jit 

21def max_kernel_float_once( 

22 inp, 

23 out, 

24 M: tl.constexpr, 

25): 

26 offset = tl.arange(0, M) 

27 inp_val = tl.load(inp + offset) 

28 max_val = tl.max(inp_val, 0) 

29 tl.store(out, max_val) 

30 

31 

32@libentry() 

33@libtuner( 

34 configs=cfggen_reduce_op(), 

35 key=["M"], 

36 strategy=["log"], 

37 prune_configs_by={"early_config_prune": prune_reduce_config}, 

38) 

39@triton.heuristics( 

40 values={ 

41 "ONE_TILE_PER_CTA": lambda args: args["M"] 

42 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM 

43 } 

44) 

45@triton.jit 

46def max_kernel_float( 

47 inp, out, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr 

48): 

49 pid = tl.program_id(0) 

50 block_start = pid * BLOCK_SIZE 

51 res = -float("inf") 

52 

53 if ONE_TILE_PER_CTA: 

54 offset = block_start + tl.arange(0, BLOCK_SIZE) 

55 mask = offset < M 

56 inp_val = tl.load(inp + offset, mask=mask, other=-float("inf")) 

57 (res,) = tl.max(inp_val, 0, return_indices=True) 

58 tl.atomic_max(out, res) 

59 else: 

60 num_jobs = tl.num_programs(axis=0) 

61 step = num_jobs * BLOCK_SIZE 

62 _tmp = tl.full([BLOCK_SIZE], value=-float("inf"), dtype=inp.dtype.element_ty) 

63 for off in range(block_start, M, step): 

64 offset = off + tl.arange(0, BLOCK_SIZE) 

65 mask = offset < M 

66 inp_val = tl.load(inp + offset, mask=mask, other=-float("inf")) 

67 _tmp = tl.where((inp_val > _tmp), inp_val, _tmp) 

68 (res,) = tl.max(_tmp, 0, return_indices=True) 

69 tl.atomic_max(out, res) 

70 

71 

72@libentry() 

73@libtuner( 

74 configs=cfggen_reduce_op(), 

75 key=["M"], 

76 strategy=["log"], 

77 prune_configs_by={"early_config_prune": prune_reduce_config}, 

78) 

79@triton.heuristics( 

80 values={ 

81 "ONE_TILE_PER_CTA": lambda args: args["M"] 

82 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM 

83 } 

84) 

85@triton.jit 

86def max_kernel_int( 

87 inp, out, M, FILL_VALUE, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr 

88): 

89 pid = tl.program_id(0) 

90 block_start = pid * BLOCK_SIZE 

91 res = FILL_VALUE 

92 if ONE_TILE_PER_CTA: 

93 offset = block_start + tl.arange(0, BLOCK_SIZE) 

94 mask = offset < M 

95 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE) 

96 res = tl.max(inp_val) 

97 else: 

98 num_jobs = tl.num_programs(axis=0) 

99 step = num_jobs * BLOCK_SIZE 

100 block_start = pid * BLOCK_SIZE 

101 _tmp = tl.full([BLOCK_SIZE], value=-(2**63), dtype=tl.int64) 

102 for off in range(block_start, M, step): 

103 offset = off + tl.arange(0, BLOCK_SIZE) 

104 mask = offset < M 

105 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE) 

106 _tmp = tl.where((inp_val > _tmp), inp_val, _tmp) 

107 res = tl.max(_tmp) 

108 tl.atomic_max(out, res) 

109 

110 

111@libentry() 

112@libtuner( 

113 configs=cfggen_reduce_op(), 

114 key=["M"], 

115 strategy=["log"], 

116 prune_configs_by={"early_config_prune": prune_reduce_config}, 

117) 

118@triton.heuristics( 

119 values={ 

120 "ONE_TILE_PER_CTA": lambda args: args["M"] 

121 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM 

122 } 

123) 

124@triton.jit 

125def max_kernel_int64_1( 

126 inp, mid, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr 

127): 

128 pid = tl.program_id(0) 

129 block_start = pid * BLOCK_SIZE 

130 FILL_VALUE = -(2**63) 

131 res = FILL_VALUE 

132 if ONE_TILE_PER_CTA: 

133 offset = block_start + tl.arange(0, BLOCK_SIZE) 

134 mask = offset < M 

135 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE) 

136 res = tl.max(inp_val) 

137 else: 

138 num_jobs = tl.num_programs(axis=0) 

139 step = num_jobs * BLOCK_SIZE 

140 block_start = block_start.to(tl.int64) 

141 _tmp = tl.full([BLOCK_SIZE], value=FILL_VALUE, dtype=tl.int64) 

142 for off in range(block_start, M, step): 

143 offset = off + tl.arange(0, BLOCK_SIZE) 

144 mask = offset < M 

145 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE) 

146 _tmp = tl.where((inp_val > _tmp), inp_val, _tmp) 

147 res = tl.max(_tmp) 

148 tl.store(mid + pid, res) 

149 

150 

151@libentry() 

152@triton.jit 

153def max_kernel_int64_2(mid, out, BLOCK_NUM: tl.constexpr): 

154 offset = tl.arange(0, BLOCK_NUM) 

155 mid_val = tl.load(mid + offset) 

156 out_val = tl.max(mid_val) 

157 tl.store(out, out_val) 

158 

159 

160def heur_block_n(args): 

161 return triton.next_power_of_2(args["N"]) 

162 

163 

164@libentry() 

165@libtuner( 

166 configs=runtime.get_tuned_config("max"), 

167 key=[ 

168 "M", 

169 "N", 

170 ], 

171 strategy=["log", "log"], 

172) 

173@triton.jit 

174def max_kernel( 

175 inp, 

176 out_value, 

177 out_index, 

178 M, 

179 N, 

180 K, 

181 BLOCK_M: tl.constexpr, 

182 BLOCK_N: tl.constexpr, 

183 UPCAST: tl.constexpr = False, 

184): 

185 # set offset 

186 if UPCAST: 

187 pid_m = tl.program_id(0).to(tl.int64) 

188 pid_k = tl.program_id(1).to(tl.int64) 

189 else: 

190 pid_m = tl.program_id(0) 

191 pid_k = tl.program_id(1) 

192 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

193 result_value = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) 

194 result_index = tl.zeros([BLOCK_M], dtype=tl.int64) 

195 min_value = get_dtype_min(inp.type.element_ty) 

196 for i in range(0, N, BLOCK_N): 

197 n_offset = i + tl.arange(0, BLOCK_N) 

198 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

199 # set mask 

200 mask = m_offset[:, None] < M and n_offset[None, :] < N 

201 inp_ptrs = inp + offset 

202 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

203 max_value, max_index = tl.max(inp_vals, axis=1, return_indices=True) 

204 update_mask = max_value > result_value 

205 result_value = tl.where(update_mask, max_value, result_value) 

206 result_index = tl.where(update_mask, i + max_index, result_index) 

207 mask1 = m_offset < M 

208 offset_index = m_offset * K + pid_k 

209 out_value_ptrs = out_value + offset_index 

210 out_index_ptrs = out_index + offset_index 

211 

212 tl.store(out_value_ptrs, result_value, mask=mask1) 

213 tl.store(out_index_ptrs, result_index, mask=mask1) 

214 

215 

216def max(inp): 

217 logger.debug("GEMS_CAMBRICON MAX") 

218 inp = inp.contiguous() 

219 M = inp.numel() 

220 dtype = inp.dtype 

221 device = inp.device 

222 mid_size = TOTAL_CORE_NUM 

223 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), mid_size),) 

224 

225 with torch_device_fn.device(inp.device): 

226 if torch.is_floating_point(inp): 

227 if M <= 65536: 

228 out = torch.empty([], dtype=dtype, device=device) 

229 max_kernel_float_once[(1, 1, 1)](inp, out, M) 

230 else: 

231 out = torch.full([], float("-inf"), dtype=torch.float32, device=device) 

232 max_kernel_float[grid](inp, out, M) 

233 elif dtype == torch.int64: 

234 mid = torch.empty([mid_size], dtype=dtype, device=device) 

235 out = torch.empty([], dtype=dtype, device=device) 

236 # Because atomic op don't support i64, use two kernels. 

237 max_kernel_int64_1[(mid_size, 1, 1)](inp, mid, M, enable_soft_i64=True) 

238 max_kernel_int64_2[(1, 1, 1)]( 

239 mid, out, BLOCK_NUM=mid_size, enable_soft_i64=True 

240 ) 

241 else: 

242 fill_value = torch.iinfo(dtype).min 

243 out = torch.full([], -(2**31), dtype=torch.int32, device=device) 

244 max_kernel_int[grid](inp, out, M, fill_value) 

245 return out.to(dtype) 

246 

247 

248def max_dim(inp, dim=None, keepdim=False): 

249 logger.debug("GEMS_CAMBRICON MAX DIM") 

250 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

251 shape = inp.shape 

252 dim = dim % inp.ndim 

253 N = shape[dim] 

254 M = math.prod(shape[:dim]) 

255 K = inp.numel() // M // N 

256 

257 inp = inp.contiguous() 

258 

259 shape_list = list(shape) 

260 shape_list[dim] = 1 

261 out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device) 

262 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

263 

264 if not keepdim: 

265 out_value = torch.squeeze(out_value, dim) 

266 out_index = torch.squeeze(out_index, dim) 

267 UPCAST = inp.shape[0] * inp.stride(0) >= 1 << 31 

268 

269 grid = lambda meta: ( 

270 triton.cdiv(M, meta["BLOCK_M"]), 

271 K, 

272 ) 

273 with torch_device_fn.device(inp.device): 

274 max_kernel[grid](inp, out_value, out_index, M, N, K, UPCAST=UPCAST) 

275 Max_out = namedtuple("max", ["values", "indices"]) 

276 out = Max_out(values=out_value, indices=out_index) 

277 return out