Coverage for src/flag_gems/ops/per_token_group_quant_fp8.py: 12%

523 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils.device_info import get_device_capability 

10 

11if torch_device_fn.is_available() and get_device_capability() >= (9, 0): 

12 SUPPORTED_FP8_DTYPE = torch.float8_e4m3fn 

13else: 

14 SUPPORTED_FP8_DTYPE = torch.float32 

15 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20@triton.jit 

21def _per_token_group_quant_fp8( 

22 y_ptr, 

23 y_q_ptr, 

24 y_s_ptr, 

25 group_size, 

26 y_num_columns, 

27 y_row_stride, 

28 eps, 

29 fp8_min, 

30 fp8_max, 

31 scale_ue8m0, 

32 BLOCK: tl.constexpr, 

33): 

34 groups_per_row = y_num_columns // group_size 

35 

36 g_id = tl.program_id(0) 

37 row = g_id // groups_per_row 

38 row_g_id = g_id % groups_per_row 

39 

40 y_ptr += (row * y_row_stride) + (row_g_id * group_size) 

41 y_q_ptr += g_id * group_size 

42 y_s_ptr += g_id 

43 

44 cols = tl.arange(0, BLOCK) 

45 mask = cols < group_size 

46 

47 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) 

48 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) 

49 y_s = _absmax / fp8_max 

50 

51 if scale_ue8m0: 

52 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) 

53 

54 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

55 

56 tl.store(y_q_ptr + cols, y_q, mask=mask) 

57 tl.store(y_s_ptr, y_s) 

58 

59 

60@triton.jit 

61def _per_token_group_quant_fp8_colmajor( 

62 y_ptr, 

63 y_q_ptr, 

64 y_s_ptr, 

65 group_size, 

66 y_num_columns, 

67 y_row_stride, 

68 y_s_col_stride, 

69 eps, 

70 fp8_min, 

71 fp8_max, 

72 scale_ue8m0, 

73 BLOCK: tl.constexpr, 

74): 

75 groups_per_row = y_num_columns // group_size 

76 

77 g_id = tl.program_id(0) 

78 row = g_id // groups_per_row 

79 group_id = g_id % groups_per_row 

80 

81 y_ptr += row * y_row_stride + group_id * group_size 

82 y_q_ptr += g_id * group_size 

83 y_s_ptr += group_id * y_s_col_stride + row 

84 

85 cols = tl.arange(0, BLOCK) 

86 mask = cols < group_size 

87 

88 y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) 

89 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) 

90 y_s = _absmax / fp8_max 

91 

92 if scale_ue8m0: 

93 y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) 

94 

95 y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

96 

97 tl.store(y_q_ptr + cols, y_q, mask=mask) 

98 tl.store(y_s_ptr, y_s) 

99 

100 

101@triton.jit 

102def _per_token_group_quant_fp8_m2( 

103 y_ptr, 

104 y_q_ptr, 

105 y_s_ptr, 

106 group_size, 

107 y_num_columns, 

108 y_row_stride, 

109 eps, 

110 fp8_min, 

111 fp8_max, 

112 scale_ue8m0, 

113 BLOCK: tl.constexpr, 

114): 

115 groups_per_row = y_num_columns // group_size 

116 pid = tl.program_id(0) 

117 pairs_per_row = groups_per_row // 2 

118 row = pid // pairs_per_row 

119 pair_id = pid % pairs_per_row 

120 

121 group0 = pair_id * 2 

122 group1 = group0 + 1 

123 

124 g0 = row * groups_per_row + group0 

125 g1 = g0 + 1 

126 

127 base = y_ptr + row * y_row_stride 

128 

129 y_ptr0 = base + group0 * group_size 

130 y_ptr1 = base + group1 * group_size 

131 

132 y_q_ptr0 = y_q_ptr + g0 * group_size 

133 y_q_ptr1 = y_q_ptr + g1 * group_size 

134 

135 y_s_ptr0 = y_s_ptr + g0 

136 y_s_ptr1 = y_s_ptr + g1 

137 

138 cols = tl.arange(0, BLOCK) 

139 mask = cols < group_size 

140 

141 y0 = tl.load(y_ptr0 + cols, mask=mask, other=0.0).to(tl.float32) 

142 y1 = tl.load(y_ptr1 + cols, mask=mask, other=0.0).to(tl.float32) 

143 

144 abs0 = tl.abs(y0) 

145 abs1 = tl.abs(y1) 

146 

147 max0 = tl.max(abs0) 

148 max1 = tl.max(abs1) 

149 

150 y_s0 = tl.maximum(max0, eps) / fp8_max 

151 y_s1 = tl.maximum(max1, eps) / fp8_max 

152 

153 if scale_ue8m0: 

154 y_s0 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s0), 1e-10)))) 

155 y_s1 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s1), 1e-10)))) 

156 

157 y_q0 = tl.clamp(y0 / y_s0, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

158 y_q1 = tl.clamp(y1 / y_s1, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

159 

160 tl.store(y_q_ptr0 + cols, y_q0, mask=mask) 

161 tl.store(y_s_ptr0, y_s0) 

162 tl.store(y_q_ptr1 + cols, y_q1, mask=mask) 

163 tl.store(y_s_ptr1, y_s1) 

164 

165 

166@triton.jit 

167def _per_token_group_quant_fp8_colmajor_m2( 

168 y_ptr, 

169 y_q_ptr, 

170 y_s_ptr, 

171 group_size, 

172 y_num_columns, 

173 y_row_stride, 

174 y_s_col_stride, 

175 eps, 

176 fp8_min, 

177 fp8_max, 

178 scale_ue8m0, 

179 BLOCK: tl.constexpr, 

180): 

181 groups_per_row = y_num_columns // group_size 

182 pid = tl.program_id(0) 

183 pairs_per_row = groups_per_row // 2 

184 row = pid // pairs_per_row 

185 pair_id = pid % pairs_per_row 

186 

187 group0 = pair_id * 2 

188 group1 = group0 + 1 

189 

190 g0 = row * groups_per_row + group0 

191 g1 = g0 + 1 

192 

193 base = y_ptr + row * y_row_stride 

194 

195 y_ptr0 = base + group0 * group_size 

196 y_ptr1 = base + group1 * group_size 

197 

198 y_q_ptr0 = y_q_ptr + g0 * group_size 

199 y_q_ptr1 = y_q_ptr + g1 * group_size 

200 

201 y_s_ptr0 = y_s_ptr + group0 * y_s_col_stride + row 

202 y_s_ptr1 = y_s_ptr + group1 * y_s_col_stride + row 

203 

204 cols = tl.arange(0, BLOCK) 

205 mask = cols < group_size 

206 

207 y0 = tl.load(y_ptr0 + cols, mask=mask, other=0.0).to(tl.float32) 

208 y1 = tl.load(y_ptr1 + cols, mask=mask, other=0.0).to(tl.float32) 

209 

210 abs0 = tl.abs(y0) 

211 abs1 = tl.abs(y1) 

212 

213 max0 = tl.max(abs0) 

214 max1 = tl.max(abs1) 

215 

216 y_s0 = tl.maximum(max0, eps) / fp8_max 

217 y_s1 = tl.maximum(max1, eps) / fp8_max 

218 

219 if scale_ue8m0: 

220 y_s0 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s0), 1e-10)))) 

221 y_s1 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s1), 1e-10)))) 

222 

223 y_q0 = tl.clamp(y0 / y_s0, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

224 y_q1 = tl.clamp(y1 / y_s1, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

225 

226 tl.store(y_q_ptr0 + cols, y_q0, mask=mask) 

227 tl.store(y_s_ptr0, y_s0) 

228 tl.store(y_q_ptr1 + cols, y_q1, mask=mask) 

229 tl.store(y_s_ptr1, y_s1) 

230 

231 

232@triton.jit 

233def _per_token_group_quant_fp8_m4( 

234 y_ptr, 

235 y_q_ptr, 

236 y_s_ptr, 

237 group_size, 

238 y_num_columns, 

239 y_row_stride, 

240 eps, 

241 fp8_min, 

242 fp8_max, 

243 scale_ue8m0, 

244 BLOCK: tl.constexpr, 

245): 

246 groups_per_row = y_num_columns // group_size 

247 pid = tl.program_id(0) 

248 pairs_per_row = groups_per_row // 4 

249 row = pid // pairs_per_row 

250 pair_id = pid % pairs_per_row 

251 

252 group0 = pair_id * 4 

253 group1 = group0 + 1 

254 group2 = group0 + 2 

255 group3 = group0 + 3 

256 

257 g0 = row * groups_per_row + group0 

258 g1 = g0 + 1 

259 g2 = g1 + 1 

260 g3 = g2 + 1 

261 

262 base = y_ptr + row * y_row_stride 

263 

264 y_ptr0 = base + group0 * group_size 

265 y_ptr1 = base + group1 * group_size 

266 y_ptr2 = base + group2 * group_size 

267 y_ptr3 = base + group3 * group_size 

268 

269 y_q_ptr0 = y_q_ptr + g0 * group_size 

270 y_q_ptr1 = y_q_ptr + g1 * group_size 

271 y_q_ptr2 = y_q_ptr + g2 * group_size 

272 y_q_ptr3 = y_q_ptr + g3 * group_size 

273 

274 y_s_ptr0 = y_s_ptr + g0 

275 y_s_ptr1 = y_s_ptr + g1 

276 y_s_ptr2 = y_s_ptr + g2 

277 y_s_ptr3 = y_s_ptr + g3 

278 

279 cols = tl.arange(0, BLOCK) 

280 mask = cols < group_size 

281 

282 y0 = tl.load(y_ptr0 + cols, mask=mask, other=0.0).to(tl.float32) 

283 y1 = tl.load(y_ptr1 + cols, mask=mask, other=0.0).to(tl.float32) 

284 y2 = tl.load(y_ptr2 + cols, mask=mask, other=0.0).to(tl.float32) 

285 y3 = tl.load(y_ptr3 + cols, mask=mask, other=0.0).to(tl.float32) 

286 

287 abs0 = tl.abs(y0) 

288 abs1 = tl.abs(y1) 

289 abs2 = tl.abs(y2) 

290 abs3 = tl.abs(y3) 

291 

292 max0 = tl.max(abs0) 

293 max1 = tl.max(abs1) 

294 max2 = tl.max(abs2) 

295 max3 = tl.max(abs3) 

296 

297 y_s0 = tl.maximum(max0, eps) / fp8_max 

298 y_s1 = tl.maximum(max1, eps) / fp8_max 

299 y_s2 = tl.maximum(max2, eps) / fp8_max 

300 y_s3 = tl.maximum(max3, eps) / fp8_max 

301 

302 if scale_ue8m0: 

303 y_s0 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s0), 1e-10)))) 

304 y_s1 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s1), 1e-10)))) 

305 y_s2 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s2), 1e-10)))) 

306 y_s3 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s3), 1e-10)))) 

307 

308 y_q0 = tl.clamp(y0 / y_s0, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

309 y_q1 = tl.clamp(y1 / y_s1, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

310 y_q2 = tl.clamp(y2 / y_s2, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

311 y_q3 = tl.clamp(y3 / y_s3, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

312 

313 tl.store(y_q_ptr0 + cols, y_q0, mask=mask) 

314 tl.store(y_s_ptr0, y_s0) 

315 tl.store(y_q_ptr1 + cols, y_q1, mask=mask) 

316 tl.store(y_s_ptr1, y_s1) 

317 tl.store(y_q_ptr2 + cols, y_q2, mask=mask) 

318 tl.store(y_s_ptr2, y_s2) 

319 tl.store(y_q_ptr3 + cols, y_q3, mask=mask) 

320 tl.store(y_s_ptr3, y_s3) 

321 

322 

323@triton.jit 

324def _per_token_group_quant_fp8_colmajor_m4( 

325 y_ptr, 

326 y_q_ptr, 

327 y_s_ptr, 

328 group_size, 

329 y_num_columns, 

330 y_row_stride, 

331 y_s_col_stride, 

332 eps, 

333 fp8_min, 

334 fp8_max, 

335 scale_ue8m0, 

336 BLOCK: tl.constexpr, 

337): 

338 groups_per_row = y_num_columns // group_size 

339 pid = tl.program_id(0) 

340 pairs_per_row = groups_per_row // 4 

341 row = pid // pairs_per_row 

342 pair_id = pid % pairs_per_row 

343 

344 group0 = pair_id * 4 

345 group1 = group0 + 1 

346 group2 = group1 + 1 

347 group3 = group2 + 1 

348 

349 g0 = row * groups_per_row + group0 

350 g1 = g0 + 1 

351 g2 = g1 + 1 

352 g3 = g2 + 1 

353 

354 base = y_ptr + row * y_row_stride 

355 

356 y_ptr0 = base + group0 * group_size 

357 y_ptr1 = base + group1 * group_size 

358 y_ptr2 = base + group2 * group_size 

359 y_ptr3 = base + group3 * group_size 

360 

361 y_q_ptr0 = y_q_ptr + g0 * group_size 

362 y_q_ptr1 = y_q_ptr + g1 * group_size 

363 y_q_ptr2 = y_q_ptr + g2 * group_size 

364 y_q_ptr3 = y_q_ptr + g3 * group_size 

365 

366 y_s_ptr0 = y_s_ptr + group0 * y_s_col_stride + row 

367 y_s_ptr1 = y_s_ptr + group1 * y_s_col_stride + row 

368 y_s_ptr2 = y_s_ptr + group2 * y_s_col_stride + row 

369 y_s_ptr3 = y_s_ptr + group3 * y_s_col_stride + row 

370 

371 cols = tl.arange(0, BLOCK) 

372 mask = cols < group_size 

373 

374 y0 = tl.load(y_ptr0 + cols, mask=mask, other=0.0).to(tl.float32) 

375 y1 = tl.load(y_ptr1 + cols, mask=mask, other=0.0).to(tl.float32) 

376 y2 = tl.load(y_ptr2 + cols, mask=mask, other=0.0).to(tl.float32) 

377 y3 = tl.load(y_ptr3 + cols, mask=mask, other=0.0).to(tl.float32) 

378 

379 abs0 = tl.abs(y0) 

380 abs1 = tl.abs(y1) 

381 abs2 = tl.abs(y2) 

382 abs3 = tl.abs(y3) 

383 

384 max0 = tl.max(abs0) 

385 max1 = tl.max(abs1) 

386 max2 = tl.max(abs2) 

387 max3 = tl.max(abs3) 

388 

389 y_s0 = tl.maximum(max0, eps) / fp8_max 

390 y_s1 = tl.maximum(max1, eps) / fp8_max 

391 y_s2 = tl.maximum(max2, eps) / fp8_max 

392 y_s3 = tl.maximum(max3, eps) / fp8_max 

393 

394 if scale_ue8m0: 

395 y_s0 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s0), 1e-10)))) 

396 y_s1 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s1), 1e-10)))) 

397 y_s2 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s2), 1e-10)))) 

398 y_s3 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s3), 1e-10)))) 

399 

400 y_q0 = tl.clamp(y0 / y_s0, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

401 y_q1 = tl.clamp(y1 / y_s1, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

402 y_q2 = tl.clamp(y2 / y_s2, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

403 y_q3 = tl.clamp(y3 / y_s3, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

404 

405 tl.store(y_q_ptr0 + cols, y_q0, mask=mask) 

406 tl.store(y_s_ptr0, y_s0) 

407 tl.store(y_q_ptr1 + cols, y_q1, mask=mask) 

408 tl.store(y_s_ptr1, y_s1) 

409 tl.store(y_q_ptr2 + cols, y_q2, mask=mask) 

410 tl.store(y_s_ptr2, y_s2) 

411 tl.store(y_q_ptr3 + cols, y_q3, mask=mask) 

412 tl.store(y_s_ptr3, y_s3) 

413 

414 

415@triton.jit 

416def _per_token_group_quant_fp8_m8( 

417 y_ptr, 

418 y_q_ptr, 

419 y_s_ptr, 

420 group_size, 

421 y_num_columns, 

422 y_row_stride, 

423 eps, 

424 fp8_min, 

425 fp8_max, 

426 scale_ue8m0, 

427 BLOCK: tl.constexpr, 

428): 

429 groups_per_row = y_num_columns // group_size 

430 pid = tl.program_id(0) 

431 pairs_per_row = groups_per_row // 8 

432 row = pid // pairs_per_row 

433 pair_id = pid % pairs_per_row 

434 

435 group0 = pair_id * 8 

436 group1 = group0 + 1 

437 group2 = group0 + 2 

438 group3 = group0 + 3 

439 group4 = group0 + 4 

440 group5 = group0 + 5 

441 group6 = group0 + 6 

442 group7 = group0 + 7 

443 

444 g0 = row * groups_per_row + group0 

445 g1 = g0 + 1 

446 g2 = g1 + 1 

447 g3 = g2 + 1 

448 g4 = g3 + 1 

449 g5 = g4 + 1 

450 g6 = g5 + 1 

451 g7 = g6 + 1 

452 

453 base = y_ptr + row * y_row_stride 

454 

455 y_ptr0 = base + group0 * group_size 

456 y_ptr1 = base + group1 * group_size 

457 y_ptr2 = base + group2 * group_size 

458 y_ptr3 = base + group3 * group_size 

459 y_ptr4 = base + group4 * group_size 

460 y_ptr5 = base + group5 * group_size 

461 y_ptr6 = base + group6 * group_size 

462 y_ptr7 = base + group7 * group_size 

463 

464 y_q_ptr0 = y_q_ptr + g0 * group_size 

465 y_q_ptr1 = y_q_ptr + g1 * group_size 

466 y_q_ptr2 = y_q_ptr + g2 * group_size 

467 y_q_ptr3 = y_q_ptr + g3 * group_size 

468 y_q_ptr4 = y_q_ptr + g4 * group_size 

469 y_q_ptr5 = y_q_ptr + g5 * group_size 

470 y_q_ptr6 = y_q_ptr + g6 * group_size 

471 y_q_ptr7 = y_q_ptr + g7 * group_size 

472 

473 y_s_ptr0 = y_s_ptr + g0 

474 y_s_ptr1 = y_s_ptr + g1 

475 y_s_ptr2 = y_s_ptr + g2 

476 y_s_ptr3 = y_s_ptr + g3 

477 y_s_ptr4 = y_s_ptr + g4 

478 y_s_ptr5 = y_s_ptr + g5 

479 y_s_ptr6 = y_s_ptr + g6 

480 y_s_ptr7 = y_s_ptr + g7 

481 

482 cols = tl.arange(0, BLOCK) 

483 mask = cols < group_size 

484 

485 y0 = tl.load(y_ptr0 + cols, mask=mask, other=0.0).to(tl.float32) 

486 y1 = tl.load(y_ptr1 + cols, mask=mask, other=0.0).to(tl.float32) 

487 y2 = tl.load(y_ptr2 + cols, mask=mask, other=0.0).to(tl.float32) 

488 y3 = tl.load(y_ptr3 + cols, mask=mask, other=0.0).to(tl.float32) 

489 y4 = tl.load(y_ptr4 + cols, mask=mask, other=0.0).to(tl.float32) 

490 y5 = tl.load(y_ptr5 + cols, mask=mask, other=0.0).to(tl.float32) 

491 y6 = tl.load(y_ptr6 + cols, mask=mask, other=0.0).to(tl.float32) 

492 y7 = tl.load(y_ptr7 + cols, mask=mask, other=0.0).to(tl.float32) 

493 

494 abs0 = tl.abs(y0) 

495 abs1 = tl.abs(y1) 

496 abs2 = tl.abs(y2) 

497 abs3 = tl.abs(y3) 

498 abs4 = tl.abs(y4) 

499 abs5 = tl.abs(y5) 

500 abs6 = tl.abs(y6) 

501 abs7 = tl.abs(y7) 

502 

503 max0 = tl.max(abs0) 

504 max1 = tl.max(abs1) 

505 max2 = tl.max(abs2) 

506 max3 = tl.max(abs3) 

507 max4 = tl.max(abs4) 

508 max5 = tl.max(abs5) 

509 max6 = tl.max(abs6) 

510 max7 = tl.max(abs7) 

511 y_s0 = tl.maximum(max0, eps) / fp8_max 

512 y_s1 = tl.maximum(max1, eps) / fp8_max 

513 y_s2 = tl.maximum(max2, eps) / fp8_max 

514 y_s3 = tl.maximum(max3, eps) / fp8_max 

515 y_s4 = tl.maximum(max4, eps) / fp8_max 

516 y_s5 = tl.maximum(max5, eps) / fp8_max 

517 y_s6 = tl.maximum(max6, eps) / fp8_max 

518 y_s7 = tl.maximum(max7, eps) / fp8_max 

519 

520 if scale_ue8m0: 

521 y_s0 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s0), 1e-10)))) 

522 y_s1 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s1), 1e-10)))) 

523 y_s2 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s2), 1e-10)))) 

524 y_s3 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s3), 1e-10)))) 

525 y_s4 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s4), 1e-10)))) 

526 y_s5 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s5), 1e-10)))) 

527 y_s6 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s6), 1e-10)))) 

528 y_s7 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s7), 1e-10)))) 

529 

530 y_q0 = tl.clamp(y0 / y_s0, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

531 y_q1 = tl.clamp(y1 / y_s1, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

532 y_q2 = tl.clamp(y2 / y_s2, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

533 y_q3 = tl.clamp(y3 / y_s3, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

534 y_q4 = tl.clamp(y4 / y_s4, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

535 y_q5 = tl.clamp(y5 / y_s5, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

536 y_q6 = tl.clamp(y6 / y_s6, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

537 y_q7 = tl.clamp(y7 / y_s7, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

538 

539 tl.store(y_q_ptr0 + cols, y_q0, mask=mask) 

540 tl.store(y_s_ptr0, y_s0) 

541 tl.store(y_q_ptr1 + cols, y_q1, mask=mask) 

542 tl.store(y_s_ptr1, y_s1) 

543 tl.store(y_q_ptr2 + cols, y_q2, mask=mask) 

544 tl.store(y_s_ptr2, y_s2) 

545 tl.store(y_q_ptr3 + cols, y_q3, mask=mask) 

546 tl.store(y_s_ptr3, y_s3) 

547 tl.store(y_q_ptr4 + cols, y_q4, mask=mask) 

548 tl.store(y_s_ptr4, y_s4) 

549 tl.store(y_q_ptr5 + cols, y_q5, mask=mask) 

550 tl.store(y_s_ptr5, y_s5) 

551 tl.store(y_q_ptr6 + cols, y_q6, mask=mask) 

552 tl.store(y_s_ptr6, y_s6) 

553 tl.store(y_q_ptr7 + cols, y_q7, mask=mask) 

554 tl.store(y_s_ptr7, y_s7) 

555 

556 

557@triton.jit 

558def _per_token_group_quant_fp8_colmajor_m8( 

559 y_ptr, 

560 y_q_ptr, 

561 y_s_ptr, 

562 group_size, 

563 y_num_columns, 

564 y_row_stride, 

565 y_s_col_stride, 

566 eps, 

567 fp8_min, 

568 fp8_max, 

569 scale_ue8m0, 

570 BLOCK: tl.constexpr, 

571): 

572 groups_per_row = y_num_columns // group_size 

573 pid = tl.program_id(0) 

574 pairs_per_row = groups_per_row // 8 

575 row = pid // pairs_per_row 

576 pair_id = pid % pairs_per_row 

577 

578 group0 = pair_id * 8 

579 group1 = group0 + 1 

580 group2 = group1 + 1 

581 group3 = group2 + 1 

582 group4 = group3 + 1 

583 group5 = group4 + 1 

584 group6 = group5 + 1 

585 group7 = group6 + 1 

586 

587 g0 = row * groups_per_row + group0 

588 g1 = g0 + 1 

589 g2 = g1 + 1 

590 g3 = g2 + 1 

591 g4 = g3 + 1 

592 g5 = g4 + 1 

593 g6 = g5 + 1 

594 g7 = g6 + 1 

595 

596 base = y_ptr + row * y_row_stride 

597 

598 y_ptr0 = base + group0 * group_size 

599 y_ptr1 = base + group1 * group_size 

600 y_ptr2 = base + group2 * group_size 

601 y_ptr3 = base + group3 * group_size 

602 y_ptr4 = base + group4 * group_size 

603 y_ptr5 = base + group5 * group_size 

604 y_ptr6 = base + group6 * group_size 

605 y_ptr7 = base + group7 * group_size 

606 

607 y_q_ptr0 = y_q_ptr + g0 * group_size 

608 y_q_ptr1 = y_q_ptr + g1 * group_size 

609 y_q_ptr2 = y_q_ptr + g2 * group_size 

610 y_q_ptr3 = y_q_ptr + g3 * group_size 

611 y_q_ptr4 = y_q_ptr + g4 * group_size 

612 y_q_ptr5 = y_q_ptr + g5 * group_size 

613 y_q_ptr6 = y_q_ptr + g6 * group_size 

614 y_q_ptr7 = y_q_ptr + g7 * group_size 

615 

616 y_s_ptr0 = y_s_ptr + group0 * y_s_col_stride + row 

617 y_s_ptr1 = y_s_ptr + group1 * y_s_col_stride + row 

618 y_s_ptr2 = y_s_ptr + group2 * y_s_col_stride + row 

619 y_s_ptr3 = y_s_ptr + group3 * y_s_col_stride + row 

620 y_s_ptr4 = y_s_ptr + group4 * y_s_col_stride + row 

621 y_s_ptr5 = y_s_ptr + group5 * y_s_col_stride + row 

622 y_s_ptr6 = y_s_ptr + group6 * y_s_col_stride + row 

623 y_s_ptr7 = y_s_ptr + group7 * y_s_col_stride + row 

624 

625 cols = tl.arange(0, BLOCK) 

626 mask = cols < group_size 

627 

628 y0 = tl.load(y_ptr0 + cols, mask=mask, other=0.0).to(tl.float32) 

629 y1 = tl.load(y_ptr1 + cols, mask=mask, other=0.0).to(tl.float32) 

630 y2 = tl.load(y_ptr2 + cols, mask=mask, other=0.0).to(tl.float32) 

631 y3 = tl.load(y_ptr3 + cols, mask=mask, other=0.0).to(tl.float32) 

632 y4 = tl.load(y_ptr4 + cols, mask=mask, other=0.0).to(tl.float32) 

633 y5 = tl.load(y_ptr5 + cols, mask=mask, other=0.0).to(tl.float32) 

634 y6 = tl.load(y_ptr6 + cols, mask=mask, other=0.0).to(tl.float32) 

635 y7 = tl.load(y_ptr7 + cols, mask=mask, other=0.0).to(tl.float32) 

636 

637 abs0 = tl.abs(y0) 

638 abs1 = tl.abs(y1) 

639 abs2 = tl.abs(y2) 

640 abs3 = tl.abs(y3) 

641 abs4 = tl.abs(y4) 

642 abs5 = tl.abs(y5) 

643 abs6 = tl.abs(y6) 

644 abs7 = tl.abs(y7) 

645 

646 max0 = tl.max(abs0) 

647 max1 = tl.max(abs1) 

648 max2 = tl.max(abs2) 

649 max3 = tl.max(abs3) 

650 max4 = tl.max(abs4) 

651 max5 = tl.max(abs5) 

652 max6 = tl.max(abs6) 

653 max7 = tl.max(abs7) 

654 

655 y_s0 = tl.maximum(max0, eps) / fp8_max 

656 y_s1 = tl.maximum(max1, eps) / fp8_max 

657 y_s2 = tl.maximum(max2, eps) / fp8_max 

658 y_s3 = tl.maximum(max3, eps) / fp8_max 

659 y_s4 = tl.maximum(max4, eps) / fp8_max 

660 y_s5 = tl.maximum(max5, eps) / fp8_max 

661 y_s6 = tl.maximum(max6, eps) / fp8_max 

662 y_s7 = tl.maximum(max7, eps) / fp8_max 

663 

664 if scale_ue8m0: 

665 y_s0 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s0), 1e-10)))) 

666 y_s1 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s1), 1e-10)))) 

667 y_s2 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s2), 1e-10)))) 

668 y_s3 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s3), 1e-10)))) 

669 y_s4 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s4), 1e-10)))) 

670 y_s5 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s5), 1e-10)))) 

671 y_s6 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s6), 1e-10)))) 

672 y_s7 = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s7), 1e-10)))) 

673 

674 y_q0 = tl.clamp(y0 / y_s0, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

675 y_q1 = tl.clamp(y1 / y_s1, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

676 y_q2 = tl.clamp(y2 / y_s2, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

677 y_q3 = tl.clamp(y3 / y_s3, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

678 y_q4 = tl.clamp(y4 / y_s4, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

679 y_q5 = tl.clamp(y5 / y_s5, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

680 y_q6 = tl.clamp(y6 / y_s6, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

681 y_q7 = tl.clamp(y7 / y_s7, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) 

682 

683 tl.store(y_q_ptr0 + cols, y_q0, mask=mask) 

684 tl.store(y_s_ptr0, y_s0) 

685 tl.store(y_q_ptr1 + cols, y_q1, mask=mask) 

686 tl.store(y_s_ptr1, y_s1) 

687 tl.store(y_q_ptr2 + cols, y_q2, mask=mask) 

688 tl.store(y_s_ptr2, y_s2) 

689 tl.store(y_q_ptr3 + cols, y_q3, mask=mask) 

690 tl.store(y_s_ptr3, y_s3) 

691 tl.store(y_q_ptr4 + cols, y_q4, mask=mask) 

692 tl.store(y_s_ptr4, y_s4) 

693 tl.store(y_q_ptr5 + cols, y_q5, mask=mask) 

694 tl.store(y_s_ptr5, y_s5) 

695 tl.store(y_q_ptr6 + cols, y_q6, mask=mask) 

696 tl.store(y_s_ptr6, y_s6) 

697 tl.store(y_q_ptr7 + cols, y_q7, mask=mask) 

698 tl.store(y_s_ptr7, y_s7) 

699 

700 

701def Groups_per_program(x, group_size) -> int: 

702 if (x.shape[-1] // group_size) % 8 == 0: 

703 return 8 

704 elif (x.shape[-1] // group_size) % 4 == 0: 

705 return 4 

706 elif (x.shape[-1] // group_size) % 2 == 0: 

707 return 2 

708 else: 

709 return 1 

710 

711 

712def per_token_group_quant_fp8( 

713 x: torch.Tensor, 

714 group_size: int, 

715 eps: float = 1e-10, 

716 dtype: Optional[torch.dtype] = None, 

717 column_major_scales: bool = False, 

718 scale_ue8m0: bool = False, 

719) -> Tuple[torch.Tensor, torch.Tensor]: 

720 logger.debug("GEMS PER TOKEN GROUP QUANT FP8") 

721 # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` 

722 fp8_dtype = SUPPORTED_FP8_DTYPE if dtype is None else dtype 

723 assert x.shape[-1] % group_size == 0, ( 

724 f"the last dimension of `x` {x.shape[-1]} must be divisible " 

725 f"by `group_size` {group_size}" 

726 ) 

727 assert x.stride(-1) == 1, "`x` groups must be contiguous" 

728 

729 finfo = torch.finfo(fp8_dtype) 

730 fp8_min = finfo.min 

731 fp8_max = finfo.max 

732 

733 x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) 

734 M = x.numel() // group_size 

735 N = group_size 

736 

737 if column_major_scales: 

738 shape = (x.shape[-1] // group_size,) + x.shape[:-1] 

739 x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) 

740 else: 

741 shape = x.shape[:-1] + (x.shape[-1] // group_size,) 

742 x_s = torch.empty(shape, device=x.device, dtype=torch.float32) 

743 

744 BLOCK = triton.next_power_of_2(N) 

745 num_warps = min(max(BLOCK // 256, 1), 8) 

746 num_stages = 1 

747 groups_per_program = Groups_per_program(x, group_size) 

748 if column_major_scales: 

749 if groups_per_program == 8: 

750 _per_token_group_quant_fp8_colmajor_m8[(M // 8,)]( 

751 x, 

752 x_q, 

753 x_s, 

754 group_size, 

755 x.shape[1], 

756 x.stride(0), 

757 x_s.stride(1), 

758 eps, 

759 fp8_min=fp8_min, 

760 fp8_max=fp8_max, 

761 scale_ue8m0=scale_ue8m0, 

762 BLOCK=BLOCK, 

763 num_warps=num_warps, 

764 num_stages=num_stages, 

765 ) 

766 elif groups_per_program == 4: 

767 _per_token_group_quant_fp8_colmajor_m4[(M // 4,)]( 

768 x, 

769 x_q, 

770 x_s, 

771 group_size, 

772 x.shape[1], 

773 x.stride(0), 

774 x_s.stride(1), 

775 eps, 

776 fp8_min=fp8_min, 

777 fp8_max=fp8_max, 

778 scale_ue8m0=scale_ue8m0, 

779 BLOCK=BLOCK, 

780 num_warps=num_warps, 

781 num_stages=num_stages, 

782 ) 

783 elif groups_per_program == 2: 

784 _per_token_group_quant_fp8_colmajor_m2[(M // 2,)]( 

785 x, 

786 x_q, 

787 x_s, 

788 group_size, 

789 x.shape[1], 

790 x.stride(0), 

791 x_s.stride(1), 

792 eps, 

793 fp8_min=fp8_min, 

794 fp8_max=fp8_max, 

795 scale_ue8m0=scale_ue8m0, 

796 BLOCK=BLOCK, 

797 num_warps=num_warps, 

798 num_stages=num_stages, 

799 ) 

800 else: 

801 _per_token_group_quant_fp8_colmajor[(M,)]( 

802 x, 

803 x_q, 

804 x_s, 

805 group_size, 

806 x.shape[1], 

807 x.stride(0), 

808 x_s.stride(1), 

809 eps, 

810 fp8_min=fp8_min, 

811 fp8_max=fp8_max, 

812 scale_ue8m0=scale_ue8m0, 

813 BLOCK=BLOCK, 

814 num_warps=num_warps, 

815 num_stages=num_stages, 

816 ) 

817 else: 

818 if groups_per_program == 8: 

819 _per_token_group_quant_fp8_m8[(M // 8,)]( 

820 x, 

821 x_q, 

822 x_s, 

823 group_size, 

824 x.shape[1], 

825 x.stride(0), 

826 eps, 

827 fp8_min=fp8_min, 

828 fp8_max=fp8_max, 

829 scale_ue8m0=scale_ue8m0, 

830 BLOCK=BLOCK, 

831 num_warps=num_warps, 

832 num_stages=num_stages, 

833 ) 

834 elif groups_per_program == 4: 

835 _per_token_group_quant_fp8_m4[(M // 4,)]( 

836 x, 

837 x_q, 

838 x_s, 

839 group_size, 

840 x.shape[1], 

841 x.stride(0), 

842 eps, 

843 fp8_min=fp8_min, 

844 fp8_max=fp8_max, 

845 scale_ue8m0=scale_ue8m0, 

846 BLOCK=BLOCK, 

847 num_warps=num_warps, 

848 num_stages=num_stages, 

849 ) 

850 elif groups_per_program == 2: 

851 _per_token_group_quant_fp8_m2[(M // 2,)]( 

852 x, 

853 x_q, 

854 x_s, 

855 group_size, 

856 x.shape[1], 

857 x.stride(0), 

858 eps, 

859 fp8_min=fp8_min, 

860 fp8_max=fp8_max, 

861 scale_ue8m0=scale_ue8m0, 

862 BLOCK=BLOCK, 

863 num_warps=num_warps, 

864 num_stages=num_stages, 

865 ) 

866 else: 

867 _per_token_group_quant_fp8[(M,)]( 

868 x, 

869 x_q, 

870 x_s, 

871 group_size, 

872 x.shape[1], 

873 x.stride(0), 

874 eps, 

875 fp8_min=fp8_min, 

876 fp8_max=fp8_max, 

877 scale_ue8m0=scale_ue8m0, 

878 BLOCK=BLOCK, 

879 num_warps=num_warps, 

880 num_stages=num_stages, 

881 ) 

882 

883 return x_q, x_s