Coverage for src/flag_gems/ops/per_token_group_quant_fp8.py: 47%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils.device_info import get_device_capability 

10 

11if torch_device_fn.is_available() and get_device_capability() >= (9, 0): 

12 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn 

13else: 

14 SUPPORTED_FP8_DTYPE = torch.float32 

15 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20@triton.jit 

21def _per_token_group_quant_fp8( 

22 y_ptr, 

23 y_q_ptr, 

24 y_s_ptr, 

25 group_size, 

26 y_num_columns, 

27 y_row_stride, 

28 eps, 

29 fp8_min, 

30 fp8_max, 

31 scale_ue8m0, 

32 BLOCK: tl.constexpr, 

33): 

34 groups_per_row = y_num_columns // group_size 

35 

36 g_id = tl.program_id(0) 

37 row = g_id // groups_per_row 

38 row_g_id = g_id % groups_per_row 

39 

40 y_ptr += (row * y_row_stride) + (row_g_id * group_size) 

41 y_q_ptr += g_id * group_size 

42 y_s_ptr += g_id 

43 

44 cols = tl.arange(0, BLOCK) 

45 mask = cols < group_size 

46 

47 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) 

48 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) 

49 y_s = _absmax / fp8_max 

50 

51 if scale_ue8m0: 

52 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) 

53 

54 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

55 

56 tl.store(y_q_ptr + cols, y_q, mask=mask) 

57 tl.store(y_s_ptr, y_s) 

58 

59 

60@triton.jit 

61def _per_token_group_quant_fp8_colmajor( 

62 y_ptr, 

63 y_q_ptr, 

64 y_s_ptr, 

65 group_size, 

66 y_num_columns, 

67 y_row_stride, 

68 y_s_col_stride, 

69 eps, 

70 fp8_min, 

71 fp8_max, 

72 scale_ue8m0, 

73 BLOCK: tl.constexpr, 

74): 

75 groups_per_row = y_num_columns // group_size 

76 

77 g_id = tl.program_id(0) 

78 row = g_id // groups_per_row 

79 group_id = g_id % groups_per_row 

80 

81 y_ptr += row * y_row_stride + group_id * group_size 

82 y_q_ptr += g_id * group_size 

83 y_s_ptr += group_id * y_s_col_stride + row 

84 

85 cols = tl.arange(0, BLOCK) 

86 mask = cols < group_size 

87 

88 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) 

89 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) 

90 y_s = _absmax / fp8_max 

91 

92 if scale_ue8m0: 

93 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) 

94 

95 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

96 

97 tl.store(y_q_ptr + cols, y_q, mask=mask) 

98 tl.store(y_s_ptr, y_s) 

99 

100 

101def per_token_group_quant_fp8( 

102 x: torch.Tensor, 

103 group_size: int, 

104 eps: float = 1e-10, 

105 dtype: Optional[torch.dtype] = None, 

106 column_major_scales: bool = False, 

107 scale_ue8m0: bool = False, 

108) -> Tuple[torch.Tensor, torch.Tensor]: 

109 logger.debug("GEMS PER TOKEN GROUP QUANT FP8") 

110 # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` 

111 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype 

112 assert x.shape[-1] % group_size == 0, ( 

113 f"the last dimension of `x` {x.shape[-1]} must be divisible " 

114 f"by `group_size` {group_size}" 

115 ) 

116 assert x.stride(-1) == 1, "`x` groups must be contiguous" 

117 

118 finfo = torch.finfo(fp8_dtype) 

119 fp8_min = finfo.min 

120 fp8_max = finfo.max 

121 

122 x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) 

123 M = x.numel() // group_size 

124 N = group_size 

125 

126 if column_major_scales: 

127 shape = (x.shape[-1] // group_size,) + x.shape[:-1] 

128 x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) 

129 else: 

130 shape = x.shape[:-1] + (x.shape[-1] // group_size,) 

131 x_s = torch.empty(shape, device=x.device, dtype=torch.float32) 

132 

133 BLOCK = triton.next_power_of_2(N) 

134 num_warps = min(max(BLOCK // 256, 1), 8) 

135 num_stages = 1 

136 if column_major_scales: 

137 _per_token_group_quant_fp8_colmajor[(M,)]( 

138 x, 

139 x_q, 

140 x_s, 

141 group_size, 

142 x.shape[1], 

143 x.stride(0), 

144 x_s.stride(1), 

145 eps, 

146 fp8_min=fp8_min, 

147 fp8_max=fp8_max, 

148 scale_ue8m0=scale_ue8m0, 

149 BLOCK=BLOCK, 

150 num_warps=num_warps, 

151 num_stages=num_stages, 

152 ) 

153 else: 

154 _per_token_group_quant_fp8[(M,)]( 

155 x, 

156 x_q, 

157 x_s, 

158 group_size, 

159 x.shape[1], 

160 x.stride(0), 

161 eps, 

162 fp8_min=fp8_min, 

163 fp8_max=fp8_max, 

164 scale_ue8m0=scale_ue8m0, 

165 BLOCK=BLOCK, 

166 num_warps=num_warps, 

167 num_stages=num_stages, 

168 ) 

169 

170 return x_q, x_s