Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/quantile.py: 0%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch import Tensor 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"] 

14 

15 

16# def heur_block_q(args): 

17# return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16)) 

18 

19 

20# def heur_block_n(args): 

21# if args["N"] >= 65536: 

22# return triton.next_power_of_2(triton.cdiv(args["N"], 512)) 

23# elif args["N"] >= 4096: 

24# return triton.next_power_of_2(triton.cdiv(args["N"], 128)) 

25# elif args["N"] >= 64: 

26# return 32 

27# elif args["N"] >= 32: 

28# return 4 

29# else: 

30# return 1 

31 

32 

33def heur_block_q(args): 

34 import builtins 

35 

36 return builtins.min(triton.next_power_of_2(args["Q"]), 1024) 

37 

38 

39def heur_block_n(args): 

40 import builtins 

41 

42 return builtins.min(triton.next_power_of_2(args["N"]), 1024) 

43 

44 

45@libentry() 

46@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n}) 

47@triton.jit 

48def quantile_kernel( 

49 inp, 

50 q, 

51 out, 

52 N, 

53 M, 

54 Q, 

55 BLOCK_Q: tl.constexpr, 

56 BLOCK_N: tl.constexpr, 

57 interpolation: tl.constexpr, 

58): 

59 pid_Q = tle.program_id(0) 

60 pid_N = tle.program_id(1) 

61 ctype = inp.dtype.element_ty 

62 

63 offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q) 

64 mask_Q = offsets_Q < Q 

65 q_ptrs = q + offsets_Q 

66 

67 offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N) 

68 mask_N = offsets_N < N 

69 

70 out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :] 

71 mask_out = mask_N[:, None] & mask_Q[None, :] 

72 

73 q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1) 

74 q_lower = tl.floor(q_block).to(tl.int32) 

75 q_upper = tl.ceil(q_block).to(tl.int32) 

76 

77 inp_lower = tl.load( 

78 inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0 

79 ) 

80 inp_upper = tl.load( 

81 inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0 

82 ) 

83 

84 if interpolation == "linear": 

85 q_frac = q_block - q_lower 

86 tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out) 

87 

88 elif interpolation == "lower": 

89 tl.store(out_ptrs, inp_lower, mask_out) 

90 

91 elif interpolation == "higher": 

92 tl.store(out_ptrs, inp_upper, mask_out) 

93 

94 elif interpolation == "nearest": 

95 q_round = tl.extra.xpu.libdevice.rint(q_block) 

96 out_block = tl.where(q_round == q_upper, inp_upper, inp_lower) 

97 tl.store(out_ptrs, out_block, mask_out) 

98 

99 elif interpolation == "midpoint": 

100 tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out) 

101 

102 

103def quantile( 

104 inp, q, dim=None, keepdim=False, interpolation="linear", out=None 

105) -> Tensor: 

106 logger.debug("GEMS QUANTILE DIM") 

107 assert torch.is_floating_point(inp) 

108 assert dim is None or isinstance(dim, int) 

109 assert isinstance(q, (float, torch.Tensor)) 

110 assert interpolation in INTERPOLATION_METHOD 

111 

112 M = inp.numel() 

113 if isinstance(q, float): 

114 q = torch.tensor(q, device=inp.device) 

115 Q = 1 

116 else: 

117 Q = 1 if q.numel() == 1 else len(q) 

118 

119 assert M > 0 

120 assert Q > 0 

121 assert torch.all(q >= 0.0) and torch.all(q <= 1.0) 

122 

123 if dim is None: 

124 inp = inp.ravel() 

125 dim = 0 

126 

127 shape = list(inp.shape) 

128 

129 dim %= inp.ndim 

130 inp = dim_compress(inp, dim) 

131 M = shape[dim] 

132 N = inp.numel() // M 

133 

134 inp, _ = inp.sort() # Sort the input with torch.sort() 

135 output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device) 

136 

137 grid = lambda meta: ( 

138 triton.cdiv(Q, meta["BLOCK_Q"]), 

139 triton.cdiv(N, meta["BLOCK_N"]), 

140 ) 

141 

142 with torch_device_fn.device(inp.device): 

143 quantile_kernel[grid](inp, q, output, N, M, Q, interpolation=interpolation) 

144 

145 output = output.permute( 

146 (-1,) + tuple(range(0, inp.ndim - 1)) 

147 ) # Same as torch.quantile() 

148 if keepdim: 

149 output = output.unsqueeze(dim + 1) 

150 if Q == 1: 

151 output = output.squeeze(0) 

152 

153 if out is not None: 

154 out.copy_(output) 

155 return output