Coverage for src/flag_gems/ops/per_token_group_quant_fp8.py: 12%
523 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils.device_info import get_device_capability
11if torch_device_fn.is_available() and get_device_capability() >= (9, 0):
12 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn
13else:
14 SUPPORTED_FP8_DTYPE = torch.float32
17logger = logging.getLogger(__name__)
20@triton.jit
21def _per_token_group_quant_fp8(
22 y_ptr,
23 y_q_ptr,
24 y_s_ptr,
25 group_size,
26 y_num_columns,
27 y_row_stride,
28 eps,
29 fp8_min,
30 fp8_max,
31 scale_ue8m0,
32 BLOCK: tl.constexpr,
33):
34 groups_per_row = y_num_columns // group_size
36 g_id = tl.program_id(0)
37 row = g_id // groups_per_row
38 row_g_id = g_id % groups_per_row
40 y_ptr += (row * y_row_stride) + (row_g_id * group_size)
41 y_q_ptr += g_id * group_size
42 y_s_ptr += g_id
44 cols = tl.arange(0, BLOCK)
45 mask = cols < group_size
47 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
48 _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
49 y_s = _absmax / fp8_max
51 if scale_ue8m0:
52 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10))))
54 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
56 tl.store(y_q_ptr + cols, y_q, mask=mask)
57 tl.store(y_s_ptr, y_s)
60@triton.jit
61def _per_token_group_quant_fp8_colmajor(
62 y_ptr,
63 y_q_ptr,
64 y_s_ptr,
65 group_size,
66 y_num_columns,
67 y_row_stride,
68 y_s_col_stride,
69 eps,
70 fp8_min,
71 fp8_max,
72 scale_ue8m0,
73 BLOCK: tl.constexpr,
74):
75 groups_per_row = y_num_columns // group_size
77 g_id = tl.program_id(0)
78 row = g_id // groups_per_row
79 group_id = g_id % groups_per_row
81 y_ptr += row * y_row_stride + group_id * group_size
82 y_q_ptr += g_id * group_size
83 y_s_ptr += group_id * y_s_col_stride + row
85 cols = tl.arange(0, BLOCK)
86 mask = cols < group_size
88 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
89 _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
90 y_s = _absmax / fp8_max
92 if scale_ue8m0:
93 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10))))
95 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
97 tl.store(y_q_ptr + cols, y_q, mask=mask)
98 tl.store(y_s_ptr, y_s)
101@triton.jit
102def _per_token_group_quant_fp8_m2(
103 y_ptr,
104 y_q_ptr,
105 y_s_ptr,
106 group_size,
107 y_num_columns,
108 y_row_stride,
109 eps,
110 fp8_min,
111 fp8_max,
112 scale_ue8m0,
113 BLOCK: tl.constexpr,
114):
115 groups_per_row = y_num_columns // group_size
116 pid = tl.program_id(0)
117 pairs_per_row = groups_per_row // 2
118 row = pid // pairs_per_row
119 pair_id = pid % pairs_per_row
121 group0 = pair_id * 2
122 group1 = group0 + 1
124 g0 = row * groups_per_row + group0
125 g1 = g0 + 1
127 base = y_ptr + row * y_row_stride
129 y_ptr0 = base + group0 * group_size
130 y_ptr1 = base + group1 * group_size
132 y_q_ptr0 = y_q_ptr + g0 * group_size
133 y_q_ptr1 = y_q_ptr + g1 * group_size
135 y_s_ptr0 = y_s_ptr + g0
136 y_s_ptr1 = y_s_ptr + g1
138 cols = tl.arange(0, BLOCK)
139 mask = cols < group_size
141 y0 = tl.load(y_ptr0 + cols, mask=mask, other=0.0).to(tl.float32)
142 y1 = tl.load(y_ptr1 + cols, mask=mask, other=0.0).to(tl.float32)
144 abs0 = tl.abs(y0)
145 abs1 = tl.abs(y1)
147 max0 = tl.max(abs0)
148 max1 = tl.max(abs1)
150 y_s0 = tl.maximum(max0, eps) / fp8_max
151 y_s1 = tl.maximum(max1, eps) / fp8_max
153 if scale_ue8m0:
154 y_s0 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s0), 1e-10))))
155 y_s1 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s1), 1e-10))))
157 y_q0 = tl.clamp(y0 / y_s0, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
158 y_q1 = tl.clamp(y1 / y_s1, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
160 tl.store(y_q_ptr0 + cols, y_q0, mask=mask)
161 tl.store(y_s_ptr0, y_s0)
162 tl.store(y_q_ptr1 + cols, y_q1, mask=mask)
163 tl.store(y_s_ptr1, y_s1)
166@triton.jit
167def _per_token_group_quant_fp8_colmajor_m2(
168 y_ptr,
169 y_q_ptr,
170 y_s_ptr,
171 group_size,
172 y_num_columns,
173 y_row_stride,
174 y_s_col_stride,
175 eps,
176 fp8_min,
177 fp8_max,
178 scale_ue8m0,
179 BLOCK: tl.constexpr,
180):
181 groups_per_row = y_num_columns // group_size
182 pid = tl.program_id(0)
183 pairs_per_row = groups_per_row // 2
184 row = pid // pairs_per_row
185 pair_id = pid % pairs_per_row
187 group0 = pair_id * 2
188 group1 = group0 + 1
190 g0 = row * groups_per_row + group0
191 g1 = g0 + 1
193 base = y_ptr + row * y_row_stride
195 y_ptr0 = base + group0 * group_size
196 y_ptr1 = base + group1 * group_size
198 y_q_ptr0 = y_q_ptr + g0 * group_size
199 y_q_ptr1 = y_q_ptr + g1 * group_size
201 y_s_ptr0 = y_s_ptr + group0 * y_s_col_stride + row
202 y_s_ptr1 = y_s_ptr + group1 * y_s_col_stride + row
204 cols = tl.arange(0, BLOCK)
205 mask = cols < group_size
207 y0 = tl.load(y_ptr0 + cols, mask=mask, other=0.0).to(tl.float32)
208 y1 = tl.load(y_ptr1 + cols, mask=mask, other=0.0).to(tl.float32)
210 abs0 = tl.abs(y0)
211 abs1 = tl.abs(y1)
213 max0 = tl.max(abs0)
214 max1 = tl.max(abs1)
216 y_s0 = tl.maximum(max0, eps) / fp8_max
217 y_s1 = tl.maximum(max1, eps) / fp8_max
219 if scale_ue8m0:
220 y_s0 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s0), 1e-10))))
221 y_s1 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s1), 1e-10))))
223 y_q0 = tl.clamp(y0 / y_s0, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
224 y_q1 = tl.clamp(y1 / y_s1, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
226 tl.store(y_q_ptr0 + cols, y_q0, mask=mask)
227 tl.store(y_s_ptr0, y_s0)
228 tl.store(y_q_ptr1 + cols, y_q1, mask=mask)
229 tl.store(y_s_ptr1, y_s1)
232@triton.jit
233def _per_token_group_quant_fp8_m4(
234 y_ptr,
235 y_q_ptr,
236 y_s_ptr,
237 group_size,
238 y_num_columns,
239 y_row_stride,
240 eps,
241 fp8_min,
242 fp8_max,
243 scale_ue8m0,
244 BLOCK: tl.constexpr,
245):
246 groups_per_row = y_num_columns // group_size
247 pid = tl.program_id(0)
248 pairs_per_row = groups_per_row // 4
249 row = pid // pairs_per_row
250 pair_id = pid % pairs_per_row
252 group0 = pair_id * 4
253 group1 = group0 + 1
254 group2 = group0 + 2
255 group3 = group0 + 3
257 g0 = row * groups_per_row + group0
258 g1 = g0 + 1
259 g2 = g1 + 1
260 g3 = g2 + 1
262 base = y_ptr + row * y_row_stride
264 y_ptr0 = base + group0 * group_size
265 y_ptr1 = base + group1 * group_size
266 y_ptr2 = base + group2 * group_size
267 y_ptr3 = base + group3 * group_size
269 y_q_ptr0 = y_q_ptr + g0 * group_size
270 y_q_ptr1 = y_q_ptr + g1 * group_size
271 y_q_ptr2 = y_q_ptr + g2 * group_size
272 y_q_ptr3 = y_q_ptr + g3 * group_size
274 y_s_ptr0 = y_s_ptr + g0
275 y_s_ptr1 = y_s_ptr + g1
276 y_s_ptr2 = y_s_ptr + g2
277 y_s_ptr3 = y_s_ptr + g3
279 cols = tl.arange(0, BLOCK)
280 mask = cols < group_size
282 y0 = tl.load(y_ptr0 + cols, mask=mask, other=0.0).to(tl.float32)
283 y1 = tl.load(y_ptr1 + cols, mask=mask, other=0.0).to(tl.float32)
284 y2 = tl.load(y_ptr2 + cols, mask=mask, other=0.0).to(tl.float32)
285 y3 = tl.load(y_ptr3 + cols, mask=mask, other=0.0).to(tl.float32)
287 abs0 = tl.abs(y0)
288 abs1 = tl.abs(y1)
289 abs2 = tl.abs(y2)
290 abs3 = tl.abs(y3)
292 max0 = tl.max(abs0)
293 max1 = tl.max(abs1)
294 max2 = tl.max(abs2)
295 max3 = tl.max(abs3)
297 y_s0 = tl.maximum(max0, eps) / fp8_max
298 y_s1 = tl.maximum(max1, eps) / fp8_max
299 y_s2 = tl.maximum(max2, eps) / fp8_max
300 y_s3 = tl.maximum(max3, eps) / fp8_max
302 if scale_ue8m0:
303 y_s0 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s0), 1e-10))))
304 y_s1 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s1), 1e-10))))
305 y_s2 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s2), 1e-10))))
306 y_s3 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s3), 1e-10))))
308 y_q0 = tl.clamp(y0 / y_s0, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
309 y_q1 = tl.clamp(y1 / y_s1, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
310 y_q2 = tl.clamp(y2 / y_s2, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
311 y_q3 = tl.clamp(y3 / y_s3, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
313 tl.store(y_q_ptr0 + cols, y_q0, mask=mask)
314 tl.store(y_s_ptr0, y_s0)
315 tl.store(y_q_ptr1 + cols, y_q1, mask=mask)
316 tl.store(y_s_ptr1, y_s1)
317 tl.store(y_q_ptr2 + cols, y_q2, mask=mask)
318 tl.store(y_s_ptr2, y_s2)
319 tl.store(y_q_ptr3 + cols, y_q3, mask=mask)
320 tl.store(y_s_ptr3, y_s3)
323@triton.jit
324def _per_token_group_quant_fp8_colmajor_m4(
325 y_ptr,
326 y_q_ptr,
327 y_s_ptr,
328 group_size,
329 y_num_columns,
330 y_row_stride,
331 y_s_col_stride,
332 eps,
333 fp8_min,
334 fp8_max,
335 scale_ue8m0,
336 BLOCK: tl.constexpr,
337):
338 groups_per_row = y_num_columns // group_size
339 pid = tl.program_id(0)
340 pairs_per_row = groups_per_row // 4
341 row = pid // pairs_per_row
342 pair_id = pid % pairs_per_row
344 group0 = pair_id * 4
345 group1 = group0 + 1
346 group2 = group1 + 1
347 group3 = group2 + 1
349 g0 = row * groups_per_row + group0
350 g1 = g0 + 1
351 g2 = g1 + 1
352 g3 = g2 + 1
354 base = y_ptr + row * y_row_stride
356 y_ptr0 = base + group0 * group_size
357 y_ptr1 = base + group1 * group_size
358 y_ptr2 = base + group2 * group_size
359 y_ptr3 = base + group3 * group_size
361 y_q_ptr0 = y_q_ptr + g0 * group_size
362 y_q_ptr1 = y_q_ptr + g1 * group_size
363 y_q_ptr2 = y_q_ptr + g2 * group_size
364 y_q_ptr3 = y_q_ptr + g3 * group_size
366 y_s_ptr0 = y_s_ptr + group0 * y_s_col_stride + row
367 y_s_ptr1 = y_s_ptr + group1 * y_s_col_stride + row
368 y_s_ptr2 = y_s_ptr + group2 * y_s_col_stride + row
369 y_s_ptr3 = y_s_ptr + group3 * y_s_col_stride + row
371 cols = tl.arange(0, BLOCK)
372 mask = cols < group_size
374 y0 = tl.load(y_ptr0 + cols, mask=mask, other=0.0).to(tl.float32)
375 y1 = tl.load(y_ptr1 + cols, mask=mask, other=0.0).to(tl.float32)
376 y2 = tl.load(y_ptr2 + cols, mask=mask, other=0.0).to(tl.float32)
377 y3 = tl.load(y_ptr3 + cols, mask=mask, other=0.0).to(tl.float32)
379 abs0 = tl.abs(y0)
380 abs1 = tl.abs(y1)
381 abs2 = tl.abs(y2)
382 abs3 = tl.abs(y3)
384 max0 = tl.max(abs0)
385 max1 = tl.max(abs1)
386 max2 = tl.max(abs2)
387 max3 = tl.max(abs3)
389 y_s0 = tl.maximum(max0, eps) / fp8_max
390 y_s1 = tl.maximum(max1, eps) / fp8_max
391 y_s2 = tl.maximum(max2, eps) / fp8_max
392 y_s3 = tl.maximum(max3, eps) / fp8_max
394 if scale_ue8m0:
395 y_s0 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s0), 1e-10))))
396 y_s1 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s1), 1e-10))))
397 y_s2 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s2), 1e-10))))
398 y_s3 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s3), 1e-10))))
400 y_q0 = tl.clamp(y0 / y_s0, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
401 y_q1 = tl.clamp(y1 / y_s1, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
402 y_q2 = tl.clamp(y2 / y_s2, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
403 y_q3 = tl.clamp(y3 / y_s3, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
405 tl.store(y_q_ptr0 + cols, y_q0, mask=mask)
406 tl.store(y_s_ptr0, y_s0)
407 tl.store(y_q_ptr1 + cols, y_q1, mask=mask)
408 tl.store(y_s_ptr1, y_s1)
409 tl.store(y_q_ptr2 + cols, y_q2, mask=mask)
410 tl.store(y_s_ptr2, y_s2)
411 tl.store(y_q_ptr3 + cols, y_q3, mask=mask)
412 tl.store(y_s_ptr3, y_s3)
415@triton.jit
416def _per_token_group_quant_fp8_m8(
417 y_ptr,
418 y_q_ptr,
419 y_s_ptr,
420 group_size,
421 y_num_columns,
422 y_row_stride,
423 eps,
424 fp8_min,
425 fp8_max,
426 scale_ue8m0,
427 BLOCK: tl.constexpr,
428):
429 groups_per_row = y_num_columns // group_size
430 pid = tl.program_id(0)
431 pairs_per_row = groups_per_row // 8
432 row = pid // pairs_per_row
433 pair_id = pid % pairs_per_row
435 group0 = pair_id * 8
436 group1 = group0 + 1
437 group2 = group0 + 2
438 group3 = group0 + 3
439 group4 = group0 + 4
440 group5 = group0 + 5
441 group6 = group0 + 6
442 group7 = group0 + 7
444 g0 = row * groups_per_row + group0
445 g1 = g0 + 1
446 g2 = g1 + 1
447 g3 = g2 + 1
448 g4 = g3 + 1
449 g5 = g4 + 1
450 g6 = g5 + 1
451 g7 = g6 + 1
453 base = y_ptr + row * y_row_stride
455 y_ptr0 = base + group0 * group_size
456 y_ptr1 = base + group1 * group_size
457 y_ptr2 = base + group2 * group_size
458 y_ptr3 = base + group3 * group_size
459 y_ptr4 = base + group4 * group_size
460 y_ptr5 = base + group5 * group_size
461 y_ptr6 = base + group6 * group_size
462 y_ptr7 = base + group7 * group_size
464 y_q_ptr0 = y_q_ptr + g0 * group_size
465 y_q_ptr1 = y_q_ptr + g1 * group_size
466 y_q_ptr2 = y_q_ptr + g2 * group_size
467 y_q_ptr3 = y_q_ptr + g3 * group_size
468 y_q_ptr4 = y_q_ptr + g4 * group_size
469 y_q_ptr5 = y_q_ptr + g5 * group_size
470 y_q_ptr6 = y_q_ptr + g6 * group_size
471 y_q_ptr7 = y_q_ptr + g7 * group_size
473 y_s_ptr0 = y_s_ptr + g0
474 y_s_ptr1 = y_s_ptr + g1
475 y_s_ptr2 = y_s_ptr + g2
476 y_s_ptr3 = y_s_ptr + g3
477 y_s_ptr4 = y_s_ptr + g4
478 y_s_ptr5 = y_s_ptr + g5
479 y_s_ptr6 = y_s_ptr + g6
480 y_s_ptr7 = y_s_ptr + g7
482 cols = tl.arange(0, BLOCK)
483 mask = cols < group_size
485 y0 = tl.load(y_ptr0 + cols, mask=mask, other=0.0).to(tl.float32)
486 y1 = tl.load(y_ptr1 + cols, mask=mask, other=0.0).to(tl.float32)
487 y2 = tl.load(y_ptr2 + cols, mask=mask, other=0.0).to(tl.float32)
488 y3 = tl.load(y_ptr3 + cols, mask=mask, other=0.0).to(tl.float32)
489 y4 = tl.load(y_ptr4 + cols, mask=mask, other=0.0).to(tl.float32)
490 y5 = tl.load(y_ptr5 + cols, mask=mask, other=0.0).to(tl.float32)
491 y6 = tl.load(y_ptr6 + cols, mask=mask, other=0.0).to(tl.float32)
492 y7 = tl.load(y_ptr7 + cols, mask=mask, other=0.0).to(tl.float32)
494 abs0 = tl.abs(y0)
495 abs1 = tl.abs(y1)
496 abs2 = tl.abs(y2)
497 abs3 = tl.abs(y3)
498 abs4 = tl.abs(y4)
499 abs5 = tl.abs(y5)
500 abs6 = tl.abs(y6)
501 abs7 = tl.abs(y7)
503 max0 = tl.max(abs0)
504 max1 = tl.max(abs1)
505 max2 = tl.max(abs2)
506 max3 = tl.max(abs3)
507 max4 = tl.max(abs4)
508 max5 = tl.max(abs5)
509 max6 = tl.max(abs6)
510 max7 = tl.max(abs7)
511 y_s0 = tl.maximum(max0, eps) / fp8_max
512 y_s1 = tl.maximum(max1, eps) / fp8_max
513 y_s2 = tl.maximum(max2, eps) / fp8_max
514 y_s3 = tl.maximum(max3, eps) / fp8_max
515 y_s4 = tl.maximum(max4, eps) / fp8_max
516 y_s5 = tl.maximum(max5, eps) / fp8_max
517 y_s6 = tl.maximum(max6, eps) / fp8_max
518 y_s7 = tl.maximum(max7, eps) / fp8_max
520 if scale_ue8m0:
521 y_s0 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s0), 1e-10))))
522 y_s1 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s1), 1e-10))))
523 y_s2 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s2), 1e-10))))
524 y_s3 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s3), 1e-10))))
525 y_s4 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s4), 1e-10))))
526 y_s5 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s5), 1e-10))))
527 y_s6 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s6), 1e-10))))
528 y_s7 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s7), 1e-10))))
530 y_q0 = tl.clamp(y0 / y_s0, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
531 y_q1 = tl.clamp(y1 / y_s1, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
532 y_q2 = tl.clamp(y2 / y_s2, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
533 y_q3 = tl.clamp(y3 / y_s3, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
534 y_q4 = tl.clamp(y4 / y_s4, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
535 y_q5 = tl.clamp(y5 / y_s5, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
536 y_q6 = tl.clamp(y6 / y_s6, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
537 y_q7 = tl.clamp(y7 / y_s7, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
539 tl.store(y_q_ptr0 + cols, y_q0, mask=mask)
540 tl.store(y_s_ptr0, y_s0)
541 tl.store(y_q_ptr1 + cols, y_q1, mask=mask)
542 tl.store(y_s_ptr1, y_s1)
543 tl.store(y_q_ptr2 + cols, y_q2, mask=mask)
544 tl.store(y_s_ptr2, y_s2)
545 tl.store(y_q_ptr3 + cols, y_q3, mask=mask)
546 tl.store(y_s_ptr3, y_s3)
547 tl.store(y_q_ptr4 + cols, y_q4, mask=mask)
548 tl.store(y_s_ptr4, y_s4)
549 tl.store(y_q_ptr5 + cols, y_q5, mask=mask)
550 tl.store(y_s_ptr5, y_s5)
551 tl.store(y_q_ptr6 + cols, y_q6, mask=mask)
552 tl.store(y_s_ptr6, y_s6)
553 tl.store(y_q_ptr7 + cols, y_q7, mask=mask)
554 tl.store(y_s_ptr7, y_s7)
557@triton.jit
558def _per_token_group_quant_fp8_colmajor_m8(
559 y_ptr,
560 y_q_ptr,
561 y_s_ptr,
562 group_size,
563 y_num_columns,
564 y_row_stride,
565 y_s_col_stride,
566 eps,
567 fp8_min,
568 fp8_max,
569 scale_ue8m0,
570 BLOCK: tl.constexpr,
571):
572 groups_per_row = y_num_columns // group_size
573 pid = tl.program_id(0)
574 pairs_per_row = groups_per_row // 8
575 row = pid // pairs_per_row
576 pair_id = pid % pairs_per_row
578 group0 = pair_id * 8
579 group1 = group0 + 1
580 group2 = group1 + 1
581 group3 = group2 + 1
582 group4 = group3 + 1
583 group5 = group4 + 1
584 group6 = group5 + 1
585 group7 = group6 + 1
587 g0 = row * groups_per_row + group0
588 g1 = g0 + 1
589 g2 = g1 + 1
590 g3 = g2 + 1
591 g4 = g3 + 1
592 g5 = g4 + 1
593 g6 = g5 + 1
594 g7 = g6 + 1
596 base = y_ptr + row * y_row_stride
598 y_ptr0 = base + group0 * group_size
599 y_ptr1 = base + group1 * group_size
600 y_ptr2 = base + group2 * group_size
601 y_ptr3 = base + group3 * group_size
602 y_ptr4 = base + group4 * group_size
603 y_ptr5 = base + group5 * group_size
604 y_ptr6 = base + group6 * group_size
605 y_ptr7 = base + group7 * group_size
607 y_q_ptr0 = y_q_ptr + g0 * group_size
608 y_q_ptr1 = y_q_ptr + g1 * group_size
609 y_q_ptr2 = y_q_ptr + g2 * group_size
610 y_q_ptr3 = y_q_ptr + g3 * group_size
611 y_q_ptr4 = y_q_ptr + g4 * group_size
612 y_q_ptr5 = y_q_ptr + g5 * group_size
613 y_q_ptr6 = y_q_ptr + g6 * group_size
614 y_q_ptr7 = y_q_ptr + g7 * group_size
616 y_s_ptr0 = y_s_ptr + group0 * y_s_col_stride + row
617 y_s_ptr1 = y_s_ptr + group1 * y_s_col_stride + row
618 y_s_ptr2 = y_s_ptr + group2 * y_s_col_stride + row
619 y_s_ptr3 = y_s_ptr + group3 * y_s_col_stride + row
620 y_s_ptr4 = y_s_ptr + group4 * y_s_col_stride + row
621 y_s_ptr5 = y_s_ptr + group5 * y_s_col_stride + row
622 y_s_ptr6 = y_s_ptr + group6 * y_s_col_stride + row
623 y_s_ptr7 = y_s_ptr + group7 * y_s_col_stride + row
625 cols = tl.arange(0, BLOCK)
626 mask = cols < group_size
628 y0 = tl.load(y_ptr0 + cols, mask=mask, other=0.0).to(tl.float32)
629 y1 = tl.load(y_ptr1 + cols, mask=mask, other=0.0).to(tl.float32)
630 y2 = tl.load(y_ptr2 + cols, mask=mask, other=0.0).to(tl.float32)
631 y3 = tl.load(y_ptr3 + cols, mask=mask, other=0.0).to(tl.float32)
632 y4 = tl.load(y_ptr4 + cols, mask=mask, other=0.0).to(tl.float32)
633 y5 = tl.load(y_ptr5 + cols, mask=mask, other=0.0).to(tl.float32)
634 y6 = tl.load(y_ptr6 + cols, mask=mask, other=0.0).to(tl.float32)
635 y7 = tl.load(y_ptr7 + cols, mask=mask, other=0.0).to(tl.float32)
637 abs0 = tl.abs(y0)
638 abs1 = tl.abs(y1)
639 abs2 = tl.abs(y2)
640 abs3 = tl.abs(y3)
641 abs4 = tl.abs(y4)
642 abs5 = tl.abs(y5)
643 abs6 = tl.abs(y6)
644 abs7 = tl.abs(y7)
646 max0 = tl.max(abs0)
647 max1 = tl.max(abs1)
648 max2 = tl.max(abs2)
649 max3 = tl.max(abs3)
650 max4 = tl.max(abs4)
651 max5 = tl.max(abs5)
652 max6 = tl.max(abs6)
653 max7 = tl.max(abs7)
655 y_s0 = tl.maximum(max0, eps) / fp8_max
656 y_s1 = tl.maximum(max1, eps) / fp8_max
657 y_s2 = tl.maximum(max2, eps) / fp8_max
658 y_s3 = tl.maximum(max3, eps) / fp8_max
659 y_s4 = tl.maximum(max4, eps) / fp8_max
660 y_s5 = tl.maximum(max5, eps) / fp8_max
661 y_s6 = tl.maximum(max6, eps) / fp8_max
662 y_s7 = tl.maximum(max7, eps) / fp8_max
664 if scale_ue8m0:
665 y_s0 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s0), 1e-10))))
666 y_s1 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s1), 1e-10))))
667 y_s2 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s2), 1e-10))))
668 y_s3 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s3), 1e-10))))
669 y_s4 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s4), 1e-10))))
670 y_s5 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s5), 1e-10))))
671 y_s6 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s6), 1e-10))))
672 y_s7 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s7), 1e-10))))
674 y_q0 = tl.clamp(y0 / y_s0, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
675 y_q1 = tl.clamp(y1 / y_s1, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
676 y_q2 = tl.clamp(y2 / y_s2, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
677 y_q3 = tl.clamp(y3 / y_s3, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
678 y_q4 = tl.clamp(y4 / y_s4, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
679 y_q5 = tl.clamp(y5 / y_s5, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
680 y_q6 = tl.clamp(y6 / y_s6, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
681 y_q7 = tl.clamp(y7 / y_s7, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
683 tl.store(y_q_ptr0 + cols, y_q0, mask=mask)
684 tl.store(y_s_ptr0, y_s0)
685 tl.store(y_q_ptr1 + cols, y_q1, mask=mask)
686 tl.store(y_s_ptr1, y_s1)
687 tl.store(y_q_ptr2 + cols, y_q2, mask=mask)
688 tl.store(y_s_ptr2, y_s2)
689 tl.store(y_q_ptr3 + cols, y_q3, mask=mask)
690 tl.store(y_s_ptr3, y_s3)
691 tl.store(y_q_ptr4 + cols, y_q4, mask=mask)
692 tl.store(y_s_ptr4, y_s4)
693 tl.store(y_q_ptr5 + cols, y_q5, mask=mask)
694 tl.store(y_s_ptr5, y_s5)
695 tl.store(y_q_ptr6 + cols, y_q6, mask=mask)
696 tl.store(y_s_ptr6, y_s6)
697 tl.store(y_q_ptr7 + cols, y_q7, mask=mask)
698 tl.store(y_s_ptr7, y_s7)
701def Groups_per_program(x, group_size) -> int:
702 if (x.shape[-1] // group_size) % 8 == 0:
703 return 8
704 elif (x.shape[-1] // group_size) % 4 == 0:
705 return 4
706 elif (x.shape[-1] // group_size) % 2 == 0:
707 return 2
708 else:
709 return 1
712def per_token_group_quant_fp8(
713 x: torch.Tensor,
714 group_size: int,
715 eps: float = 1e-10,
716 dtype: Optional[torch.dtype] = None,
717 column_major_scales: bool = False,
718 scale_ue8m0: bool = False,
719) -> Tuple[torch.Tensor, torch.Tensor]:
720 logger.debug("GEMS PER TOKEN GROUP QUANT FP8")
721 # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
722 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype
723 assert x.shape[-1] % group_size == 0, (
724 f"the last dimension of `x` {x.shape[-1]} must be divisible "
725 f"by `group_size` {group_size}"
726 )
727 assert x.stride(-1) == 1, "`x` groups must be contiguous"
729 finfo = torch.finfo(fp8_dtype)
730 fp8_min = finfo.min
731 fp8_max = finfo.max
733 x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
734 M = x.numel() // group_size
735 N = group_size
737 if column_major_scales:
738 shape = (x.shape[-1] // group_size,) + x.shape[:-1]
739 x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
740 else:
741 shape = x.shape[:-1] + (x.shape[-1] // group_size,)
742 x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
744 BLOCK = triton.next_power_of_2(N)
745 num_warps = min(max(BLOCK // 256, 1), 8)
746 num_stages = 1
747 groups_per_program = Groups_per_program(x, group_size)
748 if column_major_scales:
749 if groups_per_program == 8:
750 _per_token_group_quant_fp8_colmajor_m8[(M // 8,)](
751 x,
752 x_q,
753 x_s,
754 group_size,
755 x.shape[1],
756 x.stride(0),
757 x_s.stride(1),
758 eps,
759 fp8_min=fp8_min,
760 fp8_max=fp8_max,
761 scale_ue8m0=scale_ue8m0,
762 BLOCK=BLOCK,
763 num_warps=num_warps,
764 num_stages=num_stages,
765 )
766 elif groups_per_program == 4:
767 _per_token_group_quant_fp8_colmajor_m4[(M // 4,)](
768 x,
769 x_q,
770 x_s,
771 group_size,
772 x.shape[1],
773 x.stride(0),
774 x_s.stride(1),
775 eps,
776 fp8_min=fp8_min,
777 fp8_max=fp8_max,
778 scale_ue8m0=scale_ue8m0,
779 BLOCK=BLOCK,
780 num_warps=num_warps,
781 num_stages=num_stages,
782 )
783 elif groups_per_program == 2:
784 _per_token_group_quant_fp8_colmajor_m2[(M // 2,)](
785 x,
786 x_q,
787 x_s,
788 group_size,
789 x.shape[1],
790 x.stride(0),
791 x_s.stride(1),
792 eps,
793 fp8_min=fp8_min,
794 fp8_max=fp8_max,
795 scale_ue8m0=scale_ue8m0,
796 BLOCK=BLOCK,
797 num_warps=num_warps,
798 num_stages=num_stages,
799 )
800 else:
801 _per_token_group_quant_fp8_colmajor[(M,)](
802 x,
803 x_q,
804 x_s,
805 group_size,
806 x.shape[1],
807 x.stride(0),
808 x_s.stride(1),
809 eps,
810 fp8_min=fp8_min,
811 fp8_max=fp8_max,
812 scale_ue8m0=scale_ue8m0,
813 BLOCK=BLOCK,
814 num_warps=num_warps,
815 num_stages=num_stages,
816 )
817 else:
818 if groups_per_program == 8:
819 _per_token_group_quant_fp8_m8[(M // 8,)](
820 x,
821 x_q,
822 x_s,
823 group_size,
824 x.shape[1],
825 x.stride(0),
826 eps,
827 fp8_min=fp8_min,
828 fp8_max=fp8_max,
829 scale_ue8m0=scale_ue8m0,
830 BLOCK=BLOCK,
831 num_warps=num_warps,
832 num_stages=num_stages,
833 )
834 elif groups_per_program == 4:
835 _per_token_group_quant_fp8_m4[(M // 4,)](
836 x,
837 x_q,
838 x_s,
839 group_size,
840 x.shape[1],
841 x.stride(0),
842 eps,
843 fp8_min=fp8_min,
844 fp8_max=fp8_max,
845 scale_ue8m0=scale_ue8m0,
846 BLOCK=BLOCK,
847 num_warps=num_warps,
848 num_stages=num_stages,
849 )
850 elif groups_per_program == 2:
851 _per_token_group_quant_fp8_m2[(M // 2,)](
852 x,
853 x_q,
854 x_s,
855 group_size,
856 x.shape[1],
857 x.stride(0),
858 eps,
859 fp8_min=fp8_min,
860 fp8_max=fp8_max,
861 scale_ue8m0=scale_ue8m0,
862 BLOCK=BLOCK,
863 num_warps=num_warps,
864 num_stages=num_stages,
865 )
866 else:
867 _per_token_group_quant_fp8[(M,)](
868 x,
869 x_q,
870 x_s,
871 group_size,
872 x.shape[1],
873 x.stride(0),
874 eps,
875 fp8_min=fp8_min,
876 fp8_max=fp8_max,
877 scale_ue8m0=scale_ue8m0,
878 BLOCK=BLOCK,
879 num_warps=num_warps,
880 num_stages=num_stages,
881 )
883 return x_q, x_s