Coverage for src/flag_gems/ops/per_token_group_quant_fp8.py: 47%
72 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils.device_info import get_device_capability
11if torch_device_fn.is_available() and get_device_capability() >= (9, 0):
12 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn
13else:
14 SUPPORTED_FP8_DTYPE = torch.float32
17logger = logging.getLogger(__name__)
20@triton.jit
21def _per_token_group_quant_fp8(
22 y_ptr,
23 y_q_ptr,
24 y_s_ptr,
25 group_size,
26 y_num_columns,
27 y_row_stride,
28 eps,
29 fp8_min,
30 fp8_max,
31 scale_ue8m0,
32 BLOCK: tl.constexpr,
33):
34 groups_per_row = y_num_columns // group_size
36 g_id = tl.program_id(0)
37 row = g_id // groups_per_row
38 row_g_id = g_id % groups_per_row
40 y_ptr += (row * y_row_stride) + (row_g_id * group_size)
41 y_q_ptr += g_id * group_size
42 y_s_ptr += g_id
44 cols = tl.arange(0, BLOCK)
45 mask = cols < group_size
47 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
48 _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
49 y_s = _absmax / fp8_max
51 if scale_ue8m0:
52 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10))))
54 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
56 tl.store(y_q_ptr + cols, y_q, mask=mask)
57 tl.store(y_s_ptr, y_s)
60@triton.jit
61def _per_token_group_quant_fp8_colmajor(
62 y_ptr,
63 y_q_ptr,
64 y_s_ptr,
65 group_size,
66 y_num_columns,
67 y_row_stride,
68 y_s_col_stride,
69 eps,
70 fp8_min,
71 fp8_max,
72 scale_ue8m0,
73 BLOCK: tl.constexpr,
74):
75 groups_per_row = y_num_columns // group_size
77 g_id = tl.program_id(0)
78 row = g_id // groups_per_row
79 group_id = g_id % groups_per_row
81 y_ptr += row * y_row_stride + group_id * group_size
82 y_q_ptr += g_id * group_size
83 y_s_ptr += group_id * y_s_col_stride + row
85 cols = tl.arange(0, BLOCK)
86 mask = cols < group_size
88 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
89 _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
90 y_s = _absmax / fp8_max
92 if scale_ue8m0:
93 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10))))
95 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
97 tl.store(y_q_ptr + cols, y_q, mask=mask)
98 tl.store(y_s_ptr, y_s)
101def per_token_group_quant_fp8(
102 x: torch.Tensor,
103 group_size: int,
104 eps: float = 1e-10,
105 dtype: Optional[torch.dtype] = None,
106 column_major_scales: bool = False,
107 scale_ue8m0: bool = False,
108) -> Tuple[torch.Tensor, torch.Tensor]:
109 logger.debug("GEMS PER TOKEN GROUP QUANT FP8")
110 # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
111 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype
112 assert x.shape[-1] % group_size == 0, (
113 f"the last dimension of `x` {x.shape[-1]} must be divisible "
114 f"by `group_size` {group_size}"
115 )
116 assert x.stride(-1) == 1, "`x` groups must be contiguous"
118 finfo = torch.finfo(fp8_dtype)
119 fp8_min = finfo.min
120 fp8_max = finfo.max
122 x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
123 M = x.numel() // group_size
124 N = group_size
126 if column_major_scales:
127 shape = (x.shape[-1] // group_size,) + x.shape[:-1]
128 x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
129 else:
130 shape = x.shape[:-1] + (x.shape[-1] // group_size,)
131 x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
133 BLOCK = triton.next_power_of_2(N)
134 num_warps = min(max(BLOCK // 256, 1), 8)
135 num_stages = 1
136 if column_major_scales:
137 _per_token_group_quant_fp8_colmajor[(M,)](
138 x,
139 x_q,
140 x_s,
141 group_size,
142 x.shape[1],
143 x.stride(0),
144 x_s.stride(1),
145 eps,
146 fp8_min=fp8_min,
147 fp8_max=fp8_max,
148 scale_ue8m0=scale_ue8m0,
149 BLOCK=BLOCK,
150 num_warps=num_warps,
151 num_stages=num_stages,
152 )
153 else:
154 _per_token_group_quant_fp8[(M,)](
155 x,
156 x_q,
157 x_s,
158 group_size,
159 x.shape[1],
160 x.stride(0),
161 eps,
162 fp8_min=fp8_min,
163 fp8_max=fp8_max,
164 scale_ue8m0=scale_ue8m0,
165 BLOCK=BLOCK,
166 num_warps=num_warps,
167 num_stages=num_stages,
168 )
170 return x_q, x_s