Coverage for src/flag_gems/ops/quantile.py: 44%

153 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch import Tensor 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, tl_extra_shim 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12from .topk import _get_finfo_val, argsort 

13 

14logger = logging.getLogger(__name__) 

15 

16INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"] 

17MAX_BITONIC_M = 1024 

18 

19 

20def heur_block_q(args): 

21 return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16)) 

22 

23 

24def heur_block_n(args): 

25 if args["N"] >= 65536: 

26 return triton.next_power_of_2(triton.cdiv(args["N"], 512)) 

27 elif args["N"] >= 4096: 

28 return triton.next_power_of_2(triton.cdiv(args["N"], 128)) 

29 elif args["N"] >= 64: 

30 return 32 

31 elif args["N"] >= 32: 

32 return 4 

33 else: 

34 return 1 

35 

36 

37@libentry() 

38@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n}) 

39@triton.jit 

40def quantile_kernel( 

41 inp, 

42 q, 

43 out, 

44 N, 

45 M, 

46 Q, 

47 BLOCK_Q: tl.constexpr, 

48 BLOCK_N: tl.constexpr, 

49 interpolation: tl.constexpr, 

50): 

51 pid_Q = tle.program_id(0) 

52 pid_N = tle.program_id(1) 

53 ctype = inp.dtype.element_ty 

54 

55 offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q) 

56 mask_Q = offsets_Q < Q 

57 q_ptrs = q + offsets_Q 

58 

59 offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N) 

60 mask_N = offsets_N < N 

61 

62 out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :] 

63 mask_out = mask_N[:, None] & mask_Q[None, :] 

64 

65 q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1) 

66 q_lower = tl.floor(q_block).to(tl.int32) 

67 q_upper = tl.ceil(q_block).to(tl.int32) 

68 

69 inp_lower = tl.load( 

70 inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0 

71 ) 

72 inp_upper = tl.load( 

73 inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0 

74 ) 

75 

76 if interpolation == "linear": 

77 q_frac = q_block - q_lower 

78 tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out) 

79 

80 elif interpolation == "lower": 

81 tl.store(out_ptrs, inp_lower, mask_out) 

82 

83 elif interpolation == "higher": 

84 tl.store(out_ptrs, inp_upper, mask_out) 

85 

86 elif interpolation == "nearest": 

87 q_round = tl_extra_shim.rint(q_block) 

88 out_block = tl.where(q_round == q_upper, inp_upper, inp_lower) 

89 tl.store(out_ptrs, out_block, mask_out) 

90 

91 elif interpolation == "midpoint": 

92 tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out) 

93 

94 

95@libentry() 

96@triton.jit 

97def quantile_bitonic_kernel( 

98 inp, 

99 q, 

100 out, 

101 N, 

102 M, 

103 Q, 

104 BLOCK_Q: tl.constexpr, 

105 BLOCK_M: tl.constexpr, 

106 interpolation: tl.constexpr, 

107): 

108 pid = tle.program_id(0) 

109 ctype = inp.dtype.element_ty 

110 

111 cols = tl.arange(0, BLOCK_M) 

112 mask_M = cols < M 

113 row_ptr = inp + pid * M 

114 mask_val = _get_finfo_val(ctype, return_max=True) 

115 vals = tl.load(row_ptr + cols, mask=mask_M, other=mask_val) 

116 vals = tl.where(vals.dtype.is_fp64(), vals, vals.to(tl.float32)) 

117 ids = tl.arange(0, BLOCK_M) 

118 sorted_vals, _ = argsort(vals, ids, 0, descending=False) 

119 

120 offsets_Q = tl.arange(0, BLOCK_Q) 

121 mask_Q = offsets_Q < Q 

122 q_vals = tl.load(q + offsets_Q, mask=mask_Q, other=0.0).to(tl.float32) 

123 q_scaled = q_vals * (M - 1) 

124 q_lower = tl.floor(q_scaled).to(tl.int32) 

125 q_upper = tl.ceil(q_scaled).to(tl.int32) 

126 

127 idx = tl.arange(0, BLOCK_M)[:, None] 

128 mask_lower = idx == q_lower[None, :] 

129 mask_upper = idx == q_upper[None, :] 

130 mask_lower_f = mask_lower.to(tl.float32) 

131 mask_upper_f = mask_upper.to(tl.float32) 

132 lower_vals = tl.sum(sorted_vals[:, None] * mask_lower_f, axis=0) 

133 upper_vals = tl.sum(sorted_vals[:, None] * mask_upper_f, axis=0) 

134 

135 if interpolation == "linear": 

136 q_frac = q_scaled - q_lower 

137 out_vals = lower_vals + (upper_vals - lower_vals) * q_frac 

138 elif interpolation == "lower": 

139 out_vals = lower_vals 

140 elif interpolation == "higher": 

141 out_vals = upper_vals 

142 elif interpolation == "nearest": 

143 q_round = tl_extra_shim.rint(q_scaled).to(tl.int32) 

144 out_vals = tl.where(q_round == q_upper, upper_vals, lower_vals) 

145 elif interpolation == "midpoint": 

146 out_vals = (lower_vals + upper_vals) * 0.5 

147 

148 out_ptr = out + pid * Q + offsets_Q 

149 tl.store(out_ptr, out_vals.to(ctype), mask=mask_Q) 

150 

151 

152def quantile( 

153 inp, q, dim=None, keepdim=False, interpolation="linear", out=None 

154) -> Tensor: 

155 logger.debug("GEMS QUANTILE DIM") 

156 assert torch.is_floating_point(inp) 

157 assert dim is None or isinstance(dim, int) 

158 assert isinstance(q, (float, torch.Tensor)) 

159 assert interpolation in INTERPOLATION_METHOD 

160 

161 # Handle dim 

162 if dim is None: 

163 inp = inp.ravel() 

164 dim = 0 

165 if dim < 0: 

166 dim = dim + inp.ndim 

167 

168 # Handle q 

169 q_all_ones = False 

170 q_all_zeros = False 

171 if isinstance(q, float): 

172 q_all_ones = q == 1.0 

173 q_all_zeros = q == 0.0 

174 q = torch.tensor(q, device=inp.device, dtype=inp.dtype) 

175 Q = 1 

176 else: 

177 q = q.to(device=inp.device, dtype=inp.dtype) 

178 Q = 1 if q.numel() == 1 else len(q) 

179 

180 assert torch.all(q >= 0.0) and torch.all(q <= 1.0) 

181 

182 # Fast path: q == 0.0 -> min, q == 1.0 -> max (no sort needed) 

183 if q_all_ones or q_all_zeros: 

184 reduce_fn = torch.amax if q_all_ones else torch.amin 

185 if out is not None and Q == 1: 

186 reduce_fn(inp, dim=dim, keepdim=keepdim, out=out) 

187 return out 

188 output = reduce_fn(inp, dim=dim, keepdim=keepdim) 

189 if Q > 1: 

190 output = output.unsqueeze(0).expand(Q, *output.shape) 

191 if out is not None: 

192 out.copy_(output) 

193 return out 

194 return output 

195 

196 # handle input tensor 

197 if dim != inp.ndim - 1: 

198 inp = torch.movedim(inp, dim, -1).contiguous() 

199 else: 

200 inp = inp.contiguous() 

201 

202 M = inp.size(-1) 

203 N = inp.numel() // M 

204 

205 output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device) 

206 if M <= MAX_BITONIC_M: 

207 BLOCK_M = triton.next_power_of_2(M) 

208 BLOCK_Q = triton.next_power_of_2(min(Q, 16)) 

209 grid = (N,) 

210 with torch_device_fn.device(inp.device): 

211 quantile_bitonic_kernel[grid]( 

212 inp, 

213 q, 

214 output, 

215 N, 

216 M, 

217 Q, 

218 BLOCK_Q=BLOCK_Q, 

219 BLOCK_M=BLOCK_M, 

220 interpolation=interpolation, 

221 ) 

222 else: 

223 sorted_vals, _ = inp.sort(dim=-1) 

224 grid = lambda meta: ( 

225 triton.cdiv(Q, meta["BLOCK_Q"]), 

226 triton.cdiv(N, meta["BLOCK_N"]), 

227 ) 

228 with torch_device_fn.device(inp.device): 

229 quantile_kernel[grid]( 

230 sorted_vals, q, output, N, M, Q, interpolation=interpolation 

231 ) 

232 

233 if Q == 1: 

234 output = output.squeeze(-1) 

235 else: 

236 output = output.movedim(-1, 0) 

237 if keepdim: 

238 output = output.unsqueeze(dim + (1 if Q != 1 else 0)) 

239 

240 if out is not None: 

241 out.copy_(output) 

242 return output