Coverage for src/flag_gems/runtime/backend/_cambricon/ops/min.py: 0%

170 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2import math 

3from collections import namedtuple 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry 

12from flag_gems.utils.limits import get_dtype_max 

13 

14from ..utils import TOTAL_CORE_NUM, cfggen_reduce_op, prune_reduce_config 

15 

16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

17 

18 

19@libentry() 

20@triton.jit 

21def min_kernel_float_once( 

22 inp, 

23 out, 

24 M: tl.constexpr, 

25): 

26 offset = tl.arange(0, M) 

27 inp_val = tl.load(inp + offset) 

28 min_val = tl.min(inp_val, 0) 

29 tl.store(out, min_val) 

30 

31 

32@libentry() 

33@triton.autotune( 

34 configs=cfggen_reduce_op(), 

35 key=["M"], 

36 prune_configs_by={"early_config_prune": prune_reduce_config}, 

37) 

38@triton.heuristics( 

39 values={ 

40 "ONE_TILE_PER_CTA": lambda args: args["M"] 

41 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM 

42 } 

43) 

44@triton.jit 

45def min_kernel_float( 

46 inp, out, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr 

47): 

48 pid = tl.program_id(0) 

49 block_start = pid * BLOCK_SIZE 

50 res = float("inf") 

51 if ONE_TILE_PER_CTA: 

52 offset = block_start + tl.arange(0, BLOCK_SIZE) 

53 mask = offset < M 

54 inp_val = tl.load(inp + offset, mask=mask, other=float("inf")) 

55 (res,) = tl.min(inp_val, 0, return_indices=True) 

56 else: 

57 num_jobs = tl.num_programs(axis=0) 

58 step = num_jobs * BLOCK_SIZE 

59 _tmp = tl.full([BLOCK_SIZE], value=float("inf"), dtype=inp.dtype.element_ty) 

60 for off in range(block_start, M, step): 

61 offset = off + tl.arange(0, BLOCK_SIZE) 

62 mask = offset < M 

63 inp_val = tl.load(inp + offset, mask=mask, other=float("inf")) 

64 _tmp = tl.where((inp_val < _tmp), inp_val, _tmp) 

65 (res,) = tl.min(_tmp, 0, return_indices=True) 

66 tl.atomic_min(out, res) 

67 

68 

69@libentry() 

70@triton.autotune( 

71 configs=cfggen_reduce_op(), 

72 key=["M"], 

73 prune_configs_by={"early_config_prune": prune_reduce_config}, 

74) 

75@triton.heuristics( 

76 values={ 

77 "ONE_TILE_PER_CTA": lambda args: args["M"] 

78 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM 

79 } 

80) 

81@triton.jit 

82def min_kernel_int( 

83 inp, out, FILL_VALUE, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr 

84): 

85 pid = tl.program_id(0) 

86 block_start = pid * BLOCK_SIZE 

87 res = FILL_VALUE 

88 if ONE_TILE_PER_CTA: 

89 offset = block_start + tl.arange(0, BLOCK_SIZE) 

90 mask = offset < M 

91 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE) 

92 res = tl.min(inp_val) 

93 else: 

94 num_jobs = tl.num_programs(axis=0) 

95 step = num_jobs * BLOCK_SIZE 

96 block_start = block_start.to(tl.int64) 

97 _tmp = tl.full([BLOCK_SIZE], value=2**31 - 1, dtype=tl.int32) 

98 for off in range(block_start, M, step): 

99 offset = off + tl.arange(0, BLOCK_SIZE) 

100 mask = offset < M 

101 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE) 

102 _tmp = tl.where((inp_val < _tmp), inp_val, _tmp) 

103 res = tl.min(_tmp) 

104 tl.atomic_min(out, res) 

105 

106 

107@libentry() 

108@triton.autotune( 

109 configs=cfggen_reduce_op(), 

110 key=["M"], 

111 prune_configs_by={"early_config_prune": prune_reduce_config}, 

112) 

113@triton.heuristics( 

114 values={ 

115 "ONE_TILE_PER_CTA": lambda args: args["M"] 

116 <= args["BLOCK_SIZE"] * TOTAL_CORE_NUM 

117 } 

118) 

119@triton.jit 

120def min_kernel_int64_1( 

121 inp, mid, M, BLOCK_SIZE: tl.constexpr, ONE_TILE_PER_CTA: tl.constexpr 

122): 

123 pid = tl.program_id(0) 

124 block_start = pid * BLOCK_SIZE 

125 # FILL_VALUE is the maximum value of a 64-bit integer, used as the initial value for calculations. 

126 FILL_VALUE = 2**63 - 1 

127 res = FILL_VALUE 

128 if ONE_TILE_PER_CTA: 

129 offset = block_start + tl.arange(0, BLOCK_SIZE) 

130 mask = offset < M 

131 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE) 

132 res = tl.min(inp_val) 

133 else: 

134 num_jobs = tl.num_programs(axis=0) 

135 step = num_jobs * BLOCK_SIZE 

136 block_start = block_start.to(tl.int64) 

137 _tmp = tl.full([BLOCK_SIZE], value=FILL_VALUE, dtype=tl.int64) 

138 for off in range(block_start, M, step): 

139 offset = off + tl.arange(0, BLOCK_SIZE) 

140 mask = offset < M 

141 inp_val = tl.load(inp + offset, mask=mask, other=FILL_VALUE) 

142 _tmp = tl.where((inp_val < _tmp), inp_val, _tmp) 

143 res = tl.min(_tmp) 

144 tl.store(mid + pid, res) 

145 

146 

147@libentry() 

148@triton.jit 

149def min_kernel_int64_2(mid, out, BLOCK_NUM: tl.constexpr): 

150 offset = tl.arange(0, BLOCK_NUM) 

151 mid_val = tl.load(mid + offset) 

152 out_val = tl.min(mid_val) 

153 tl.store(out, out_val) 

154 

155 

156def heur_block_n(args): 

157 return triton.next_power_of_2(args["N"]) 

158 

159 

160@libentry() 

161@triton.autotune( 

162 configs=runtime.get_tuned_config("min"), 

163 key=[ 

164 "M", 

165 "N", 

166 ], 

167) 

168@triton.jit 

169def min_kernel( 

170 inp, 

171 out_value, 

172 out_index, 

173 M, 

174 N, 

175 K, 

176 BLOCK_M: tl.constexpr, 

177 BLOCK_N: tl.constexpr, 

178): 

179 # set offset 

180 pid_m = tl.program_id(0) 

181 pid_k = tl.program_id(1) 

182 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

183 

184 min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf")) 

185 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

186 max_value = get_dtype_max(inp.type.element_ty) 

187 for start_n in range(0, N, BLOCK_N): 

188 n_offset = start_n + tl.arange(0, BLOCK_N) 

189 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

190 mask = m_offset[:, None] < M and n_offset[None, :] < N 

191 inp_ptrs = inp + offset 

192 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

193 local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True) 

194 # if return indices is not supported, call a tl.argmax in addition 

195 # local_argmin = tl.argmin(inp_vals, 1) 

196 update = local_min < min_values 

197 min_values = tl.where(update, local_min, min_values) 

198 argmin_values = tl.where(update, start_n + local_argmin, argmin_values) 

199 

200 offset_index = m_offset * K + pid_k 

201 out_value_ptrs = out_value + offset_index 

202 out_index_ptrs = out_index + offset_index 

203 mask1 = m_offset < M 

204 tl.store(out_value_ptrs, min_values, mask=mask1) 

205 tl.store(out_index_ptrs, argmin_values, mask=mask1) 

206 

207 

208def min(inp): 

209 logger.debug("GEMS_CAMBRICON MIN") 

210 M = inp.numel() 

211 mid_size = TOTAL_CORE_NUM 

212 dtype = inp.dtype 

213 device = inp.device 

214 

215 with torch_device_fn.device(device): 

216 if torch.is_floating_point(inp): 

217 if M <= 65536: 

218 out = torch.empty([], dtype=dtype, device=device) 

219 min_kernel_float_once[(1, 1, 1)](inp, out, M) 

220 else: 

221 out = torch.full([], float("inf"), dtype=torch.float32, device=device) 

222 min_kernel_float[(mid_size, 1, 1)](inp, out, M) 

223 elif dtype == torch.int64: 

224 mid = torch.empty([mid_size], dtype=dtype, device=device) 

225 out = torch.empty([], dtype=dtype, device=device) 

226 # Because atomic op don't support i64, use two kernels. 

227 min_kernel_int64_1[(mid_size, 1, 1)](inp, mid, M, enable_soft_i64=True) 

228 min_kernel_int64_2[(1, 1, 1)]( 

229 mid, out, BLOCK_NUM=mid_size, enable_soft_i64=True 

230 ) 

231 else: 

232 fill_value = torch.iinfo(dtype).max 

233 out = torch.full([], 2**31 - 1, dtype=torch.int32, device=device) 

234 min_kernel_int[(mid_size, 1, 1)](inp, out, fill_value, M) 

235 return out.to(dtype) 

236 

237 

238def min_dim(inp, dim=None, keepdim=False): 

239 logger.debug("GEMS_CAMBRICON MIN DIM") 

240 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

241 shape = inp.shape 

242 dim = dim % inp.ndim 

243 N = shape[dim] 

244 M = math.prod(shape[:dim]) 

245 K = inp.numel() // M // N 

246 

247 inp = inp.contiguous() 

248 

249 shape_list = list(shape) 

250 shape_list[dim] = 1 

251 out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device) 

252 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

253 

254 if not keepdim: 

255 out_value = torch.squeeze(out_value, dim) 

256 out_index = torch.squeeze(out_index, dim) 

257 

258 grid = lambda meta: ( 

259 triton.cdiv(M, meta["BLOCK_M"]), 

260 K, 

261 ) 

262 with torch_device_fn.device(inp.device): 

263 min_kernel[grid](inp, out_value, out_index, M, N, K) 

264 Min_out = namedtuple("min", ["values", "indices"]) 

265 out = Min_out(values=out_value, indices=out_index) 

266 return out