Coverage for src/flag_gems/runtime/backend/_cambricon/ops/per_token_group_quant_fp8.py: 0%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from ..utils import MAX_GRID_SIZE_X 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): 

13 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn 

14else: 

15 SUPPORTED_FP8_DTYPE = torch.float32 

16 

17 

18@triton.jit 

19def _per_token_group_quant_fp8( 

20 y_ptr, 

21 y_q_ptr, 

22 y_s_ptr, 

23 group_size, 

24 y_num_columns, 

25 y_row_stride, 

26 eps, 

27 fp8_min, 

28 fp8_max, 

29 scale_ue8m0, 

30 BLOCK: tl.constexpr, 

31 M: tl.constexpr, 

32): 

33 groups_per_row = y_num_columns // group_size 

34 

35 grid_0 = tl.num_programs(0) 

36 g_id = tl.program_id(0) 

37 while g_id < M: 

38 row = g_id // groups_per_row 

39 row_g_id = g_id % groups_per_row 

40 

41 y_ptr_offset = (row * y_row_stride) + (row_g_id * group_size) 

42 y_q_ptr_offset = g_id * group_size 

43 y_s_ptr_offset = g_id 

44 

45 cols = tl.arange(0, BLOCK) 

46 mask = cols < group_size 

47 

48 y = tl.load(y_ptr + cols + y_ptr_offset, mask=mask, other=0.0).to(tl.float32) 

49 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) 

50 y_s = _absmax / fp8_max 

51 if scale_ue8m0: 

52 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) 

53 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

54 

55 tl.store(y_q_ptr + cols + y_q_ptr_offset, y_q, mask=mask) 

56 tl.store(y_s_ptr + y_s_ptr_offset, y_s) 

57 g_id += grid_0 

58 

59 

60@triton.jit 

61def _per_token_group_quant_fp8_colmajor( 

62 y_ptr, 

63 y_q_ptr, 

64 y_s_ptr, 

65 group_size, 

66 y_num_columns, 

67 y_row_stride, 

68 y_s_col_stride, 

69 eps, 

70 fp8_min, 

71 fp8_max, 

72 scale_ue8m0, 

73 BLOCK: tl.constexpr, 

74 M: tl.constexpr, 

75): 

76 groups_per_row = y_num_columns // group_size 

77 grid_0 = tl.num_programs(0) 

78 g_id = tl.program_id(0) 

79 while g_id < M: 

80 row = g_id // groups_per_row 

81 group_id = g_id % groups_per_row 

82 

83 y_ptr_offset = row * y_row_stride + group_id * group_size 

84 y_q_ptr_offset = g_id * group_size 

85 y_s_ptr_offset = group_id * y_s_col_stride + row 

86 

87 cols = tl.arange(0, BLOCK) 

88 mask = cols < group_size 

89 

90 y = tl.load(y_ptr + cols + y_ptr_offset, mask=mask, other=0.0).to(tl.float32) 

91 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) 

92 y_s = _absmax / fp8_max 

93 if scale_ue8m0: 

94 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) 

95 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

96 

97 tl.store(y_q_ptr + cols + y_q_ptr_offset, y_q, mask=mask) 

98 tl.store(y_s_ptr + y_s_ptr_offset, y_s) 

99 g_id += grid_0 

100 

101 

102def per_token_group_quant_fp8( 

103 x: torch.Tensor, 

104 group_size: int, 

105 eps: float = 1e-10, 

106 dtype: Optional[torch.dtype] = None, 

107 column_major_scales: bool = False, 

108 scale_ue8m0: bool = False, 

109) -> Tuple[torch.Tensor, torch.Tensor]: 

110 logger.debug("GEMS_CAMBRICON PER_TOKEN_GROUP_QUANT_FP8") 

111 # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` 

112 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype 

113 assert x.shape[-1] % group_size == 0, ( 

114 f"the last dimension of `x` {x.shape[-1]} must be divisible " 

115 f"by `group_size` {group_size}" 

116 ) 

117 assert x.stride(-1) == 1, "`x` groups must be contiguous" 

118 

119 finfo = torch.finfo(fp8_dtype) 

120 fp8_min = finfo.min 

121 fp8_max = finfo.max 

122 

123 x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) 

124 M = x.numel() // group_size 

125 N = group_size 

126 

127 if column_major_scales: 

128 shape = (x.shape[-1] // group_size,) + x.shape[:-1] 

129 x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) 

130 else: 

131 shape = x.shape[:-1] + (x.shape[-1] // group_size,) 

132 x_s = torch.empty(shape, device=x.device, dtype=torch.float32) 

133 

134 BLOCK = triton.next_power_of_2(N) 

135 num_warps = min(max(BLOCK // 256, 1), 8) 

136 num_stages = 1 

137 grid = min(M, MAX_GRID_SIZE_X // 4) 

138 if column_major_scales: 

139 _per_token_group_quant_fp8_colmajor[(grid,)]( 

140 x, 

141 x_q, 

142 x_s, 

143 group_size, 

144 x.shape[1], 

145 x.stride(0), 

146 x_s.stride(1), 

147 eps, 

148 fp8_min=fp8_min, 

149 fp8_max=fp8_max, 

150 scale_ue8m0=scale_ue8m0, 

151 BLOCK=BLOCK, 

152 num_warps=num_warps, 

153 num_stages=num_stages, 

154 M=M, 

155 ) 

156 else: 

157 _per_token_group_quant_fp8[(grid,)]( 

158 x, 

159 x_q, 

160 x_s, 

161 group_size, 

162 x.shape[1], 

163 x.stride(0), 

164 eps, 

165 fp8_min=fp8_min, 

166 fp8_max=fp8_max, 

167 scale_ue8m0=scale_ue8m0, 

168 BLOCK=BLOCK, 

169 num_warps=num_warps, 

170 num_stages=num_stages, 

171 M=M, 

172 ) 

173 

174 return x_q, x_s