Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/per_token_group_quant_fp8.py: 0%

67 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1from typing import Optional, Tuple 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): 

8 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn 

9else: 

10 SUPPORTED_FP8_DTYPE = torch.float32 

11 

12 

13@triton.jit 

14def _per_token_group_quant_fp8( 

15 y_ptr, 

16 y_q_ptr, 

17 y_s_ptr, 

18 group_size, 

19 y_num_columns, 

20 y_row_stride, 

21 eps, 

22 fp8_min, 

23 fp8_max, 

24 scale_ue8m0, 

25 BLOCK: tl.constexpr, 

26): 

27 groups_per_row = y_num_columns // group_size 

28 

29 g_id = tl.program_id(0) 

30 row = g_id // groups_per_row 

31 row_g_id = g_id % groups_per_row 

32 

33 y_ptr += (row * y_row_stride) + (row_g_id * group_size) 

34 y_q_ptr += g_id * group_size 

35 y_s_ptr += g_id 

36 

37 cols = tl.arange(0, BLOCK) 

38 mask = cols < group_size 

39 

40 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) 

41 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) 

42 y_s = _absmax / fp8_max 

43 

44 if scale_ue8m0: 

45 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) 

46 

47 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

48 

49 tl.store(y_q_ptr + cols, y_q, mask=mask) 

50 tl.store(y_s_ptr, y_s) 

51 

52 

53@triton.jit 

54def _per_token_group_quant_fp8_colmajor( 

55 y_ptr, 

56 y_q_ptr, 

57 y_s_ptr, 

58 group_size, 

59 y_num_columns, 

60 y_row_stride, 

61 y_s_col_stride, 

62 eps, 

63 fp8_min, 

64 fp8_max, 

65 scale_ue8m0, 

66 BLOCK: tl.constexpr, 

67): 

68 groups_per_row = y_num_columns // group_size 

69 

70 g_id = tl.program_id(0) 

71 row = g_id // groups_per_row 

72 group_id = g_id % groups_per_row 

73 

74 y_ptr += row * y_row_stride + group_id * group_size 

75 y_q_ptr += g_id * group_size 

76 y_s_ptr += group_id * y_s_col_stride + row 

77 

78 cols = tl.arange(0, BLOCK) 

79 mask = cols < group_size 

80 

81 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) 

82 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) 

83 y_s = _absmax / fp8_max 

84 

85 if scale_ue8m0: 

86 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) 

87 

88 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

89 

90 tl.store(y_q_ptr + cols, y_q, mask=mask) 

91 tl.store(y_s_ptr, y_s) 

92 

93 

94def per_token_group_quant_fp8( 

95 x: torch.Tensor, 

96 group_size: int, 

97 eps: float = 1e-10, 

98 dtype: Optional[torch.dtype] = None, 

99 column_major_scales: bool = False, 

100 scale_ue8m0: bool = False, 

101) -> Tuple[torch.Tensor, torch.Tensor]: 

102 # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` 

103 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype 

104 assert x.shape[-1] % group_size == 0, ( 

105 f"the last dimension of `x` {x.shape[-1]} must be divisible " 

106 f"by `group_size` {group_size}" 

107 ) 

108 assert x.stride(-1) == 1, "`x` groups must be contiguous" 

109 

110 finfo = torch.finfo(fp8_dtype) 

111 fp8_min = finfo.min 

112 fp8_max = finfo.max 

113 

114 x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) 

115 M = x.numel() // group_size 

116 N = group_size 

117 

118 if column_major_scales: 

119 shape = (x.shape[-1] // group_size,) + x.shape[:-1] 

120 x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) 

121 else: 

122 shape = x.shape[:-1] + (x.shape[-1] // group_size,) 

123 x_s = torch.empty(shape, device=x.device, dtype=torch.float32) 

124 

125 BLOCK = triton.next_power_of_2(N) 

126 num_warps = min(max(BLOCK // 256, 1), 8) 

127 num_stages = 1 

128 if column_major_scales: 

129 _per_token_group_quant_fp8_colmajor[(M,)]( 

130 x, 

131 x_q, 

132 x_s, 

133 group_size, 

134 x.shape[1], 

135 x.stride(0), 

136 x_s.stride(1), 

137 eps, 

138 fp8_min=fp8_min, 

139 fp8_max=fp8_max, 

140 scale_ue8m0=scale_ue8m0, 

141 BLOCK=BLOCK, 

142 num_warps=num_warps, 

143 num_stages=num_stages, 

144 ) 

145 else: 

146 _per_token_group_quant_fp8[(M,)]( 

147 x, 

148 x_q, 

149 x_s, 

150 group_size, 

151 x.shape[1], 

152 x.stride(0), 

153 eps, 

154 fp8_min=fp8_min, 

155 fp8_max=fp8_max, 

156 scale_ue8m0=scale_ue8m0, 

157 BLOCK=BLOCK, 

158 num_warps=num_warps, 

159 num_stages=num_stages, 

160 ) 

161 

162 return x_q, x_s