Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/quantile.py: 0%
83 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch import Tensor
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"]
16# def heur_block_q(args):
17# return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16))
20# def heur_block_n(args):
21# if args["N"] >= 65536:
22# return triton.next_power_of_2(triton.cdiv(args["N"], 512))
23# elif args["N"] >= 4096:
24# return triton.next_power_of_2(triton.cdiv(args["N"], 128))
25# elif args["N"] >= 64:
26# return 32
27# elif args["N"] >= 32:
28# return 4
29# else:
30# return 1
33def heur_block_q(args):
34 import builtins
36 return builtins.min(triton.next_power_of_2(args["Q"]), 1024)
39def heur_block_n(args):
40 import builtins
42 return builtins.min(triton.next_power_of_2(args["N"]), 1024)
45@libentry()
46@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n})
47@triton.jit
48def quantile_kernel(
49 inp,
50 q,
51 out,
52 N,
53 M,
54 Q,
55 BLOCK_Q: tl.constexpr,
56 BLOCK_N: tl.constexpr,
57 interpolation: tl.constexpr,
58):
59 pid_Q = tle.program_id(0)
60 pid_N = tle.program_id(1)
61 ctype = inp.dtype.element_ty
63 offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q)
64 mask_Q = offsets_Q < Q
65 q_ptrs = q + offsets_Q
67 offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N)
68 mask_N = offsets_N < N
70 out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :]
71 mask_out = mask_N[:, None] & mask_Q[None, :]
73 q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1)
74 q_lower = tl.floor(q_block).to(tl.int32)
75 q_upper = tl.ceil(q_block).to(tl.int32)
77 inp_lower = tl.load(
78 inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0
79 )
80 inp_upper = tl.load(
81 inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0
82 )
84 if interpolation == "linear":
85 q_frac = q_block - q_lower
86 tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out)
88 elif interpolation == "lower":
89 tl.store(out_ptrs, inp_lower, mask_out)
91 elif interpolation == "higher":
92 tl.store(out_ptrs, inp_upper, mask_out)
94 elif interpolation == "nearest":
95 q_round = tl.extra.xpu.libdevice.rint(q_block)
96 out_block = tl.where(q_round == q_upper, inp_upper, inp_lower)
97 tl.store(out_ptrs, out_block, mask_out)
99 elif interpolation == "midpoint":
100 tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out)
103def quantile(
104 inp, q, dim=None, keepdim=False, interpolation="linear", out=None
105) -> Tensor:
106 logger.debug("GEMS QUANTILE DIM")
107 assert torch.is_floating_point(inp)
108 assert dim is None or isinstance(dim, int)
109 assert isinstance(q, (float, torch.Tensor))
110 assert interpolation in INTERPOLATION_METHOD
112 M = inp.numel()
113 if isinstance(q, float):
114 q = torch.tensor(q, device=inp.device)
115 Q = 1
116 else:
117 Q = 1 if q.numel() == 1 else len(q)
119 assert M > 0
120 assert Q > 0
121 assert torch.all(q >= 0.0) and torch.all(q <= 1.0)
123 if dim is None:
124 inp = inp.ravel()
125 dim = 0
127 shape = list(inp.shape)
129 dim %= inp.ndim
130 inp = dim_compress(inp, dim)
131 M = shape[dim]
132 N = inp.numel() // M
134 inp, _ = inp.sort() # Sort the input with torch.sort()
135 output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device)
137 grid = lambda meta: (
138 triton.cdiv(Q, meta["BLOCK_Q"]),
139 triton.cdiv(N, meta["BLOCK_N"]),
140 )
142 with torch_device_fn.device(inp.device):
143 quantile_kernel[grid](inp, q, output, N, M, Q, interpolation=interpolation)
145 output = output.permute(
146 (-1,) + tuple(range(0, inp.ndim - 1))
147 ) # Same as torch.quantile()
148 if keepdim:
149 output = output.unsqueeze(dim + 1)
150 if Q == 1:
151 output = output.squeeze(0)
153 if out is not None:
154 out.copy_(output)
155 return output