Coverage for src/flag_gems/ops/per_token_group_quant_fp8.py: 19%
69 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1from typing import Optional, Tuple
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils.device_info import get_device_capability
10if torch_device_fn.is_available() and get_device_capability() >= (9, 0):
11 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn
12else:
13 SUPPORTED_FP8_DTYPE = torch.float32
16@triton.jit
17def _per_token_group_quant_fp8(
18 y_ptr,
19 y_q_ptr,
20 y_s_ptr,
21 group_size,
22 y_num_columns,
23 y_row_stride,
24 eps,
25 fp8_min,
26 fp8_max,
27 scale_ue8m0,
28 BLOCK: tl.constexpr,
29):
30 groups_per_row = y_num_columns // group_size
32 g_id = tl.program_id(0)
33 row = g_id // groups_per_row
34 row_g_id = g_id % groups_per_row
36 y_ptr += (row * y_row_stride) + (row_g_id * group_size)
37 y_q_ptr += g_id * group_size
38 y_s_ptr += g_id
40 cols = tl.arange(0, BLOCK)
41 mask = cols < group_size
43 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
44 _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
45 y_s = _absmax / fp8_max
47 if scale_ue8m0:
48 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10))))
50 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
52 tl.store(y_q_ptr + cols, y_q, mask=mask)
53 tl.store(y_s_ptr, y_s)
56@triton.jit
57def _per_token_group_quant_fp8_colmajor(
58 y_ptr,
59 y_q_ptr,
60 y_s_ptr,
61 group_size,
62 y_num_columns,
63 y_row_stride,
64 y_s_col_stride,
65 eps,
66 fp8_min,
67 fp8_max,
68 scale_ue8m0,
69 BLOCK: tl.constexpr,
70):
71 groups_per_row = y_num_columns // group_size
73 g_id = tl.program_id(0)
74 row = g_id // groups_per_row
75 group_id = g_id % groups_per_row
77 y_ptr += row * y_row_stride + group_id * group_size
78 y_q_ptr += g_id * group_size
79 y_s_ptr += group_id * y_s_col_stride + row
81 cols = tl.arange(0, BLOCK)
82 mask = cols < group_size
84 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
85 _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
86 y_s = _absmax / fp8_max
88 if scale_ue8m0:
89 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10))))
91 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
93 tl.store(y_q_ptr + cols, y_q, mask=mask)
94 tl.store(y_s_ptr, y_s)
97def per_token_group_quant_fp8(
98 x: torch.Tensor,
99 group_size: int,
100 eps: float = 1e-10,
101 dtype: Optional[torch.dtype] = None,
102 column_major_scales: bool = False,
103 scale_ue8m0: bool = False,
104) -> Tuple[torch.Tensor, torch.Tensor]:
105 # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
106 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype
107 assert x.shape[-1] % group_size == 0, (
108 f"the last dimension of `x` {x.shape[-1]} must be divisible "
109 f"by `group_size` {group_size}"
110 )
111 assert x.stride(-1) == 1, "`x` groups must be contiguous"
113 finfo = torch.finfo(fp8_dtype)
114 fp8_min = finfo.min
115 fp8_max = finfo.max
117 x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
118 M = x.numel() // group_size
119 N = group_size
121 if column_major_scales:
122 shape = (x.shape[-1] // group_size,) + x.shape[:-1]
123 x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
124 else:
125 shape = x.shape[:-1] + (x.shape[-1] // group_size,)
126 x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
128 BLOCK = triton.next_power_of_2(N)
129 num_warps = min(max(BLOCK // 256, 1), 8)
130 num_stages = 1
131 if column_major_scales:
132 _per_token_group_quant_fp8_colmajor[(M,)](
133 x,
134 x_q,
135 x_s,
136 group_size,
137 x.shape[1],
138 x.stride(0),
139 x_s.stride(1),
140 eps,
141 fp8_min=fp8_min,
142 fp8_max=fp8_max,
143 scale_ue8m0=scale_ue8m0,
144 BLOCK=BLOCK,
145 num_warps=num_warps,
146 num_stages=num_stages,
147 )
148 else:
149 _per_token_group_quant_fp8[(M,)](
150 x,
151 x_q,
152 x_s,
153 group_size,
154 x.shape[1],
155 x.stride(0),
156 eps,
157 fp8_min=fp8_min,
158 fp8_max=fp8_max,
159 scale_ue8m0=scale_ue8m0,
160 BLOCK=BLOCK,
161 num_warps=num_warps,
162 num_stages=num_stages,
163 )
165 return x_q, x_s