Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/per_token_group_quant_fp8.py: 0%
67 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1from typing import Optional, Tuple
3import torch
4import triton
5import triton.language as tl
7if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
8 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn
9else:
10 SUPPORTED_FP8_DTYPE = torch.float32
13@triton.jit
14def _per_token_group_quant_fp8(
15 y_ptr,
16 y_q_ptr,
17 y_s_ptr,
18 group_size,
19 y_num_columns,
20 y_row_stride,
21 eps,
22 fp8_min,
23 fp8_max,
24 scale_ue8m0,
25 BLOCK: tl.constexpr,
26):
27 groups_per_row = y_num_columns // group_size
29 g_id = tl.program_id(0)
30 row = g_id // groups_per_row
31 row_g_id = g_id % groups_per_row
33 y_ptr += (row * y_row_stride) + (row_g_id * group_size)
34 y_q_ptr += g_id * group_size
35 y_s_ptr += g_id
37 cols = tl.arange(0, BLOCK)
38 mask = cols < group_size
40 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
41 _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
42 y_s = _absmax / fp8_max
44 if scale_ue8m0:
45 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10))))
47 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
49 tl.store(y_q_ptr + cols, y_q, mask=mask)
50 tl.store(y_s_ptr, y_s)
53@triton.jit
54def _per_token_group_quant_fp8_colmajor(
55 y_ptr,
56 y_q_ptr,
57 y_s_ptr,
58 group_size,
59 y_num_columns,
60 y_row_stride,
61 y_s_col_stride,
62 eps,
63 fp8_min,
64 fp8_max,
65 scale_ue8m0,
66 BLOCK: tl.constexpr,
67):
68 groups_per_row = y_num_columns // group_size
70 g_id = tl.program_id(0)
71 row = g_id // groups_per_row
72 group_id = g_id % groups_per_row
74 y_ptr += row * y_row_stride + group_id * group_size
75 y_q_ptr += g_id * group_size
76 y_s_ptr += group_id * y_s_col_stride + row
78 cols = tl.arange(0, BLOCK)
79 mask = cols < group_size
81 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
82 _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
83 y_s = _absmax / fp8_max
85 if scale_ue8m0:
86 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10))))
88 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
90 tl.store(y_q_ptr + cols, y_q, mask=mask)
91 tl.store(y_s_ptr, y_s)
94def per_token_group_quant_fp8(
95 x: torch.Tensor,
96 group_size: int,
97 eps: float = 1e-10,
98 dtype: Optional[torch.dtype] = None,
99 column_major_scales: bool = False,
100 scale_ue8m0: bool = False,
101) -> Tuple[torch.Tensor, torch.Tensor]:
102 # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
103 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype
104 assert x.shape[-1] % group_size == 0, (
105 f"the last dimension of `x` {x.shape[-1]} must be divisible "
106 f"by `group_size` {group_size}"
107 )
108 assert x.stride(-1) == 1, "`x` groups must be contiguous"
110 finfo = torch.finfo(fp8_dtype)
111 fp8_min = finfo.min
112 fp8_max = finfo.max
114 x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
115 M = x.numel() // group_size
116 N = group_size
118 if column_major_scales:
119 shape = (x.shape[-1] // group_size,) + x.shape[:-1]
120 x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
121 else:
122 shape = x.shape[:-1] + (x.shape[-1] // group_size,)
123 x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
125 BLOCK = triton.next_power_of_2(N)
126 num_warps = min(max(BLOCK // 256, 1), 8)
127 num_stages = 1
128 if column_major_scales:
129 _per_token_group_quant_fp8_colmajor[(M,)](
130 x,
131 x_q,
132 x_s,
133 group_size,
134 x.shape[1],
135 x.stride(0),
136 x_s.stride(1),
137 eps,
138 fp8_min=fp8_min,
139 fp8_max=fp8_max,
140 scale_ue8m0=scale_ue8m0,
141 BLOCK=BLOCK,
142 num_warps=num_warps,
143 num_stages=num_stages,
144 )
145 else:
146 _per_token_group_quant_fp8[(M,)](
147 x,
148 x_q,
149 x_s,
150 group_size,
151 x.shape[1],
152 x.stride(0),
153 eps,
154 fp8_min=fp8_min,
155 fp8_max=fp8_max,
156 scale_ue8m0=scale_ue8m0,
157 BLOCK=BLOCK,
158 num_warps=num_warps,
159 num_stages=num_stages,
160 )
162 return x_q, x_s