Coverage for src/flag_gems/ops/quantile.py: 44%
153 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch import Tensor
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, tl_extra_shim
10from flag_gems.utils import triton_lang_extension as tle
12from .topk import _get_finfo_val, argsort
14logger = logging.getLogger(__name__)
16INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"]
17MAX_BITONIC_M = 1024
20def heur_block_q(args):
21 return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16))
24def heur_block_n(args):
25 if args["N"] >= 65536:
26 return triton.next_power_of_2(triton.cdiv(args["N"], 512))
27 elif args["N"] >= 4096:
28 return triton.next_power_of_2(triton.cdiv(args["N"], 128))
29 elif args["N"] >= 64:
30 return 32
31 elif args["N"] >= 32:
32 return 4
33 else:
34 return 1
37@libentry()
38@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n})
39@triton.jit
40def quantile_kernel(
41 inp,
42 q,
43 out,
44 N,
45 M,
46 Q,
47 BLOCK_Q: tl.constexpr,
48 BLOCK_N: tl.constexpr,
49 interpolation: tl.constexpr,
50):
51 pid_Q = tle.program_id(0)
52 pid_N = tle.program_id(1)
53 ctype = inp.dtype.element_ty
55 offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q)
56 mask_Q = offsets_Q < Q
57 q_ptrs = q + offsets_Q
59 offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N)
60 mask_N = offsets_N < N
62 out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :]
63 mask_out = mask_N[:, None] & mask_Q[None, :]
65 q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1)
66 q_lower = tl.floor(q_block).to(tl.int32)
67 q_upper = tl.ceil(q_block).to(tl.int32)
69 inp_lower = tl.load(
70 inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0
71 )
72 inp_upper = tl.load(
73 inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0
74 )
76 if interpolation == "linear":
77 q_frac = q_block - q_lower
78 tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out)
80 elif interpolation == "lower":
81 tl.store(out_ptrs, inp_lower, mask_out)
83 elif interpolation == "higher":
84 tl.store(out_ptrs, inp_upper, mask_out)
86 elif interpolation == "nearest":
87 q_round = tl_extra_shim.rint(q_block)
88 out_block = tl.where(q_round == q_upper, inp_upper, inp_lower)
89 tl.store(out_ptrs, out_block, mask_out)
91 elif interpolation == "midpoint":
92 tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out)
95@libentry()
96@triton.jit
97def quantile_bitonic_kernel(
98 inp,
99 q,
100 out,
101 N,
102 M,
103 Q,
104 BLOCK_Q: tl.constexpr,
105 BLOCK_M: tl.constexpr,
106 interpolation: tl.constexpr,
107):
108 pid = tle.program_id(0)
109 ctype = inp.dtype.element_ty
111 cols = tl.arange(0, BLOCK_M)
112 mask_M = cols < M
113 row_ptr = inp + pid * M
114 mask_val = _get_finfo_val(ctype, return_max=True)
115 vals = tl.load(row_ptr + cols, mask=mask_M, other=mask_val)
116 vals = tl.where(vals.dtype.is_fp64(), vals, vals.to(tl.float32))
117 ids = tl.arange(0, BLOCK_M)
118 sorted_vals, _ = argsort(vals, ids, 0, descending=False)
120 offsets_Q = tl.arange(0, BLOCK_Q)
121 mask_Q = offsets_Q < Q
122 q_vals = tl.load(q + offsets_Q, mask=mask_Q, other=0.0).to(tl.float32)
123 q_scaled = q_vals * (M - 1)
124 q_lower = tl.floor(q_scaled).to(tl.int32)
125 q_upper = tl.ceil(q_scaled).to(tl.int32)
127 idx = tl.arange(0, BLOCK_M)[:, None]
128 mask_lower = idx == q_lower[None, :]
129 mask_upper = idx == q_upper[None, :]
130 mask_lower_f = mask_lower.to(tl.float32)
131 mask_upper_f = mask_upper.to(tl.float32)
132 lower_vals = tl.sum(sorted_vals[:, None] * mask_lower_f, axis=0)
133 upper_vals = tl.sum(sorted_vals[:, None] * mask_upper_f, axis=0)
135 if interpolation == "linear":
136 q_frac = q_scaled - q_lower
137 out_vals = lower_vals + (upper_vals - lower_vals) * q_frac
138 elif interpolation == "lower":
139 out_vals = lower_vals
140 elif interpolation == "higher":
141 out_vals = upper_vals
142 elif interpolation == "nearest":
143 q_round = tl_extra_shim.rint(q_scaled).to(tl.int32)
144 out_vals = tl.where(q_round == q_upper, upper_vals, lower_vals)
145 elif interpolation == "midpoint":
146 out_vals = (lower_vals + upper_vals) * 0.5
148 out_ptr = out + pid * Q + offsets_Q
149 tl.store(out_ptr, out_vals.to(ctype), mask=mask_Q)
152def quantile(
153 inp, q, dim=None, keepdim=False, interpolation="linear", out=None
154) -> Tensor:
155 logger.debug("GEMS QUANTILE DIM")
156 assert torch.is_floating_point(inp)
157 assert dim is None or isinstance(dim, int)
158 assert isinstance(q, (float, torch.Tensor))
159 assert interpolation in INTERPOLATION_METHOD
161 # Handle dim
162 if dim is None:
163 inp = inp.ravel()
164 dim = 0
165 if dim < 0:
166 dim = dim + inp.ndim
168 # Handle q
169 q_all_ones = False
170 q_all_zeros = False
171 if isinstance(q, float):
172 q_all_ones = q == 1.0
173 q_all_zeros = q == 0.0
174 q = torch.tensor(q, device=inp.device, dtype=inp.dtype)
175 Q = 1
176 else:
177 q = q.to(device=inp.device, dtype=inp.dtype)
178 Q = 1 if q.numel() == 1 else len(q)
180 assert torch.all(q >= 0.0) and torch.all(q <= 1.0)
182 # Fast path: q == 0.0 -> min, q == 1.0 -> max (no sort needed)
183 if q_all_ones or q_all_zeros:
184 reduce_fn = torch.amax if q_all_ones else torch.amin
185 if out is not None and Q == 1:
186 reduce_fn(inp, dim=dim, keepdim=keepdim, out=out)
187 return out
188 output = reduce_fn(inp, dim=dim, keepdim=keepdim)
189 if Q > 1:
190 output = output.unsqueeze(0).expand(Q, *output.shape)
191 if out is not None:
192 out.copy_(output)
193 return out
194 return output
196 # handle input tensor
197 if dim != inp.ndim - 1:
198 inp = torch.movedim(inp, dim, -1).contiguous()
199 else:
200 inp = inp.contiguous()
202 M = inp.size(-1)
203 N = inp.numel() // M
205 output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device)
206 if M <= MAX_BITONIC_M:
207 BLOCK_M = triton.next_power_of_2(M)
208 BLOCK_Q = triton.next_power_of_2(min(Q, 16))
209 grid = (N,)
210 with torch_device_fn.device(inp.device):
211 quantile_bitonic_kernel[grid](
212 inp,
213 q,
214 output,
215 N,
216 M,
217 Q,
218 BLOCK_Q=BLOCK_Q,
219 BLOCK_M=BLOCK_M,
220 interpolation=interpolation,
221 )
222 else:
223 sorted_vals, _ = inp.sort(dim=-1)
224 grid = lambda meta: (
225 triton.cdiv(Q, meta["BLOCK_Q"]),
226 triton.cdiv(N, meta["BLOCK_N"]),
227 )
228 with torch_device_fn.device(inp.device):
229 quantile_kernel[grid](
230 sorted_vals, q, output, N, M, Q, interpolation=interpolation
231 )
233 if Q == 1:
234 output = output.squeeze(-1)
235 else:
236 output = output.movedim(-1, 0)
237 if keepdim:
238 output = output.unsqueeze(dim + (1 if Q != 1 else 0))
240 if out is not None:
241 out.copy_(output)
242 return output