Coverage for src/flag_gems/runtime/backend/_cambricon/ops/per_token_group_quant_fp8.py: 0%
80 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils.device_info import get_device_capability
11from ..utils import MAX_GRID_SIZE_X
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15if torch_device_fn.is_available() and get_device_capability() >= (9, 0):
16 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn
17else:
18 SUPPORTED_FP8_DTYPE = torch.float32
21@triton.jit
22def _per_token_group_quant_fp8(
23 y_ptr,
24 y_q_ptr,
25 y_s_ptr,
26 group_size,
27 y_num_columns,
28 y_row_stride,
29 eps,
30 fp8_min,
31 fp8_max,
32 scale_ue8m0,
33 BLOCK: tl.constexpr,
34 M: tl.constexpr,
35):
36 groups_per_row = y_num_columns // group_size
38 grid_0 = tl.num_programs(0)
39 g_id = tl.program_id(0)
40 while g_id < M:
41 row = g_id // groups_per_row
42 row_g_id = g_id % groups_per_row
44 y_ptr_offset = (row * y_row_stride) + (row_g_id * group_size)
45 y_q_ptr_offset = g_id * group_size
46 y_s_ptr_offset = g_id
48 cols = tl.arange(0, BLOCK)
49 mask = cols < group_size
51 y = tl.load(y_ptr + cols + y_ptr_offset, mask=mask, other=0.0).to(tl.float32)
52 _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
53 y_s = _absmax / fp8_max
54 if scale_ue8m0:
55 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10))))
56 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
58 tl.store(y_q_ptr + cols + y_q_ptr_offset, y_q, mask=mask)
59 tl.store(y_s_ptr + y_s_ptr_offset, y_s)
60 g_id += grid_0
63@triton.jit
64def _per_token_group_quant_fp8_colmajor(
65 y_ptr,
66 y_q_ptr,
67 y_s_ptr,
68 group_size,
69 y_num_columns,
70 y_row_stride,
71 y_s_col_stride,
72 eps,
73 fp8_min,
74 fp8_max,
75 scale_ue8m0,
76 BLOCK: tl.constexpr,
77 M: tl.constexpr,
78):
79 groups_per_row = y_num_columns // group_size
80 grid_0 = tl.num_programs(0)
81 g_id = tl.program_id(0)
82 while g_id < M:
83 row = g_id // groups_per_row
84 group_id = g_id % groups_per_row
86 y_ptr_offset = row * y_row_stride + group_id * group_size
87 y_q_ptr_offset = g_id * group_size
88 y_s_ptr_offset = group_id * y_s_col_stride + row
90 cols = tl.arange(0, BLOCK)
91 mask = cols < group_size
93 y = tl.load(y_ptr + cols + y_ptr_offset, mask=mask, other=0.0).to(tl.float32)
94 _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
95 y_s = _absmax / fp8_max
96 if scale_ue8m0:
97 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10))))
98 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
100 tl.store(y_q_ptr + cols + y_q_ptr_offset, y_q, mask=mask)
101 tl.store(y_s_ptr + y_s_ptr_offset, y_s)
102 g_id += grid_0
105def per_token_group_quant_fp8(
106 x: torch.Tensor,
107 group_size: int,
108 eps: float = 1e-10,
109 dtype: Optional[torch.dtype] = None,
110 column_major_scales: bool = False,
111 scale_ue8m0: bool = False,
112) -> Tuple[torch.Tensor, torch.Tensor]:
113 logger.debug("GEMS_CAMBRICON PER_TOKEN_GROUP_QUANT_FP8")
114 # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
115 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype
116 assert x.shape[-1] % group_size == 0, (
117 f"the last dimension of `x` {x.shape[-1]} must be divisible "
118 f"by `group_size` {group_size}"
119 )
120 assert x.stride(-1) == 1, "`x` groups must be contiguous"
122 finfo = torch.finfo(fp8_dtype)
123 fp8_min = finfo.min
124 fp8_max = finfo.max
126 x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
127 M = x.numel() // group_size
128 N = group_size
130 if column_major_scales:
131 shape = (x.shape[-1] // group_size,) + x.shape[:-1]
132 x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
133 else:
134 shape = x.shape[:-1] + (x.shape[-1] // group_size,)
135 x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
137 BLOCK = triton.next_power_of_2(N)
138 num_warps = min(max(BLOCK // 256, 1), 8)
139 num_stages = 1
140 grid = min(M, MAX_GRID_SIZE_X // 4)
141 if column_major_scales:
142 _per_token_group_quant_fp8_colmajor[(grid,)](
143 x,
144 x_q,
145 x_s,
146 group_size,
147 x.shape[1],
148 x.stride(0),
149 x_s.stride(1),
150 eps,
151 fp8_min=fp8_min,
152 fp8_max=fp8_max,
153 scale_ue8m0=scale_ue8m0,
154 BLOCK=BLOCK,
155 num_warps=num_warps,
156 num_stages=num_stages,
157 M=M,
158 )
159 else:
160 _per_token_group_quant_fp8[(grid,)](
161 x,
162 x_q,
163 x_s,
164 group_size,
165 x.shape[1],
166 x.stride(0),
167 eps,
168 fp8_min=fp8_min,
169 fp8_max=fp8_max,
170 scale_ue8m0=scale_ue8m0,
171 BLOCK=BLOCK,
172 num_warps=num_warps,
173 num_stages=num_stages,
174 M=M,
175 )
177 return x_q, x_s