Coverage for src/flag_gems/ops/per_token_group_quant_fp8.py: 45%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1from typing import Optional, Tuple 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils.device_info import get_device_capability 

9 

10if torch_device_fn.is_available() and get_device_capability() >= (9, 0): 

11 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn 

12else: 

13 SUPPORTED_FP8_DTYPE = torch.float32 

14 

15 

16@triton.jit 

17def _per_token_group_quant_fp8( 

18 y_ptr, 

19 y_q_ptr, 

20 y_s_ptr, 

21 group_size, 

22 y_num_columns, 

23 y_row_stride, 

24 eps, 

25 fp8_min, 

26 fp8_max, 

27 scale_ue8m0, 

28 BLOCK: tl.constexpr, 

29): 

30 groups_per_row = y_num_columns // group_size 

31 

32 g_id = tl.program_id(0) 

33 row = g_id // groups_per_row 

34 row_g_id = g_id % groups_per_row 

35 

36 y_ptr += (row * y_row_stride) + (row_g_id * group_size) 

37 y_q_ptr += g_id * group_size 

38 y_s_ptr += g_id 

39 

40 cols = tl.arange(0, BLOCK) 

41 mask = cols < group_size 

42 

43 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) 

44 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) 

45 y_s = _absmax / fp8_max 

46 

47 if scale_ue8m0: 

48 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) 

49 

50 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

51 

52 tl.store(y_q_ptr + cols, y_q, mask=mask) 

53 tl.store(y_s_ptr, y_s) 

54 

55 

56@triton.jit 

57def _per_token_group_quant_fp8_colmajor( 

58 y_ptr, 

59 y_q_ptr, 

60 y_s_ptr, 

61 group_size, 

62 y_num_columns, 

63 y_row_stride, 

64 y_s_col_stride, 

65 eps, 

66 fp8_min, 

67 fp8_max, 

68 scale_ue8m0, 

69 BLOCK: tl.constexpr, 

70): 

71 groups_per_row = y_num_columns // group_size 

72 

73 g_id = tl.program_id(0) 

74 row = g_id // groups_per_row 

75 group_id = g_id % groups_per_row 

76 

77 y_ptr += row * y_row_stride + group_id * group_size 

78 y_q_ptr += g_id * group_size 

79 y_s_ptr += group_id * y_s_col_stride + row 

80 

81 cols = tl.arange(0, BLOCK) 

82 mask = cols < group_size 

83 

84 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) 

85 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) 

86 y_s = _absmax / fp8_max 

87 

88 if scale_ue8m0: 

89 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) 

90 

91 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

92 

93 tl.store(y_q_ptr + cols, y_q, mask=mask) 

94 tl.store(y_s_ptr, y_s) 

95 

96 

97def per_token_group_quant_fp8( 

98 x: torch.Tensor, 

99 group_size: int, 

100 eps: float = 1e-10, 

101 dtype: Optional[torch.dtype] = None, 

102 column_major_scales: bool = False, 

103 scale_ue8m0: bool = False, 

104) -> Tuple[torch.Tensor, torch.Tensor]: 

105 # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` 

106 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype 

107 assert x.shape[-1] % group_size == 0, ( 

108 f"the last dimension of `x` {x.shape[-1]} must be divisible " 

109 f"by `group_size` {group_size}" 

110 ) 

111 assert x.stride(-1) == 1, "`x` groups must be contiguous" 

112 

113 finfo = torch.finfo(fp8_dtype) 

114 fp8_min = finfo.min 

115 fp8_max = finfo.max 

116 

117 x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) 

118 M = x.numel() // group_size 

119 N = group_size 

120 

121 if column_major_scales: 

122 shape = (x.shape[-1] // group_size,) + x.shape[:-1] 

123 x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) 

124 else: 

125 shape = x.shape[:-1] + (x.shape[-1] // group_size,) 

126 x_s = torch.empty(shape, device=x.device, dtype=torch.float32) 

127 

128 BLOCK = triton.next_power_of_2(N) 

129 num_warps = min(max(BLOCK // 256, 1), 8) 

130 num_stages = 1 

131 if column_major_scales: 

132 _per_token_group_quant_fp8_colmajor[(M,)]( 

133 x, 

134 x_q, 

135 x_s, 

136 group_size, 

137 x.shape[1], 

138 x.stride(0), 

139 x_s.stride(1), 

140 eps, 

141 fp8_min=fp8_min, 

142 fp8_max=fp8_max, 

143 scale_ue8m0=scale_ue8m0, 

144 BLOCK=BLOCK, 

145 num_warps=num_warps, 

146 num_stages=num_stages, 

147 ) 

148 else: 

149 _per_token_group_quant_fp8[(M,)]( 

150 x, 

151 x_q, 

152 x_s, 

153 group_size, 

154 x.shape[1], 

155 x.stride(0), 

156 eps, 

157 fp8_min=fp8_min, 

158 fp8_max=fp8_max, 

159 scale_ue8m0=scale_ue8m0, 

160 BLOCK=BLOCK, 

161 num_warps=num_warps, 

162 num_stages=num_stages, 

163 ) 

164 

165 return x_q, x_s