Coverage for src/flag_gems/runtime/backend/_cambricon/ops/quantile.py: 0%
225 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6import triton.language.core as core
7from torch import Tensor
9try:
10 # TODO: Triton 2.1 does not implement _log2.
11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton.
12 from triton.language.standard import _log2, zeros_like
13except ImportError:
14 pass
15from flag_gems.runtime import torch_device_fn
16from flag_gems.utils import libentry, tl_extra_shim
17from flag_gems.utils import triton_lang_extension as tle
19from ..utils import MAX_GRID_SIZE_X
20from .topk import _get_finfo_val
22logger = logging.getLogger(__name__)
23logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
25INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"]
26MAX_BITONIC_M = 1024
28"""
29Note(Zhengzekang):
30Refer from triton2.2 official `sort` implementation:
31https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404
32Just add indices to sort with values.
33"""
36@triton.jit
37def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr):
38 n_outer: core.constexpr = x.numel >> n_dims
39 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
41 # tl.device_print("shape is: ", shape)
42 y = core.reshape(x, shape)
43 y_idx = core.reshape(ids, shape)
45 # slice left/right with 'stride' 2**(n_dims - i - 1)
46 mask = core.arange(0, 2)[None, :, None]
47 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype)
48 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype)
49 left = core.reshape(left, x.shape)
50 right = core.reshape(right, x.shape)
52 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to(
53 ids.dtype
54 )
55 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to(
56 ids.dtype
57 )
58 left_idx = core.reshape(left_idx, ids.shape)
59 right_idx = core.reshape(right_idx, ids.shape)
61 # actual compare-and-swap
62 if core.constexpr(x.dtype.primitive_bitwidth) == 8:
63 idtype = core.int8
64 elif core.constexpr(x.dtype.primitive_bitwidth) == 16:
65 idtype = core.int16
66 elif core.constexpr(x.dtype.primitive_bitwidth) == 32:
67 idtype = core.int32
68 elif core.constexpr(x.dtype.primitive_bitwidth) == 64:
69 idtype = core.int64
70 else:
71 raise ValueError("Unsupported dtype")
73 ileft = left.to(idtype, bitcast=True)
74 iright = right.to(idtype, bitcast=True)
75 ix = x.to(idtype, bitcast=True)
77 cond = (left > right) ^ flip
78 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix))
80 if core.constexpr(ids.dtype.primitive_bitwidth) == 8:
81 idx_dtype = core.int8
82 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16:
83 idx_dtype = core.int16
84 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32:
85 idx_dtype = core.int32
86 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64:
87 idx_dtype = core.int64
88 else:
89 raise ValueError("Unsupported dtype")
91 ileft_idx = left_idx.to(idx_dtype, bitcast=True)
92 iright_idx = right_idx.to(idx_dtype, bitcast=True)
93 ix_idx = ids.to(idx_dtype, bitcast=True)
94 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx))
96 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True)
99@triton.jit
100def _bitonic_merge(
101 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr
102):
103 """
104 order_type 0 == ascending
105 order_type 1 == descending
106 order_type 2 == alternating
107 """
108 n_outer: core.constexpr = x.numel >> n_dims
109 core.static_assert(stage <= n_dims)
110 # flip denotes whether to re-arrange sub-sequences of elements in ascending or
111 # descending order.
112 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
113 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
114 # a stride of 2) at this stage
115 if order == 2:
116 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
117 flip = core.reshape(
118 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape
119 )
120 else:
121 flip = order
122 # perform `stage` rounds of `compare-and-swap`
123 for i in core.static_range(stage):
124 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
125 return x, ids
128@triton.jit
129def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr):
130 # handle default dimension or check that it is the most minor dim
131 _dim: core.constexpr = dim
132 n_dims: core.constexpr = _log2(x.shape[_dim])
133 for i in core.static_range(1, n_dims + 1):
134 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
135 return x, ids
138def heur_block_q(args):
139 return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16))
142def heur_block_n(args):
143 if args["N"] >= 65536:
144 return triton.next_power_of_2(triton.cdiv(args["N"], 512))
145 elif args["N"] >= 4096:
146 return triton.next_power_of_2(triton.cdiv(args["N"], 128))
147 elif args["N"] >= 64:
148 return 32
149 elif args["N"] >= 32:
150 return 4
151 else:
152 return 1
155@libentry()
156@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n})
157@triton.jit
158def quantile_kernel(
159 inp,
160 q,
161 out,
162 N,
163 M,
164 Q,
165 BLOCK_Q: tl.constexpr,
166 BLOCK_N: tl.constexpr,
167 interpolation: tl.constexpr,
168):
169 pid_Q = tle.program_id(0)
170 pid_N = tle.program_id(1)
171 ctype = inp.dtype.element_ty
173 offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q)
174 mask_Q = offsets_Q < Q
175 q_ptrs = q + offsets_Q
177 offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N)
178 mask_N = offsets_N < N
180 out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :]
181 mask_out = mask_N[:, None] & mask_Q[None, :]
183 q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1)
184 q_lower = tl.floor(q_block).to(tl.int32)
185 q_upper = tl.ceil(q_block).to(tl.int32)
187 inp_lower = tl.load(
188 inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0
189 )
190 inp_upper = tl.load(
191 inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0
192 )
194 if interpolation == "linear":
195 q_frac = q_block - q_lower
196 tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out)
198 elif interpolation == "lower":
199 tl.store(out_ptrs, inp_lower, mask_out)
201 elif interpolation == "higher":
202 tl.store(out_ptrs, inp_upper, mask_out)
204 elif interpolation == "nearest":
205 q_round = tl_extra_shim.rint(q_block)
206 out_block = tl.where(q_round == q_upper, inp_upper, inp_lower)
207 tl.store(out_ptrs, out_block, mask_out)
209 elif interpolation == "midpoint":
210 tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out)
213@libentry()
214@triton.jit
215def quantile_bitonic_kernel(
216 inp,
217 q,
218 out,
219 N,
220 M,
221 Q,
222 BLOCK_Q: tl.constexpr,
223 BLOCK_M: tl.constexpr,
224 interpolation: tl.constexpr,
225):
226 pid = tle.program_id(0)
227 grid_0 = tl.num_programs(0)
228 ctype = inp.dtype.element_ty
230 while pid < N:
231 cols = tl.arange(0, BLOCK_M)
232 mask_M = cols < M
233 row_ptr = inp + pid * M
234 mask_val = _get_finfo_val(ctype, return_max=True)
235 vals = tl.load(row_ptr + cols, mask=mask_M, other=mask_val)
236 vals = tl.where(vals.dtype.is_fp64(), vals, vals.to(tl.float32))
237 ids = tl.arange(0, BLOCK_M)
238 sorted_vals, _ = argsort(vals, ids, 0, descending=False)
240 offsets_Q = tl.arange(0, BLOCK_Q)
241 mask_Q = offsets_Q < Q
242 q_vals = tl.load(q + offsets_Q, mask=mask_Q, other=0.0).to(tl.float32)
243 q_scaled = q_vals * (M - 1)
244 q_lower = tl.floor(q_scaled).to(tl.int32)
245 q_upper = tl.ceil(q_scaled).to(tl.int32)
247 idx = tl.arange(0, BLOCK_M)[:, None]
248 mask_lower = idx == q_lower[None, :]
249 mask_upper = idx == q_upper[None, :]
250 mask_lower_f = mask_lower.to(tl.float32)
251 mask_upper_f = mask_upper.to(tl.float32)
252 lower_vals = tl.sum(sorted_vals[:, None] * mask_lower_f, axis=0)
253 upper_vals = tl.sum(sorted_vals[:, None] * mask_upper_f, axis=0)
255 if interpolation == "linear":
256 q_frac = q_scaled - q_lower
257 out_vals = lower_vals + (upper_vals - lower_vals) * q_frac
258 elif interpolation == "lower":
259 out_vals = lower_vals
260 elif interpolation == "higher":
261 out_vals = upper_vals
262 elif interpolation == "nearest":
263 q_round = tl_extra_shim.rint(q_scaled).to(tl.int32)
264 out_vals = tl.where(q_round == q_upper, upper_vals, lower_vals)
265 elif interpolation == "midpoint":
266 out_vals = (lower_vals + upper_vals) * 0.5
268 out_ptr = out + pid * Q + offsets_Q
269 tl.store(out_ptr, out_vals.to(ctype), mask=mask_Q)
270 pid += grid_0
273def quantile(
274 inp, q, dim=None, keepdim=False, interpolation="linear", out=None
275) -> Tensor:
276 logger.debug("GEMS_CAMBRICON QUANTILE DIM")
277 assert torch.is_floating_point(inp)
278 assert dim is None or isinstance(dim, int)
279 assert isinstance(q, (float, torch.Tensor))
280 assert interpolation in INTERPOLATION_METHOD
282 # Handle dim
283 if dim is None:
284 inp = inp.ravel()
285 dim = 0
286 if dim < 0:
287 dim = dim + inp.ndim
289 # Handle q
290 q_all_ones = False
291 q_all_zeros = False
292 if isinstance(q, float):
293 q_all_ones = q == 1.0
294 q_all_zeros = q == 0.0
295 q = torch.tensor(q, device=inp.device, dtype=inp.dtype)
296 Q = 1
297 else:
298 q = q.to(device=inp.device, dtype=inp.dtype)
299 Q = 1 if q.numel() == 1 else len(q)
301 assert torch.all(q >= 0.0) and torch.all(q <= 1.0)
303 # Fast path: q == 0.0 -> min, q == 1.0 -> max (no sort needed)
304 if q_all_ones or q_all_zeros:
305 reduce_fn = torch.amax if q_all_ones else torch.amin
306 if out is not None and Q == 1:
307 reduce_fn(inp, dim=dim, keepdim=keepdim, out=out)
308 return out
309 output = reduce_fn(inp, dim=dim, keepdim=keepdim)
310 if Q > 1:
311 output = output.unsqueeze(0).expand(Q, *output.shape)
312 if out is not None:
313 out.copy_(output)
314 return out
315 return output
317 # handle input tensor
318 if dim != inp.ndim - 1:
319 inp = torch.movedim(inp, dim, -1).contiguous()
320 else:
321 inp = inp.contiguous()
323 M = inp.size(-1)
324 N = inp.numel() // M
326 output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device)
327 if M <= MAX_BITONIC_M:
328 BLOCK_M = triton.next_power_of_2(M)
329 BLOCK_Q = triton.next_power_of_2(min(Q, 16))
330 grid = min(N, MAX_GRID_SIZE_X // 4)
331 with torch_device_fn.device(inp.device):
332 quantile_bitonic_kernel[(grid,)](
333 inp,
334 q,
335 output,
336 N,
337 M,
338 Q,
339 BLOCK_Q=BLOCK_Q,
340 BLOCK_M=BLOCK_M,
341 interpolation=interpolation,
342 )
343 else:
344 sorted_vals, _ = inp.sort(dim=-1)
345 grid = lambda meta: (
346 triton.cdiv(Q, meta["BLOCK_Q"]),
347 triton.cdiv(N, meta["BLOCK_N"]),
348 )
349 with torch_device_fn.device(inp.device):
350 quantile_kernel[grid](
351 sorted_vals, q, output, N, M, Q, interpolation=interpolation
352 )
354 if Q == 1:
355 output = output.squeeze(-1)
356 else:
357 output = output.movedim(-1, 0)
358 if keepdim:
359 output = output.unsqueeze(dim + (1 if Q != 1 else 0))
361 if out is not None:
362 out.copy_(output)
363 return output