Coverage for src/flag_gems/runtime/backend/_cambricon/ops/quantile.py: 0%

225 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6import triton.language.core as core 

7from torch import Tensor 

8 

9try: 

10 # TODO: Triton 2.1 does not implement _log2. 

11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton. 

12 from triton.language.standard import _log2, zeros_like 

13except ImportError: 

14 pass 

15from flag_gems.runtime import torch_device_fn 

16from flag_gems.utils import libentry, tl_extra_shim 

17from flag_gems.utils import triton_lang_extension as tle 

18 

19from ..utils import MAX_GRID_SIZE_X 

20from .topk import _get_finfo_val 

21 

22logger = logging.getLogger(__name__) 

23logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

24 

25INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"] 

26MAX_BITONIC_M = 1024 

27 

28""" 

29Note(Zhengzekang): 

30Refer from triton2.2 official `sort` implementation: 

31https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404 

32Just add indices to sort with values. 

33""" 

34 

35 

36@triton.jit 

37def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): 

38 n_outer: core.constexpr = x.numel >> n_dims 

39 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] 

40 

41 # tl.device_print("shape is: ", shape) 

42 y = core.reshape(x, shape) 

43 y_idx = core.reshape(ids, shape) 

44 

45 # slice left/right with 'stride' 2**(n_dims - i - 1) 

46 mask = core.arange(0, 2)[None, :, None] 

47 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype) 

48 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype) 

49 left = core.reshape(left, x.shape) 

50 right = core.reshape(right, x.shape) 

51 

52 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to( 

53 ids.dtype 

54 ) 

55 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to( 

56 ids.dtype 

57 ) 

58 left_idx = core.reshape(left_idx, ids.shape) 

59 right_idx = core.reshape(right_idx, ids.shape) 

60 

61 # actual compare-and-swap 

62 if core.constexpr(x.dtype.primitive_bitwidth) == 8: 

63 idtype = core.int8 

64 elif core.constexpr(x.dtype.primitive_bitwidth) == 16: 

65 idtype = core.int16 

66 elif core.constexpr(x.dtype.primitive_bitwidth) == 32: 

67 idtype = core.int32 

68 elif core.constexpr(x.dtype.primitive_bitwidth) == 64: 

69 idtype = core.int64 

70 else: 

71 raise ValueError("Unsupported dtype") 

72 

73 ileft = left.to(idtype, bitcast=True) 

74 iright = right.to(idtype, bitcast=True) 

75 ix = x.to(idtype, bitcast=True) 

76 

77 cond = (left > right) ^ flip 

78 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix)) 

79 

80 if core.constexpr(ids.dtype.primitive_bitwidth) == 8: 

81 idx_dtype = core.int8 

82 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16: 

83 idx_dtype = core.int16 

84 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32: 

85 idx_dtype = core.int32 

86 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64: 

87 idx_dtype = core.int64 

88 else: 

89 raise ValueError("Unsupported dtype") 

90 

91 ileft_idx = left_idx.to(idx_dtype, bitcast=True) 

92 iright_idx = right_idx.to(idx_dtype, bitcast=True) 

93 ix_idx = ids.to(idx_dtype, bitcast=True) 

94 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx)) 

95 

96 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True) 

97 

98 

99@triton.jit 

100def _bitonic_merge( 

101 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr 

102): 

103 """ 

104 order_type 0 == ascending 

105 order_type 1 == descending 

106 order_type 2 == alternating 

107 """ 

108 n_outer: core.constexpr = x.numel >> n_dims 

109 core.static_assert(stage <= n_dims) 

110 # flip denotes whether to re-arrange sub-sequences of elements in ascending or 

111 # descending order. 

112 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage 

113 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with 

114 # a stride of 2) at this stage 

115 if order == 2: 

116 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] 

117 flip = core.reshape( 

118 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape 

119 ) 

120 else: 

121 flip = order 

122 # perform `stage` rounds of `compare-and-swap` 

123 for i in core.static_range(stage): 

124 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) 

125 return x, ids 

126 

127 

128@triton.jit 

129def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr): 

130 # handle default dimension or check that it is the most minor dim 

131 _dim: core.constexpr = dim 

132 n_dims: core.constexpr = _log2(x.shape[_dim]) 

133 for i in core.static_range(1, n_dims + 1): 

134 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) 

135 return x, ids 

136 

137 

138def heur_block_q(args): 

139 return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16)) 

140 

141 

142def heur_block_n(args): 

143 if args["N"] >= 65536: 

144 return triton.next_power_of_2(triton.cdiv(args["N"], 512)) 

145 elif args["N"] >= 4096: 

146 return triton.next_power_of_2(triton.cdiv(args["N"], 128)) 

147 elif args["N"] >= 64: 

148 return 32 

149 elif args["N"] >= 32: 

150 return 4 

151 else: 

152 return 1 

153 

154 

155@libentry() 

156@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n}) 

157@triton.jit 

158def quantile_kernel( 

159 inp, 

160 q, 

161 out, 

162 N, 

163 M, 

164 Q, 

165 BLOCK_Q: tl.constexpr, 

166 BLOCK_N: tl.constexpr, 

167 interpolation: tl.constexpr, 

168): 

169 pid_Q = tle.program_id(0) 

170 pid_N = tle.program_id(1) 

171 ctype = inp.dtype.element_ty 

172 

173 offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q) 

174 mask_Q = offsets_Q < Q 

175 q_ptrs = q + offsets_Q 

176 

177 offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N) 

178 mask_N = offsets_N < N 

179 

180 out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :] 

181 mask_out = mask_N[:, None] & mask_Q[None, :] 

182 

183 q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1) 

184 q_lower = tl.floor(q_block).to(tl.int32) 

185 q_upper = tl.ceil(q_block).to(tl.int32) 

186 

187 inp_lower = tl.load( 

188 inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0 

189 ) 

190 inp_upper = tl.load( 

191 inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0 

192 ) 

193 

194 if interpolation == "linear": 

195 q_frac = q_block - q_lower 

196 tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out) 

197 

198 elif interpolation == "lower": 

199 tl.store(out_ptrs, inp_lower, mask_out) 

200 

201 elif interpolation == "higher": 

202 tl.store(out_ptrs, inp_upper, mask_out) 

203 

204 elif interpolation == "nearest": 

205 q_round = tl_extra_shim.rint(q_block) 

206 out_block = tl.where(q_round == q_upper, inp_upper, inp_lower) 

207 tl.store(out_ptrs, out_block, mask_out) 

208 

209 elif interpolation == "midpoint": 

210 tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out) 

211 

212 

213@libentry() 

214@triton.jit 

215def quantile_bitonic_kernel( 

216 inp, 

217 q, 

218 out, 

219 N, 

220 M, 

221 Q, 

222 BLOCK_Q: tl.constexpr, 

223 BLOCK_M: tl.constexpr, 

224 interpolation: tl.constexpr, 

225): 

226 pid = tle.program_id(0) 

227 grid_0 = tl.num_programs(0) 

228 ctype = inp.dtype.element_ty 

229 

230 while pid < N: 

231 cols = tl.arange(0, BLOCK_M) 

232 mask_M = cols < M 

233 row_ptr = inp + pid * M 

234 mask_val = _get_finfo_val(ctype, return_max=True) 

235 vals = tl.load(row_ptr + cols, mask=mask_M, other=mask_val) 

236 vals = tl.where(vals.dtype.is_fp64(), vals, vals.to(tl.float32)) 

237 ids = tl.arange(0, BLOCK_M) 

238 sorted_vals, _ = argsort(vals, ids, 0, descending=False) 

239 

240 offsets_Q = tl.arange(0, BLOCK_Q) 

241 mask_Q = offsets_Q < Q 

242 q_vals = tl.load(q + offsets_Q, mask=mask_Q, other=0.0).to(tl.float32) 

243 q_scaled = q_vals * (M - 1) 

244 q_lower = tl.floor(q_scaled).to(tl.int32) 

245 q_upper = tl.ceil(q_scaled).to(tl.int32) 

246 

247 idx = tl.arange(0, BLOCK_M)[:, None] 

248 mask_lower = idx == q_lower[None, :] 

249 mask_upper = idx == q_upper[None, :] 

250 mask_lower_f = mask_lower.to(tl.float32) 

251 mask_upper_f = mask_upper.to(tl.float32) 

252 lower_vals = tl.sum(sorted_vals[:, None] * mask_lower_f, axis=0) 

253 upper_vals = tl.sum(sorted_vals[:, None] * mask_upper_f, axis=0) 

254 

255 if interpolation == "linear": 

256 q_frac = q_scaled - q_lower 

257 out_vals = lower_vals + (upper_vals - lower_vals) * q_frac 

258 elif interpolation == "lower": 

259 out_vals = lower_vals 

260 elif interpolation == "higher": 

261 out_vals = upper_vals 

262 elif interpolation == "nearest": 

263 q_round = tl_extra_shim.rint(q_scaled).to(tl.int32) 

264 out_vals = tl.where(q_round == q_upper, upper_vals, lower_vals) 

265 elif interpolation == "midpoint": 

266 out_vals = (lower_vals + upper_vals) * 0.5 

267 

268 out_ptr = out + pid * Q + offsets_Q 

269 tl.store(out_ptr, out_vals.to(ctype), mask=mask_Q) 

270 pid += grid_0 

271 

272 

273def quantile( 

274 inp, q, dim=None, keepdim=False, interpolation="linear", out=None 

275) -> Tensor: 

276 logger.debug("GEMS_CAMBRICON QUANTILE DIM") 

277 assert torch.is_floating_point(inp) 

278 assert dim is None or isinstance(dim, int) 

279 assert isinstance(q, (float, torch.Tensor)) 

280 assert interpolation in INTERPOLATION_METHOD 

281 

282 # Handle dim 

283 if dim is None: 

284 inp = inp.ravel() 

285 dim = 0 

286 if dim < 0: 

287 dim = dim + inp.ndim 

288 

289 # Handle q 

290 q_all_ones = False 

291 q_all_zeros = False 

292 if isinstance(q, float): 

293 q_all_ones = q == 1.0 

294 q_all_zeros = q == 0.0 

295 q = torch.tensor(q, device=inp.device, dtype=inp.dtype) 

296 Q = 1 

297 else: 

298 q = q.to(device=inp.device, dtype=inp.dtype) 

299 Q = 1 if q.numel() == 1 else len(q) 

300 

301 assert torch.all(q >= 0.0) and torch.all(q <= 1.0) 

302 

303 # Fast path: q == 0.0 -> min, q == 1.0 -> max (no sort needed) 

304 if q_all_ones or q_all_zeros: 

305 reduce_fn = torch.amax if q_all_ones else torch.amin 

306 if out is not None and Q == 1: 

307 reduce_fn(inp, dim=dim, keepdim=keepdim, out=out) 

308 return out 

309 output = reduce_fn(inp, dim=dim, keepdim=keepdim) 

310 if Q > 1: 

311 output = output.unsqueeze(0).expand(Q, *output.shape) 

312 if out is not None: 

313 out.copy_(output) 

314 return out 

315 return output 

316 

317 # handle input tensor 

318 if dim != inp.ndim - 1: 

319 inp = torch.movedim(inp, dim, -1).contiguous() 

320 else: 

321 inp = inp.contiguous() 

322 

323 M = inp.size(-1) 

324 N = inp.numel() // M 

325 

326 output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device) 

327 if M <= MAX_BITONIC_M: 

328 BLOCK_M = triton.next_power_of_2(M) 

329 BLOCK_Q = triton.next_power_of_2(min(Q, 16)) 

330 grid = min(N, MAX_GRID_SIZE_X // 4) 

331 with torch_device_fn.device(inp.device): 

332 quantile_bitonic_kernel[(grid,)]( 

333 inp, 

334 q, 

335 output, 

336 N, 

337 M, 

338 Q, 

339 BLOCK_Q=BLOCK_Q, 

340 BLOCK_M=BLOCK_M, 

341 interpolation=interpolation, 

342 ) 

343 else: 

344 sorted_vals, _ = inp.sort(dim=-1) 

345 grid = lambda meta: ( 

346 triton.cdiv(Q, meta["BLOCK_Q"]), 

347 triton.cdiv(N, meta["BLOCK_N"]), 

348 ) 

349 with torch_device_fn.device(inp.device): 

350 quantile_kernel[grid]( 

351 sorted_vals, q, output, N, M, Q, interpolation=interpolation 

352 ) 

353 

354 if Q == 1: 

355 output = output.squeeze(-1) 

356 else: 

357 output = output.movedim(-1, 0) 

358 if keepdim: 

359 output = output.unsqueeze(dim + (1 if Q != 1 else 0)) 

360 

361 if out is not None: 

362 out.copy_(output) 

363 return output