Coverage for src/flag_gems/fused/fused_inv_rope_fp8_quant.py: 14%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils.device_info import get_device_capability 

10 

11if torch_device_fn.is_available() and get_device_capability() >= (9, 0): 

12 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn 

13else: 

14 SUPPORTED_FP8_DTYPE = torch.float32 

15 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20def _get_tma_aligned_size(size: int, align: int) -> int: 

21 return ((size + align - 1) // align) * align 

22 

23 

24@triton.jit 

25def _fused_inv_rope_fp8_quant_per_head( 

26 o_ptr, 

27 positions_ptr, 

28 cos_sin_cache_ptr, 

29 fp8_ptr, 

30 scale_ptr, 

31 num_tokens, 

32 heads_per_group: tl.constexpr, 

33 o_stride_token, 

34 o_stride_head, 

35 cache_stride_pos, 

36 fp8_stride_group, 

37 fp8_stride_token, 

38 scale_stride_group, 

39 scale_stride_k, 

40 fp8_max: tl.constexpr, 

41 eps: tl.constexpr, 

42 QUANT_GROUP_SIZE: tl.constexpr, 

43 CHUNKS_PER_HEAD: tl.constexpr, 

44 ROPE_START: tl.constexpr, 

45 HALF_ROPE: tl.constexpr, 

46 TMA_ALIGNED_SCALES: tl.constexpr, 

47): 

48 pid_token = tl.program_id(0).to(tl.int64) 

49 pid_gh = tl.program_id(1).to(tl.int64) 

50 

51 g = pid_gh // heads_per_group 

52 head_in_group = pid_gh % heads_per_group 

53 global_head = pid_gh 

54 qb_start = head_in_group * CHUNKS_PER_HEAD 

55 

56 if pid_token >= num_tokens: 

57 if TMA_ALIGNED_SCALES: 

58 scale_addr = ( 

59 scale_ptr 

60 + g * scale_stride_group 

61 + pid_token 

62 + head_in_group * scale_stride_k 

63 ) 

64 tl.store(scale_addr, tl.zeros((), dtype=tl.int32)) 

65 else: 

66 block_offsets = tl.arange(0, CHUNKS_PER_HEAD) 

67 qb_indices = qb_start + block_offsets 

68 scale_addrs = ( 

69 scale_ptr 

70 + g * scale_stride_group 

71 + pid_token 

72 + qb_indices * scale_stride_k 

73 ) 

74 tl.store(scale_addrs, tl.zeros((CHUNKS_PER_HEAD,), dtype=tl.float32)) 

75 return 

76 

77 input_base = o_ptr + pid_token * o_stride_token + global_head * o_stride_head 

78 

79 HEAD_DIM: tl.constexpr = CHUNKS_PER_HEAD * QUANT_GROUP_SIZE 

80 offsets = tl.arange(0, HEAD_DIM) 

81 x = tl.load(input_base + offsets).to(tl.float32) 

82 

83 rope_abs_start: tl.constexpr = (CHUNKS_PER_HEAD - 1) * QUANT_GROUP_SIZE + ROPE_START 

84 pos = tl.load(positions_ptr + pid_token) 

85 cache_base = cos_sin_cache_ptr + pos * cache_stride_pos 

86 is_rope = offsets >= rope_abs_start 

87 rope_local = offsets - rope_abs_start 

88 

89 x_partner = tl.load(input_base + (offsets ^ 1), mask=is_rope, other=0.0).to( 

90 tl.float32 

91 ) 

92 cs_idx = tl.maximum(rope_local >> 1, 0) 

93 cos_v = tl.load(cache_base + cs_idx, mask=is_rope, other=1.0) 

94 sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope, other=0.0) 

95 x_add = x * cos_v + x_partner * sin_v 

96 x_sub = x * cos_v - x_partner * sin_v 

97 is_even = (rope_local & 1) == 0 

98 rotated = tl.where(is_even, x_add, x_sub) 

99 x = tl.where(is_rope, rotated, x) 

100 

101 x_2d = tl.reshape(tl.abs(x), (CHUNKS_PER_HEAD, QUANT_GROUP_SIZE)) 

102 block_absmax = tl.maximum(tl.max(x_2d, axis=1), eps) 

103 scales = block_absmax * (1.0 / fp8_max) 

104 if TMA_ALIGNED_SCALES: 

105 scales = tl.math.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(scales), 1e-10)))) 

106 

107 scales_exp = tl.reshape( 

108 tl.broadcast_to( 

109 tl.reshape(scales, (CHUNKS_PER_HEAD, 1)), 

110 (CHUNKS_PER_HEAD, QUANT_GROUP_SIZE), 

111 ), 

112 (HEAD_DIM,), 

113 ) 

114 x_quant = tl.clamp(x / scales_exp, -fp8_max, fp8_max).to(tl.float8e4nv) 

115 

116 fp8_base = ( 

117 fp8_ptr 

118 + g * fp8_stride_group 

119 + pid_token * fp8_stride_token 

120 + qb_start * QUANT_GROUP_SIZE 

121 ) 

122 tl.store(fp8_base + offsets, x_quant) 

123 

124 block_offsets = tl.arange(0, CHUNKS_PER_HEAD) 

125 qb_indices = qb_start + block_offsets 

126 if TMA_ALIGNED_SCALES: 

127 scale_bits = scales.to(tl.int32, bitcast=True) 

128 ue8m0_bytes = (scale_bits >> 23) & 0xFF 

129 packed_val = tl.sum(ue8m0_bytes << (block_offsets * 8)) 

130 scale_addr = ( 

131 scale_ptr 

132 + g * scale_stride_group 

133 + pid_token 

134 + head_in_group * scale_stride_k 

135 ) 

136 tl.store(scale_addr, packed_val) 

137 else: 

138 scale_addrs = ( 

139 scale_ptr + g * scale_stride_group + pid_token + qb_indices * scale_stride_k 

140 ) 

141 tl.store(scale_addrs, scales) 

142 

143 

144def fused_inv_rope_fp8_quant( 

145 o: torch.Tensor, 

146 positions: torch.Tensor, 

147 cos_sin_cache: torch.Tensor, 

148 n_groups: int, 

149 heads_per_group: int, 

150 nope_dim: int = 448, 

151 rope_dim: int = 64, 

152 quant_group_size: int = 128, 

153 eps: float = 1e-10, 

154 dtype: Optional[torch.dtype] = None, 

155 tma_aligned_scales: bool = False, 

156) -> Tuple[torch.Tensor, torch.Tensor]: 

157 """ 

158 Triton draft of DeepSeek-V4 fused inverse-RoPE + FP8 group quant. 

159 

160 Args: 

161 o: [num_tokens, num_heads, head_dim] 

162 positions: [num_tokens] 

163 cos_sin_cache: [max_position, rope_dim] laid out as cos || sin 

164 

165 Returns: 

166 o_fp8: [num_tokens, n_groups, heads_per_group * head_dim] 

167 o_scale: [num_tokens, n_groups, num_scale_blocks] or packed UE8M0 view 

168 """ 

169 logger.debug("GEMS FUSED INV ROPE FP8 QUANT") 

170 

171 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype 

172 assert fp8_dtype == torch.float8_e4m3fn, "only torch.float8_e4m3fn is supported" 

173 assert o.ndim == 3, "`o` must be [num_tokens, num_heads, head_dim]" 

174 assert positions.ndim == 1, "`positions` must be 1D" 

175 assert cos_sin_cache.ndim == 2, "`cos_sin_cache` must be 2D" 

176 assert o.stride(-1) == 1, "head_dim must be contiguous" 

177 assert positions.shape[0] == o.shape[0], "positions and o token count mismatch" 

178 

179 num_tokens, num_heads, head_dim = o.shape 

180 assert num_heads == n_groups * heads_per_group 

181 assert head_dim == nope_dim + rope_dim 

182 assert head_dim % quant_group_size == 0 

183 assert nope_dim % quant_group_size == (quant_group_size - rope_dim) 

184 assert rope_dim % 2 == 0 

185 assert cos_sin_cache.shape[-1] == rope_dim 

186 assert cos_sin_cache.dtype == torch.float32 

187 

188 chunks_per_head = head_dim // quant_group_size 

189 if tma_aligned_scales: 

190 assert ( 

191 chunks_per_head <= 4 

192 ), "packed UE8M0 path currently expects at most 4 scale blocks per head" 

193 

194 d = heads_per_group * head_dim 

195 num_scale_blocks = d // quant_group_size 

196 tma_aligned_t = _get_tma_aligned_size(num_tokens, 4) 

197 

198 if tma_aligned_scales: 

199 scale_inner = (num_scale_blocks + 3) // 4 

200 scale_dtype = torch.int32 

201 else: 

202 scale_inner = num_scale_blocks 

203 scale_dtype = torch.float32 

204 

205 finfo = torch.finfo(fp8_dtype) 

206 fp8_q = torch.empty((n_groups, num_tokens, d), dtype=fp8_dtype, device=o.device) 

207 scale = torch.empty( 

208 n_groups * scale_inner * tma_aligned_t, 

209 dtype=scale_dtype, 

210 device=o.device, 

211 ).as_strided( 

212 (n_groups, num_tokens, scale_inner), 

213 (scale_inner * tma_aligned_t, 1, tma_aligned_t), 

214 ) 

215 

216 grid = (tma_aligned_t, n_groups * heads_per_group) 

217 _fused_inv_rope_fp8_quant_per_head[grid]( 

218 o, 

219 positions, 

220 cos_sin_cache, 

221 fp8_q, 

222 scale, 

223 num_tokens, 

224 heads_per_group=heads_per_group, 

225 o_stride_token=o.stride(0), 

226 o_stride_head=o.stride(1), 

227 cache_stride_pos=cos_sin_cache.stride(0), 

228 fp8_stride_group=fp8_q.stride(0), 

229 fp8_stride_token=fp8_q.stride(1), 

230 scale_stride_group=scale.stride(0), 

231 scale_stride_k=scale.stride(2), 

232 fp8_max=finfo.max, 

233 eps=eps, 

234 QUANT_GROUP_SIZE=quant_group_size, 

235 CHUNKS_PER_HEAD=chunks_per_head, 

236 ROPE_START=nope_dim % quant_group_size, 

237 HALF_ROPE=rope_dim // 2, 

238 TMA_ALIGNED_SCALES=tma_aligned_scales, 

239 num_warps=1, 

240 num_stages=1, 

241 ) 

242 

243 return fp8_q.transpose(0, 1), scale.transpose(0, 1)