Coverage for src/flag_gems/runtime/backend/_cambricon/ops/per_token_group_quant_fp8.py: 0%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from ..utils import MAX_GRID_SIZE_X
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
13 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn
14else:
15 SUPPORTED_FP8_DTYPE = torch.float32
18@triton.jit
19def _per_token_group_quant_fp8(
20 y_ptr,
21 y_q_ptr,
22 y_s_ptr,
23 group_size,
24 y_num_columns,
25 y_row_stride,
26 eps,
27 fp8_min,
28 fp8_max,
29 scale_ue8m0,
30 BLOCK: tl.constexpr,
31 M: tl.constexpr,
32):
33 groups_per_row = y_num_columns // group_size
35 grid_0 = tl.num_programs(0)
36 g_id = tl.program_id(0)
37 while g_id < M:
38 row = g_id // groups_per_row
39 row_g_id = g_id % groups_per_row
41 y_ptr_offset = (row * y_row_stride) + (row_g_id * group_size)
42 y_q_ptr_offset = g_id * group_size
43 y_s_ptr_offset = g_id
45 cols = tl.arange(0, BLOCK)
46 mask = cols < group_size
48 y = tl.load(y_ptr + cols + y_ptr_offset, mask=mask, other=0.0).to(tl.float32)
49 _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
50 y_s = _absmax / fp8_max
51 if scale_ue8m0:
52 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10))))
53 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
55 tl.store(y_q_ptr + cols + y_q_ptr_offset, y_q, mask=mask)
56 tl.store(y_s_ptr + y_s_ptr_offset, y_s)
57 g_id += grid_0
60@triton.jit
61def _per_token_group_quant_fp8_colmajor(
62 y_ptr,
63 y_q_ptr,
64 y_s_ptr,
65 group_size,
66 y_num_columns,
67 y_row_stride,
68 y_s_col_stride,
69 eps,
70 fp8_min,
71 fp8_max,
72 scale_ue8m0,
73 BLOCK: tl.constexpr,
74 M: tl.constexpr,
75):
76 groups_per_row = y_num_columns // group_size
77 grid_0 = tl.num_programs(0)
78 g_id = tl.program_id(0)
79 while g_id < M:
80 row = g_id // groups_per_row
81 group_id = g_id % groups_per_row
83 y_ptr_offset = row * y_row_stride + group_id * group_size
84 y_q_ptr_offset = g_id * group_size
85 y_s_ptr_offset = group_id * y_s_col_stride + row
87 cols = tl.arange(0, BLOCK)
88 mask = cols < group_size
90 y = tl.load(y_ptr + cols + y_ptr_offset, mask=mask, other=0.0).to(tl.float32)
91 _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
92 y_s = _absmax / fp8_max
93 if scale_ue8m0:
94 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10))))
95 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
97 tl.store(y_q_ptr + cols + y_q_ptr_offset, y_q, mask=mask)
98 tl.store(y_s_ptr + y_s_ptr_offset, y_s)
99 g_id += grid_0
102def per_token_group_quant_fp8(
103 x: torch.Tensor,
104 group_size: int,
105 eps: float = 1e-10,
106 dtype: Optional[torch.dtype] = None,
107 column_major_scales: bool = False,
108 scale_ue8m0: bool = False,
109) -> Tuple[torch.Tensor, torch.Tensor]:
110 logger.debug("GEMS_CAMBRICON PER_TOKEN_GROUP_QUANT_FP8")
111 # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
112 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype
113 assert x.shape[-1] % group_size == 0, (
114 f"the last dimension of `x` {x.shape[-1]} must be divisible "
115 f"by `group_size` {group_size}"
116 )
117 assert x.stride(-1) == 1, "`x` groups must be contiguous"
119 finfo = torch.finfo(fp8_dtype)
120 fp8_min = finfo.min
121 fp8_max = finfo.max
123 x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
124 M = x.numel() // group_size
125 N = group_size
127 if column_major_scales:
128 shape = (x.shape[-1] // group_size,) + x.shape[:-1]
129 x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
130 else:
131 shape = x.shape[:-1] + (x.shape[-1] // group_size,)
132 x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
134 BLOCK = triton.next_power_of_2(N)
135 num_warps = min(max(BLOCK // 256, 1), 8)
136 num_stages = 1
137 grid = min(M, MAX_GRID_SIZE_X // 4)
138 if column_major_scales:
139 _per_token_group_quant_fp8_colmajor[(grid,)](
140 x,
141 x_q,
142 x_s,
143 group_size,
144 x.shape[1],
145 x.stride(0),
146 x_s.stride(1),
147 eps,
148 fp8_min=fp8_min,
149 fp8_max=fp8_max,
150 scale_ue8m0=scale_ue8m0,
151 BLOCK=BLOCK,
152 num_warps=num_warps,
153 num_stages=num_stages,
154 M=M,
155 )
156 else:
157 _per_token_group_quant_fp8[(grid,)](
158 x,
159 x_q,
160 x_s,
161 group_size,
162 x.shape[1],
163 x.stride(0),
164 eps,
165 fp8_min=fp8_min,
166 fp8_max=fp8_max,
167 scale_ue8m0=scale_ue8m0,
168 BLOCK=BLOCK,
169 num_warps=num_warps,
170 num_stages=num_stages,
171 M=M,
172 )
174 return x_q, x_s