Coverage for src/flag_gems/runtime/backend/_cambricon/ops/per_token_group_quant_fp8.py: 0%

80 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils.device_info import get_device_capability 

10 

11from ..utils import MAX_GRID_SIZE_X 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15if torch_device_fn.is_available() and get_device_capability() >= (9, 0): 

16 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn 

17else: 

18 SUPPORTED_FP8_DTYPE = torch.float32 

19 

20 

21@triton.jit 

22def _per_token_group_quant_fp8( 

23 y_ptr, 

24 y_q_ptr, 

25 y_s_ptr, 

26 group_size, 

27 y_num_columns, 

28 y_row_stride, 

29 eps, 

30 fp8_min, 

31 fp8_max, 

32 scale_ue8m0, 

33 BLOCK: tl.constexpr, 

34 M: tl.constexpr, 

35): 

36 groups_per_row = y_num_columns // group_size 

37 

38 grid_0 = tl.num_programs(0) 

39 g_id = tl.program_id(0) 

40 while g_id < M: 

41 row = g_id // groups_per_row 

42 row_g_id = g_id % groups_per_row 

43 

44 y_ptr_offset = (row * y_row_stride) + (row_g_id * group_size) 

45 y_q_ptr_offset = g_id * group_size 

46 y_s_ptr_offset = g_id 

47 

48 cols = tl.arange(0, BLOCK) 

49 mask = cols < group_size 

50 

51 y = tl.load(y_ptr + cols + y_ptr_offset, mask=mask, other=0.0).to(tl.float32) 

52 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) 

53 y_s = _absmax / fp8_max 

54 if scale_ue8m0: 

55 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) 

56 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

57 

58 tl.store(y_q_ptr + cols + y_q_ptr_offset, y_q, mask=mask) 

59 tl.store(y_s_ptr + y_s_ptr_offset, y_s) 

60 g_id += grid_0 

61 

62 

63@triton.jit 

64def _per_token_group_quant_fp8_colmajor( 

65 y_ptr, 

66 y_q_ptr, 

67 y_s_ptr, 

68 group_size, 

69 y_num_columns, 

70 y_row_stride, 

71 y_s_col_stride, 

72 eps, 

73 fp8_min, 

74 fp8_max, 

75 scale_ue8m0, 

76 BLOCK: tl.constexpr, 

77 M: tl.constexpr, 

78): 

79 groups_per_row = y_num_columns // group_size 

80 grid_0 = tl.num_programs(0) 

81 g_id = tl.program_id(0) 

82 while g_id < M: 

83 row = g_id // groups_per_row 

84 group_id = g_id % groups_per_row 

85 

86 y_ptr_offset = row * y_row_stride + group_id * group_size 

87 y_q_ptr_offset = g_id * group_size 

88 y_s_ptr_offset = group_id * y_s_col_stride + row 

89 

90 cols = tl.arange(0, BLOCK) 

91 mask = cols < group_size 

92 

93 y = tl.load(y_ptr + cols + y_ptr_offset, mask=mask, other=0.0).to(tl.float32) 

94 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) 

95 y_s = _absmax / fp8_max 

96 if scale_ue8m0: 

97 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) 

98 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

99 

100 tl.store(y_q_ptr + cols + y_q_ptr_offset, y_q, mask=mask) 

101 tl.store(y_s_ptr + y_s_ptr_offset, y_s) 

102 g_id += grid_0 

103 

104 

105def per_token_group_quant_fp8( 

106 x: torch.Tensor, 

107 group_size: int, 

108 eps: float = 1e-10, 

109 dtype: Optional[torch.dtype] = None, 

110 column_major_scales: bool = False, 

111 scale_ue8m0: bool = False, 

112) -> Tuple[torch.Tensor, torch.Tensor]: 

113 logger.debug("GEMS_CAMBRICON PER_TOKEN_GROUP_QUANT_FP8") 

114 # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` 

115 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype 

116 assert x.shape[-1] % group_size == 0, ( 

117 f"the last dimension of `x` {x.shape[-1]} must be divisible " 

118 f"by `group_size` {group_size}" 

119 ) 

120 assert x.stride(-1) == 1, "`x` groups must be contiguous" 

121 

122 finfo = torch.finfo(fp8_dtype) 

123 fp8_min = finfo.min 

124 fp8_max = finfo.max 

125 

126 x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) 

127 M = x.numel() // group_size 

128 N = group_size 

129 

130 if column_major_scales: 

131 shape = (x.shape[-1] // group_size,) + x.shape[:-1] 

132 x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) 

133 else: 

134 shape = x.shape[:-1] + (x.shape[-1] // group_size,) 

135 x_s = torch.empty(shape, device=x.device, dtype=torch.float32) 

136 

137 BLOCK = triton.next_power_of_2(N) 

138 num_warps = min(max(BLOCK // 256, 1), 8) 

139 num_stages = 1 

140 grid = min(M, MAX_GRID_SIZE_X // 4) 

141 if column_major_scales: 

142 _per_token_group_quant_fp8_colmajor[(grid,)]( 

143 x, 

144 x_q, 

145 x_s, 

146 group_size, 

147 x.shape[1], 

148 x.stride(0), 

149 x_s.stride(1), 

150 eps, 

151 fp8_min=fp8_min, 

152 fp8_max=fp8_max, 

153 scale_ue8m0=scale_ue8m0, 

154 BLOCK=BLOCK, 

155 num_warps=num_warps, 

156 num_stages=num_stages, 

157 M=M, 

158 ) 

159 else: 

160 _per_token_group_quant_fp8[(grid,)]( 

161 x, 

162 x_q, 

163 x_s, 

164 group_size, 

165 x.shape[1], 

166 x.stride(0), 

167 eps, 

168 fp8_min=fp8_min, 

169 fp8_max=fp8_max, 

170 scale_ue8m0=scale_ue8m0, 

171 BLOCK=BLOCK, 

172 num_warps=num_warps, 

173 num_stages=num_stages, 

174 M=M, 

175 ) 

176 

177 return x_q, x_s