Coverage for src/flag_gems/fused/fused_inv_rope_fp8_quant.py: 14%
101 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils.device_info import get_device_capability
11if torch_device_fn.is_available() and get_device_capability() >= (9, 0):
12 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn
13else:
14 SUPPORTED_FP8_DTYPE = torch.float32
17logger = logging.getLogger(__name__)
20def _get_tma_aligned_size(size: int, align: int) -> int:
21 return ((size + align - 1) // align) * align
24@triton.jit
25def _fused_inv_rope_fp8_quant_per_head(
26 o_ptr,
27 positions_ptr,
28 cos_sin_cache_ptr,
29 fp8_ptr,
30 scale_ptr,
31 num_tokens,
32 heads_per_group: tl.constexpr,
33 o_stride_token,
34 o_stride_head,
35 cache_stride_pos,
36 fp8_stride_group,
37 fp8_stride_token,
38 scale_stride_group,
39 scale_stride_k,
40 fp8_max: tl.constexpr,
41 eps: tl.constexpr,
42 QUANT_GROUP_SIZE: tl.constexpr,
43 CHUNKS_PER_HEAD: tl.constexpr,
44 ROPE_START: tl.constexpr,
45 HALF_ROPE: tl.constexpr,
46 TMA_ALIGNED_SCALES: tl.constexpr,
47):
48 pid_token = tl.program_id(0).to(tl.int64)
49 pid_gh = tl.program_id(1).to(tl.int64)
51 g = pid_gh // heads_per_group
52 head_in_group = pid_gh % heads_per_group
53 global_head = pid_gh
54 qb_start = head_in_group * CHUNKS_PER_HEAD
56 if pid_token >= num_tokens:
57 if TMA_ALIGNED_SCALES:
58 scale_addr = (
59 scale_ptr
60 + g * scale_stride_group
61 + pid_token
62 + head_in_group * scale_stride_k
63 )
64 tl.store(scale_addr, tl.zeros((), dtype=tl.int32))
65 else:
66 block_offsets = tl.arange(0, CHUNKS_PER_HEAD)
67 qb_indices = qb_start + block_offsets
68 scale_addrs = (
69 scale_ptr
70 + g * scale_stride_group
71 + pid_token
72 + qb_indices * scale_stride_k
73 )
74 tl.store(scale_addrs, tl.zeros((CHUNKS_PER_HEAD,), dtype=tl.float32))
75 return
77 input_base = o_ptr + pid_token * o_stride_token + global_head * o_stride_head
79 HEAD_DIM: tl.constexpr = CHUNKS_PER_HEAD * QUANT_GROUP_SIZE
80 offsets = tl.arange(0, HEAD_DIM)
81 x = tl.load(input_base + offsets).to(tl.float32)
83 rope_abs_start: tl.constexpr = (CHUNKS_PER_HEAD - 1) * QUANT_GROUP_SIZE + ROPE_START
84 pos = tl.load(positions_ptr + pid_token)
85 cache_base = cos_sin_cache_ptr + pos * cache_stride_pos
86 is_rope = offsets >= rope_abs_start
87 rope_local = offsets - rope_abs_start
89 x_partner = tl.load(input_base + (offsets ^ 1), mask=is_rope, other=0.0).to(
90 tl.float32
91 )
92 cs_idx = tl.maximum(rope_local >> 1, 0)
93 cos_v = tl.load(cache_base + cs_idx, mask=is_rope, other=1.0)
94 sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope, other=0.0)
95 x_add = x * cos_v + x_partner * sin_v
96 x_sub = x * cos_v - x_partner * sin_v
97 is_even = (rope_local & 1) == 0
98 rotated = tl.where(is_even, x_add, x_sub)
99 x = tl.where(is_rope, rotated, x)
101 x_2d = tl.reshape(tl.abs(x), (CHUNKS_PER_HEAD, QUANT_GROUP_SIZE))
102 block_absmax = tl.maximum(tl.max(x_2d, axis=1), eps)
103 scales = block_absmax * (1.0 / fp8_max)
104 if TMA_ALIGNED_SCALES:
105 scales = tl.math.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(scales), 1e-10))))
107 scales_exp = tl.reshape(
108 tl.broadcast_to(
109 tl.reshape(scales, (CHUNKS_PER_HEAD, 1)),
110 (CHUNKS_PER_HEAD, QUANT_GROUP_SIZE),
111 ),
112 (HEAD_DIM,),
113 )
114 x_quant = tl.clamp(x / scales_exp, -fp8_max, fp8_max).to(tl.float8e4nv)
116 fp8_base = (
117 fp8_ptr
118 + g * fp8_stride_group
119 + pid_token * fp8_stride_token
120 + qb_start * QUANT_GROUP_SIZE
121 )
122 tl.store(fp8_base + offsets, x_quant)
124 block_offsets = tl.arange(0, CHUNKS_PER_HEAD)
125 qb_indices = qb_start + block_offsets
126 if TMA_ALIGNED_SCALES:
127 scale_bits = scales.to(tl.int32, bitcast=True)
128 ue8m0_bytes = (scale_bits >> 23) & 0xFF
129 packed_val = tl.sum(ue8m0_bytes << (block_offsets * 8))
130 scale_addr = (
131 scale_ptr
132 + g * scale_stride_group
133 + pid_token
134 + head_in_group * scale_stride_k
135 )
136 tl.store(scale_addr, packed_val)
137 else:
138 scale_addrs = (
139 scale_ptr + g * scale_stride_group + pid_token + qb_indices * scale_stride_k
140 )
141 tl.store(scale_addrs, scales)
144def fused_inv_rope_fp8_quant(
145 o: torch.Tensor,
146 positions: torch.Tensor,
147 cos_sin_cache: torch.Tensor,
148 n_groups: int,
149 heads_per_group: int,
150 nope_dim: int = 448,
151 rope_dim: int = 64,
152 quant_group_size: int = 128,
153 eps: float = 1e-10,
154 dtype: Optional[torch.dtype] = None,
155 tma_aligned_scales: bool = False,
156) -> Tuple[torch.Tensor, torch.Tensor]:
157 """
158 Triton draft of DeepSeek-V4 fused inverse-RoPE + FP8 group quant.
160 Args:
161 o: [num_tokens, num_heads, head_dim]
162 positions: [num_tokens]
163 cos_sin_cache: [max_position, rope_dim] laid out as cos || sin
165 Returns:
166 o_fp8: [num_tokens, n_groups, heads_per_group * head_dim]
167 o_scale: [num_tokens, n_groups, num_scale_blocks] or packed UE8M0 view
168 """
169 logger.debug("GEMS FUSED INV ROPE FP8 QUANT")
171 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype
172 assert fp8_dtype == torch.float8_e4m3fn, "only torch.float8_e4m3fn is supported"
173 assert o.ndim == 3, "`o` must be [num_tokens, num_heads, head_dim]"
174 assert positions.ndim == 1, "`positions` must be 1D"
175 assert cos_sin_cache.ndim == 2, "`cos_sin_cache` must be 2D"
176 assert o.stride(-1) == 1, "head_dim must be contiguous"
177 assert positions.shape[0] == o.shape[0], "positions and o token count mismatch"
179 num_tokens, num_heads, head_dim = o.shape
180 assert num_heads == n_groups * heads_per_group
181 assert head_dim == nope_dim + rope_dim
182 assert head_dim % quant_group_size == 0
183 assert nope_dim % quant_group_size == (quant_group_size - rope_dim)
184 assert rope_dim % 2 == 0
185 assert cos_sin_cache.shape[-1] == rope_dim
186 assert cos_sin_cache.dtype == torch.float32
188 chunks_per_head = head_dim // quant_group_size
189 if tma_aligned_scales:
190 assert (
191 chunks_per_head <= 4
192 ), "packed UE8M0 path currently expects at most 4 scale blocks per head"
194 d = heads_per_group * head_dim
195 num_scale_blocks = d // quant_group_size
196 tma_aligned_t = _get_tma_aligned_size(num_tokens, 4)
198 if tma_aligned_scales:
199 scale_inner = (num_scale_blocks + 3) // 4
200 scale_dtype = torch.int32
201 else:
202 scale_inner = num_scale_blocks
203 scale_dtype = torch.float32
205 finfo = torch.finfo(fp8_dtype)
206 fp8_q = torch.empty((n_groups, num_tokens, d), dtype=fp8_dtype, device=o.device)
207 scale = torch.empty(
208 n_groups * scale_inner * tma_aligned_t,
209 dtype=scale_dtype,
210 device=o.device,
211 ).as_strided(
212 (n_groups, num_tokens, scale_inner),
213 (scale_inner * tma_aligned_t, 1, tma_aligned_t),
214 )
216 grid = (tma_aligned_t, n_groups * heads_per_group)
217 _fused_inv_rope_fp8_quant_per_head[grid](
218 o,
219 positions,
220 cos_sin_cache,
221 fp8_q,
222 scale,
223 num_tokens,
224 heads_per_group=heads_per_group,
225 o_stride_token=o.stride(0),
226 o_stride_head=o.stride(1),
227 cache_stride_pos=cos_sin_cache.stride(0),
228 fp8_stride_group=fp8_q.stride(0),
229 fp8_stride_token=fp8_q.stride(1),
230 scale_stride_group=scale.stride(0),
231 scale_stride_k=scale.stride(2),
232 fp8_max=finfo.max,
233 eps=eps,
234 QUANT_GROUP_SIZE=quant_group_size,
235 CHUNKS_PER_HEAD=chunks_per_head,
236 ROPE_START=nope_dim % quant_group_size,
237 HALF_ROPE=rope_dim // 2,
238 TMA_ALIGNED_SCALES=tma_aligned_scales,
239 num_warps=1,
240 num_stages=1,
241 )
243 return fp8_q.transpose(0, 1), scale.transpose(0, 1)