Coverage for src/flag_gems/runtime/backend/_tsingmicro/fused/cross_entropy_loss.py: 0%

506 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch.nn import _reduction as _Reduction 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11 

12logger = logging.getLogger(__name__) 

13 

14TOTAL_CORE_NUM = torch_device_fn.get_device_properties().multi_processor_count 

15 

16 

17@libentry() 

18@triton.autotune( 

19 configs=[ 

20 triton.Config({"BLOCK_C": 2**n}, num_warps=1, num_stages=3) 

21 for n in range(10, 17, 2) 

22 ], 

23 key=["C"], 

24) 

25@triton.jit 

26def softmax_forward_kernel( 

27 inp_ptr, 

28 final_max_ptr, 

29 final_sum_ptr, 

30 N, 

31 C: tl.constexpr, 

32 D: tl.constexpr, 

33 BLOCK_C: tl.constexpr, 

34): 

35 job_id = tl.program_id(0) 

36 job_num = tl.num_programs(0) 

37 

38 batch_per_job = N // job_num 

39 job_remain_batch = N - batch_per_job * job_num 

40 batch_per_job += 1 

41 batch_begin = job_id * batch_per_job 

42 if job_id >= job_remain_batch: 

43 batch_per_job -= 1 

44 batch_begin = job_id * batch_per_job + job_remain_batch 

45 batch_end = batch_begin + batch_per_job 

46 

47 for batch_idx in range(batch_begin, batch_end): 

48 pid_n = batch_idx 

49 

50 if C <= BLOCK_C: 

51 offset_d = tl.arange(0, D) 

52 offset_c = tl.arange(0, C) 

53 

54 inp_ptrs = ( 

55 inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

56 ) 

57 inp = tl.load(inp_ptrs).to(tl.float32) 

58 final_max = tl.max(inp, axis=0) 

59 final_sum = tl.sum(tl.exp(inp - final_max[None, :]), axis=0) 

60 

61 final_max_ptrs = final_max_ptr + pid_n * D + offset_d 

62 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d 

63 

64 tl.store(final_max_ptrs, final_max) 

65 tl.store(final_sum_ptrs, final_sum) 

66 else: 

67 tmp_max = tl.zeros([BLOCK_C, D], dtype=tl.float32) 

68 tmp_sum = tl.zeros([BLOCK_C, D], dtype=tl.float32) 

69 offset_d = tl.arange(0, D) 

70 

71 for off in range(0, C, BLOCK_C): 

72 offset_c = off + tl.arange(0, BLOCK_C) 

73 inp_ptrs = ( 

74 inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

75 ) 

76 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

77 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to( 

78 tl.float32 

79 ) 

80 cur_max = tl.maximum(tmp_max, inp) 

81 cur_exp = tl.exp(inp - cur_max) 

82 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

83 tmp_max = cur_max 

84 

85 final_max = tl.max(tmp_max, axis=0) 

86 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :]) 

87 final_sum = tl.sum(tmp_sum, axis=0) 

88 

89 final_max_ptrs = final_max_ptr + pid_n * D + offset_d 

90 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d 

91 

92 tl.store(final_max_ptrs, final_max) 

93 tl.store(final_sum_ptrs, final_sum) 

94 

95 

96@libentry() 

97@triton.autotune( 

98 configs=[ 

99 triton.Config({"C_TILE_NUM": num}, num_warps=1, num_stages=s) 

100 for num in [4, 8, 16, 48] 

101 for s in [0, 3] 

102 ], 

103 key=["C"], 

104 restore_value=["final_max_ptr"], 

105) 

106@triton.jit 

107def max_kernel( 

108 inp_ptr, 

109 final_max_ptr, 

110 N, 

111 C: tl.constexpr, 

112 D: tl.constexpr, 

113 C_TILE_NUM: tl.constexpr, 

114): 

115 job_id = tl.program_id(0) 

116 job_num = tl.num_programs(0) 

117 

118 batch_per_job = N // job_num 

119 job_remain_batch = N - batch_per_job * job_num 

120 batch_per_job += 1 

121 batch_begin = job_id * batch_per_job 

122 if job_id >= job_remain_batch: 

123 batch_per_job -= 1 

124 batch_begin = job_id * batch_per_job + job_remain_batch 

125 batch_end = batch_begin + batch_per_job 

126 

127 core_id = tl.program_id(1) 

128 offset_d = tl.arange(0, D) 

129 BLOCK_C: tl.constexpr = (C + C_TILE_NUM - 1) // C_TILE_NUM 

130 

131 for batch_idx in range(batch_begin, batch_end): 

132 pid_n = batch_idx 

133 offset_c = core_id * BLOCK_C + tl.arange(0, BLOCK_C) 

134 

135 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

136 inp_mask = offset_c[:, None] < C 

137 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(tl.float32) 

138 

139 final_max = tl.max(inp, axis=0) 

140 final_max_ptrs = final_max_ptr + pid_n * D + offset_d 

141 tl.atomic_max(final_max_ptrs, final_max) 

142 

143 

144@libentry() 

145@triton.autotune( 

146 configs=[ 

147 triton.Config({"C_TILE_NUM": num}, num_warps=1, num_stages=s) 

148 for num in [4, 8, 16, 48] 

149 for s in [0, 3] 

150 ], 

151 key=["C"], 

152 reset_to_zero=["final_sum_ptr"], 

153) 

154@triton.jit 

155def softmax_forward_with_max_kernel( 

156 inp_ptr, 

157 final_max_ptr, 

158 final_sum_ptr, 

159 N, 

160 C: tl.constexpr, 

161 D: tl.constexpr, 

162 C_TILE_NUM: tl.constexpr, 

163): 

164 job_id = tl.program_id(0) 

165 job_num = tl.num_programs(0) 

166 

167 batch_per_job = N // job_num 

168 job_remain_batch = N - batch_per_job * job_num 

169 batch_per_job += 1 

170 batch_begin = job_id * batch_per_job 

171 if job_id >= job_remain_batch: 

172 batch_per_job -= 1 

173 batch_begin = job_id * batch_per_job + job_remain_batch 

174 batch_end = batch_begin + batch_per_job 

175 

176 core_id = tl.program_id(1) 

177 offset_d = tl.arange(0, D) 

178 BLOCK_C: tl.constexpr = (C + C_TILE_NUM - 1) // C_TILE_NUM 

179 

180 for batch_idx in range(batch_begin, batch_end): 

181 pid_n = batch_idx 

182 offset_c = core_id * BLOCK_C + tl.arange(0, BLOCK_C) 

183 

184 final_max_ptrs = final_max_ptr + pid_n * D + offset_d 

185 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d 

186 final_max = tl.load(final_max_ptrs) 

187 

188 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

189 inp_mask = offset_c[:, None] < C 

190 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(tl.float32) 

191 

192 final_sum = tl.sum(tl.exp(inp - final_max[None, :]), axis=0) 

193 tl.atomic_add(final_sum_ptrs, final_sum) 

194 

195 

196@libentry() 

197@triton.autotune( 

198 configs=[ 

199 triton.Config({"BLOCK_N": 2**n}, num_warps=4, num_stages=0) 

200 for n in range(4, 11, 2) 

201 ], 

202 key=["N"], 

203) 

204@triton.jit(do_not_specialize=["ignore_index"]) 

205def nllloss_without_weight_kernel( 

206 inp_ptr, 

207 tgt_ptr, 

208 final_max_ptr, 

209 final_sum_ptr, 

210 out_ptr, 

211 ignore_index, 

212 N, 

213 C, 

214 D: tl.constexpr, 

215 BLOCK_N: tl.constexpr, 

216): 

217 core_id = tl.program_id(0) 

218 offset_n = core_id * BLOCK_N + tl.arange(0, BLOCK_N) 

219 offset_d = tl.arange(0, D) 

220 

221 tgt_ptrs = tgt_ptr + offset_n * D + offset_d 

222 tgt_mask = offset_n < N 

223 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

224 

225 ignore_mask = not (tgt == ignore_index) 

226 

227 final_max_ptrs = final_max_ptr + offset_n * D + offset_d 

228 final_sum_ptrs = final_sum_ptr + offset_n * D + offset_d 

229 final_max = tl.load(final_max_ptrs, mask=tgt_mask, other=0) 

230 final_sum = tl.load(final_sum_ptrs, mask=tgt_mask, other=1) 

231 

232 inp_tgt_ptrs = inp_ptr + offset_n * C * D + tgt * D + offset_d 

233 inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32) 

234 

235 loge2 = 0.693147 

236 out = tl.log2(final_sum) * loge2 + final_max - inp_tgt 

237 

238 out_ptrs = out_ptr + offset_n * D + offset_d 

239 tl.store(out_ptrs, out, mask=tgt_mask and ignore_mask) 

240 

241 

242@libentry() 

243@triton.heuristics( 

244 values={ 

245 "num_warps": lambda args: 1, 

246 "num_stages": lambda args: 0, 

247 }, 

248) 

249@triton.jit(do_not_specialize=["ignore_index"]) 

250def nllloss_with_weight_kernel( 

251 inp_ptr, 

252 tgt_ptr, 

253 w_ptr, 

254 w_tgt_ptr, 

255 final_max_ptr, 

256 final_sum_ptr, 

257 out_ptr, 

258 ignore_index, 

259 N, 

260 C, 

261 D: tl.constexpr, 

262): 

263 pid_n = tl.program_id(0) 

264 offset_d = tl.arange(0, D) 

265 

266 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

267 tgt = tl.load(tgt_ptrs) 

268 

269 ignore_mask = not (tgt == ignore_index) 

270 

271 if w_ptr is None: 

272 w_tgt = ignore_mask 

273 else: 

274 w_ptrs = w_ptr + tgt 

275 w_tgt = tl.load(w_ptrs).to(tl.float32) 

276 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d 

277 tl.store(w_tgt_ptrs, w_tgt, mask=ignore_mask) 

278 

279 final_max_ptrs = final_max_ptr + pid_n * D + offset_d 

280 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d 

281 final_max = tl.load(final_max_ptrs) 

282 final_sum = tl.load(final_sum_ptrs) 

283 

284 inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d 

285 inp_tgt = tl.load(inp_tgt_ptrs).to(tl.float32) 

286 

287 loge2 = 0.693147 

288 out = (tl.log2(final_sum) * loge2 + final_max - inp_tgt) * w_tgt 

289 

290 out_ptrs = out_ptr + pid_n * D + offset_d 

291 tl.store(out_ptrs, out, mask=ignore_mask) 

292 

293 

294@libentry() 

295@triton.autotune( 

296 configs=runtime.get_tuned_config("cross_entropy_loss"), 

297 key=["C", "D"], 

298) 

299@triton.jit(do_not_specialize=["label_smoothing"]) 

300def celoss_probability_kernel( 

301 inp_ptr, 

302 tgt_ptr, 

303 w_ptr, 

304 out_ptr, 

305 label_smoothing, 

306 C, 

307 D, 

308 BLOCK_C: tl.constexpr, 

309 BLOCK_D: tl.constexpr, 

310): 

311 pid_d = tl.program_id(0) 

312 pid_n = tl.program_id(1) 

313 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

314 

315 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

316 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

317 

318 for off in range(0, C, BLOCK_C): 

319 offset_c = off + tl.arange(0, BLOCK_C) 

320 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

321 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

322 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

323 cur_max = tl.maximum(tmp_max, inp) 

324 cur_exp = tl.exp(inp - cur_max) 

325 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

326 tmp_max = cur_max 

327 final_max = tl.max(tmp_max, axis=0)[None, :] 

328 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

329 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :] 

330 

331 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

332 for off in range(0, C, BLOCK_C): 

333 offset_c = off + tl.arange(0, BLOCK_C) 

334 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

335 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

336 mask = offset_c[:, None] < C and offset_d[None, :] < D 

337 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

338 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

339 tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C 

340 log = final_sum + final_max - inp 

341 w_mask = offset_c < C 

342 if w_ptr is None: 

343 w = w_mask 

344 else: 

345 w = tl.load(w_ptr + offset_c, mask=w_mask, other=0).to(tl.float32) 

346 _sum += log * tgt * w[:, None] 

347 

348 out = tl.sum(_sum, axis=0) 

349 out_ptrs = out_ptr + pid_n * D + offset_d 

350 tl.store(out_ptrs, out, mask=offset_d < D) 

351 

352 

353@libentry() 

354@triton.autotune( 

355 configs=runtime.get_tuned_config("cross_entropy_loss"), 

356 key=["C", "D"], 

357) 

358@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"]) 

359def celoss_indices_smooth_kernel( 

360 inp_ptr, 

361 tgt_ptr, 

362 w_ptr, 

363 out_ptr, 

364 w_tgt_ptr, 

365 ignore_index, 

366 label_smoothing, 

367 C, 

368 D, 

369 BLOCK_C: tl.constexpr, 

370 BLOCK_D: tl.constexpr, 

371): 

372 pid_d = tl.program_id(0) 

373 pid_n = tl.program_id(1) 

374 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

375 

376 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

377 tgt_mask = offset_d < D 

378 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

379 

380 ignore_mask = not (tgt == ignore_index) and tgt_mask 

381 

382 if w_ptr is None: 

383 w_tgt = ignore_mask 

384 else: 

385 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0) 

386 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d 

387 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask) 

388 

389 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

390 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

391 

392 for off in range(0, C, BLOCK_C): 

393 offset_c = off + tl.arange(0, BLOCK_C) 

394 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

395 mask = offset_c[:, None] < C and offset_d[None, :] < D 

396 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32) 

397 cur_max = tl.maximum(tmp_max, inp) 

398 cur_exp = tl.exp(inp - cur_max) 

399 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

400 tmp_max = cur_max 

401 final_max = tl.max(tmp_max, axis=0)[None, :] 

402 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

403 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :] 

404 final_sum_max = final_sum + final_max 

405 

406 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

407 for off in range(0, C, BLOCK_C): 

408 offset_c = off + tl.arange(0, BLOCK_C) 

409 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

410 mask = offset_c[:, None] < C and offset_d[None, :] < D 

411 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

412 

413 w_mask = offset_c < C 

414 if w_ptr is None: 

415 w = w_mask 

416 else: 

417 w = tl.load(w_ptr + offset_c, w_mask, other=0).to(tl.float32) 

418 

419 smooth = tl.where( 

420 offset_c[:, None] == tgt[None, :], 

421 1 - label_smoothing + label_smoothing / C, 

422 label_smoothing / C, 

423 ).to(tl.float32) 

424 

425 log = final_sum_max - inp 

426 _sum += log * smooth * w[:, None] 

427 

428 out = tl.sum(_sum, axis=0) 

429 out = tl.where(ignore_mask, out, 0) 

430 out_ptrs = out_ptr + pid_n * D + offset_d 

431 tl.store(out_ptrs, out, mask=tgt_mask) 

432 

433 

434@triton.jit 

435def single_celoss_indice_bwd( 

436 pid_n, 

437 offset_c, 

438 offset_d, 

439 final_max, 

440 final_sum, 

441 tgt, 

442 w_tgt, 

443 out_grad, 

444 mean_num, 

445 inp_ptr, 

446 inp_grad_ptr, 

447 ignore_mask, 

448 C, 

449 D, 

450): 

451 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

452 inp_mask = offset_c[:, None] < C 

453 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(tl.float32) 

454 

455 minus_one = offset_c[:, None] == tgt[None, :] 

456 inp_grad = ( 

457 (tl.exp(inp - final_max[None, :]) / final_sum[None, :] - minus_one) 

458 * w_tgt 

459 * out_grad 

460 * mean_num 

461 ) 

462 inp_grad_ptrs = ( 

463 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

464 ) 

465 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask) 

466 

467 

468def config_prune(configs, named_args, **kwargs): 

469 pruned_configs = [] 

470 

471 for config in configs: 

472 kw = config.kwargs 

473 mode, num, BLOCK_C = (kw["TILE_MODE"], kw["C_TILE_NUM"], kw["BLOCK_C"]) 

474 if (mode == 0 and num == 1) or (mode == 1 and num >= 4 and BLOCK_C <= 1024): 

475 pruned_configs.append(config) 

476 return pruned_configs 

477 

478 

479@libentry() 

480@triton.autotune( 

481 configs=[ 

482 triton.Config( 

483 { 

484 "TILE_MODE": mode, 

485 "C_TILE_NUM": num, 

486 "BLOCK_C": 2**n, 

487 }, 

488 num_warps=1, 

489 num_stages=s, 

490 ) 

491 for mode in [0, 1] 

492 for num in [1, 4, 8, 16, 48] 

493 for n in range(10, 17, 2) 

494 for s in [0, 3] 

495 ], 

496 key=["C"], 

497 prune_configs_by={ 

498 "early_config_prune": config_prune, 

499 }, 

500) 

501@triton.jit(do_not_specialize=["ignore_index", "mean_num"]) 

502def celoss_indice_bwd_with_saved_sum_kernel( 

503 out_grad_ptr, 

504 inp_ptr, 

505 tgt_ptr, 

506 w_ptr, 

507 inp_grad_ptr, 

508 final_max_ptr, 

509 final_sum_ptr, 

510 ignore_index, 

511 mean_num, 

512 N, 

513 C: tl.constexpr, 

514 D: tl.constexpr, 

515 is_has_weight: tl.constexpr, 

516 is_has_ignore_index: tl.constexpr, 

517 is_tgt_in_i32: tl.constexpr, 

518 TILE_MODE: tl.constexpr, 

519 C_TILE_NUM: tl.constexpr, 

520 BLOCK_C: tl.constexpr, 

521): 

522 job_id = tl.program_id(0) 

523 job_num = tl.num_programs(0) 

524 

525 batch_per_job = N // job_num 

526 job_remain_batch = N - batch_per_job * job_num 

527 batch_per_job += 1 

528 batch_begin = job_id * batch_per_job 

529 if job_id >= job_remain_batch: 

530 batch_per_job -= 1 

531 batch_begin = job_id * batch_per_job + job_remain_batch 

532 batch_end = batch_begin + batch_per_job 

533 

534 for batch_idx in range(batch_begin, batch_end): 

535 pid_n = batch_idx 

536 offset_d = tl.arange(0, D) 

537 

538 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

539 if is_tgt_in_i32: 

540 tgt = tl.load(tgt_ptrs).to(tl.int32) 

541 else: 

542 tgt = tl.load(tgt_ptrs) 

543 

544 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

545 out_grad = tl.load(out_grad_ptrs).to(tl.float32)[None, :] 

546 

547 if is_has_weight: 

548 w_ptrs = w_ptr + tgt 

549 w_tgt = tl.load(w_ptrs).to(tl.float32)[None, :] 

550 else: 

551 w_tgt = 1 

552 

553 if is_has_ignore_index: 

554 ignore_mask = (tgt != ignore_index)[None, :] 

555 else: 

556 ignore_mask = True 

557 

558 final_max_ptrs = final_max_ptr + pid_n * D + offset_d 

559 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d 

560 final_max = tl.load(final_max_ptrs) 

561 final_sum = tl.load(final_sum_ptrs) 

562 

563 if TILE_MODE == 0: 

564 if C <= BLOCK_C: 

565 offset_c = tl.arange(0, C) 

566 single_celoss_indice_bwd( 

567 pid_n, 

568 offset_c, 

569 offset_d, 

570 final_max, 

571 final_sum, 

572 tgt, 

573 w_tgt, 

574 out_grad, 

575 mean_num, 

576 inp_ptr, 

577 inp_grad_ptr, 

578 ignore_mask, 

579 C, 

580 D, 

581 ) 

582 else: 

583 for off in range(0, C, BLOCK_C): 

584 offset_c = off + tl.arange(0, BLOCK_C) 

585 single_celoss_indice_bwd( 

586 pid_n, 

587 offset_c, 

588 offset_d, 

589 final_max, 

590 final_sum, 

591 tgt, 

592 w_tgt, 

593 out_grad, 

594 mean_num, 

595 inp_ptr, 

596 inp_grad_ptr, 

597 ignore_mask, 

598 C, 

599 D, 

600 ) 

601 else: 

602 core_id = tl.program_id(1) 

603 C_TILE_SIZE: tl.constexpr = (C + C_TILE_NUM - 1) // C_TILE_NUM 

604 offset_c = core_id * C_TILE_SIZE + tl.arange(0, C_TILE_SIZE) 

605 

606 single_celoss_indice_bwd( 

607 pid_n, 

608 offset_c, 

609 offset_d, 

610 final_max, 

611 final_sum, 

612 tgt, 

613 w_tgt, 

614 out_grad, 

615 mean_num, 

616 inp_ptr, 

617 inp_grad_ptr, 

618 ignore_mask, 

619 C, 

620 D, 

621 ) 

622 

623 

624@libentry() 

625@triton.autotune( 

626 configs=runtime.get_tuned_config("cross_entropy_loss"), 

627 key=["C", "D"], 

628) 

629@triton.jit(do_not_specialize=["label_smoothing", "mean_num"]) 

630def celoss_probability_bwd( 

631 out_grad_ptr, 

632 inp_ptr, 

633 tgt_ptr, 

634 w_ptr, 

635 inp_grad_ptr, 

636 label_smoothing, 

637 mean_num, 

638 C, 

639 D, 

640 BLOCK_C: tl.constexpr, 

641 BLOCK_D: tl.constexpr, 

642): 

643 pid_d = tl.program_id(0) 

644 pid_n = tl.program_id(1) 

645 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

646 

647 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

648 out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32)[ 

649 None, : 

650 ] 

651 

652 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

653 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

654 w_tgt_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

655 

656 for off in range(0, C, BLOCK_C): 

657 offset_c = off + tl.arange(0, BLOCK_C) 

658 mask = offset_c[:, None] < C and offset_d[None, :] < D 

659 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

660 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32) 

661 

662 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

663 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

664 tgt = tgt * (1 - label_smoothing) + label_smoothing / C 

665 

666 w_mask = offset_c < C 

667 if w_ptr is None: 

668 w = w_mask 

669 else: 

670 w_ptrs = w_ptr + offset_c 

671 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

672 

673 w_tgt_sum += tgt * w[:, None] 

674 

675 cur_max = tl.maximum(tmp_max, inp) 

676 cur_exp = tl.exp(inp - cur_max) 

677 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

678 tmp_max = cur_max 

679 final_max = tl.max(tmp_max, axis=0)[None, :] 

680 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

681 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

682 w_tgt_sum = tl.sum(w_tgt_sum, axis=0)[None, :] 

683 

684 for off in range(0, C, BLOCK_C): 

685 offset_c = off + tl.arange(0, BLOCK_C) 

686 offset = pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

687 inp_ptrs = inp_ptr + offset 

688 mask = offset_c[:, None] < C and offset_d[None, :] < D 

689 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

690 

691 tgt_ptrs = tgt_ptr + offset 

692 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

693 tgt = tgt * (1 - label_smoothing) + label_smoothing / C 

694 

695 w_mask = offset_c < C 

696 if w_ptr is None: 

697 w = w_mask 

698 else: 

699 w_ptrs = w_ptr + offset_c 

700 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

701 

702 grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - tgt * w[:, None] 

703 inp_grad = grad * out_grad * mean_num 

704 

705 inp_grad_ptrs = inp_grad_ptr + offset 

706 tl.store(inp_grad_ptrs, inp_grad, mask) 

707 

708 

709@libentry() 

710@triton.autotune( 

711 configs=runtime.get_tuned_config("cross_entropy_loss"), 

712 key=["C", "D"], 

713) 

714@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"]) 

715def celoss_indices_smooth_bwd( 

716 out_grad_ptr, 

717 inp_ptr, 

718 tgt_ptr, 

719 w_ptr, 

720 inp_grad_ptr, 

721 ignore_index, 

722 label_smoothing, 

723 mean_num, 

724 C, 

725 D, 

726 BLOCK_C: tl.constexpr, 

727 BLOCK_D: tl.constexpr, 

728): 

729 pid_d = tl.program_id(0) 

730 pid_n = tl.program_id(1) 

731 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

732 

733 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

734 tgt_mask = offset_d < D 

735 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

736 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

737 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :] 

738 

739 ignore_mask = (tgt != ignore_index)[None, :] 

740 

741 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

742 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

743 w_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

744 

745 for off in range(0, C, BLOCK_C): 

746 offset_c = off + tl.arange(0, BLOCK_C) 

747 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

748 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

749 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

750 

751 w_mask = offset_c < C 

752 if w_ptr is None: 

753 w = w_mask 

754 else: 

755 w_ptrs = w_ptr + offset_c 

756 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

757 

758 smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32) 

759 smooth = tl.where( 

760 offset_c[:, None] == tgt[None, :], 

761 1 - label_smoothing + label_smoothing / C, 

762 smooth, 

763 ) 

764 

765 w_sum += smooth * w[:, None] 

766 

767 cur_max = tl.maximum(tmp_max, inp) 

768 cur_exp = tl.exp(inp - cur_max) 

769 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

770 tmp_max = cur_max 

771 final_max = tl.max(tmp_max, axis=0)[None, :] 

772 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

773 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

774 w_sum = tl.sum(w_sum, axis=0)[None, :] 

775 

776 for off in range(0, C, BLOCK_C): 

777 offset_c = off + tl.arange(0, BLOCK_C) 

778 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

779 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

780 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

781 

782 w_mask = offset_c < C 

783 if w_ptr is None: 

784 w = w_mask 

785 else: 

786 w_ptrs = w_ptr + offset_c 

787 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

788 

789 smooth = tl.where( 

790 offset_c[:, None] == tgt[None, :], 

791 1 - label_smoothing + label_smoothing / C, 

792 label_smoothing / C, 

793 ) 

794 

795 grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[:, None] 

796 inp_grad = grad * out_grad * mean_num 

797 inp_grad_ptrs = ( 

798 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

799 ) 

800 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask) 

801 

802 

803class CrossEntropyLoss(torch.autograd.Function): 

804 @staticmethod 

805 def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing): 

806 logger.debug("GEMS_TSINGMICRO CrossEntropyLoss") 

807 

808 shape = list(inp.shape) 

809 dim = inp.ndim 

810 N = 1 if dim == 1 else shape[0] 

811 C = shape[0] if dim == 1 else shape[1] 

812 D = inp.numel() // N // C 

813 axis = 0 if dim == 1 else 1 

814 del shape[axis] 

815 

816 inp = inp.contiguous() 

817 tgt = target.contiguous() 

818 

819 ctx.N = N 

820 ctx.C = C 

821 ctx.D = D 

822 ctx.ignore_index = ignore_index 

823 ctx.label_smoothing = label_smoothing 

824 ctx.shape = shape 

825 

826 final_max = None 

827 final_sum = None 

828 

829 mean_num = 1 

830 if reduction == 1 and tgt.ndim == dim: 

831 mean_num = 1 / (N * D) 

832 out = torch.empty(shape, dtype=torch.float32, device=inp.device) 

833 

834 def get_result(inp, tgt, out, reduction, mean_num): 

835 if reduction == 0: # NONE 

836 return out.to(inp.dtype) 

837 elif reduction == 1: # MEAN 

838 return (torch.sum(out) * mean_num).to(inp.dtype) 

839 else: # SUM 

840 return torch.sum(out).to(inp.dtype) 

841 

842 if weight is None and tgt.ndim != dim and label_smoothing == 0: 

843 final_max = torch.full( 

844 shape, 

845 torch.finfo(torch.float32).min, 

846 dtype=torch.float32, 

847 device=inp.device, 

848 ) 

849 final_sum = torch.zeros(shape, dtype=torch.float32, device=inp.device) 

850 with torch_device_fn.device(inp.device): 

851 if C <= (32 * 1000) or C > (2048 * 1000): 

852 softmax_forward_kernel[(TOTAL_CORE_NUM,)]( 

853 inp, final_max, final_sum, N, C, D 

854 ) 

855 else: 

856 grid = lambda meta: ( 

857 triton.cdiv(TOTAL_CORE_NUM, meta["C_TILE_NUM"]), 

858 meta["C_TILE_NUM"], 

859 ) 

860 max_kernel[grid](inp, final_max, N, C, D) 

861 softmax_forward_with_max_kernel[grid]( 

862 inp, final_max, final_sum, N, C, D 

863 ) 

864 

865 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),) 

866 nllloss_without_weight_kernel[grid]( 

867 inp, tgt, final_max, final_sum, out, ignore_index, N, C, D 

868 ) 

869 if reduction == 1: 

870 if ignore_index < 0 or ignore_index >= C: 

871 mean_num = 1 / C 

872 else: 

873 mean_num = 1 / (C - 1) 

874 ctx.mean_num = mean_num 

875 

876 ctx.save_for_backward(inp, tgt, weight, final_max, final_sum) 

877 return get_result(inp, tgt, out, reduction, mean_num) 

878 

879 weight = weight.contiguous() if weight is not None else None 

880 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N) 

881 

882 if tgt.ndim == dim: 

883 # target probabilities 

884 with torch_device_fn.device(inp.device): 

885 celoss_probability_kernel[grid]( 

886 inp, 

887 tgt, 

888 weight, 

889 out, 

890 label_smoothing, 

891 C, 

892 D, 

893 ) 

894 elif label_smoothing == 0: 

895 # target indices 

896 w_tgt = torch.zeros(shape, dtype=torch.float32, device=inp.device) 

897 final_max = torch.empty(shape, dtype=torch.float32, device=inp.device) 

898 final_sum = torch.empty(shape, dtype=torch.float32, device=inp.device) 

899 with torch_device_fn.device(inp.device): 

900 softmax_forward_kernel[(TOTAL_CORE_NUM,)]( 

901 inp, final_max, final_sum, N, C, D 

902 ) 

903 nllloss_with_weight_kernel[(N,)]( 

904 inp, 

905 tgt, 

906 weight, 

907 w_tgt, 

908 final_max, 

909 final_sum, 

910 out, 

911 ignore_index, 

912 N, 

913 C, 

914 D, 

915 ) 

916 else: 

917 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device) 

918 with torch_device_fn.device(inp.device): 

919 celoss_indices_smooth_kernel[grid]( 

920 inp, 

921 tgt, 

922 weight, 

923 out, 

924 w_tgt, 

925 ignore_index, 

926 label_smoothing, 

927 C, 

928 D, 

929 ) 

930 ctx.save_for_backward(inp, tgt, weight, final_max, final_sum) 

931 ctx.mean_num = 1 

932 

933 if reduction == 1 and tgt.ndim != dim: 

934 mean_num = 1 / torch.sum(w_tgt).item() 

935 ctx.mean_num = mean_num 

936 return get_result(inp, tgt, out, reduction, mean_num) 

937 

938 @staticmethod 

939 def backward(ctx, out_grad): 

940 logger.debug("GEMS_TSINGMICRO CrossEntropyLoss VJP") 

941 

942 inp, tgt, weight, final_max, final_sum = ctx.saved_tensors 

943 N = ctx.N 

944 C = ctx.C 

945 D = ctx.D 

946 ignore_index = ctx.ignore_index 

947 label_smoothing = ctx.label_smoothing 

948 mean_num = ctx.mean_num 

949 shape = ctx.shape 

950 

951 out_grad = out_grad.broadcast_to(shape).contiguous() 

952 

953 inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device) 

954 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N) 

955 if tgt.ndim == inp.ndim: 

956 celoss_probability_bwd[grid]( 

957 out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, C, D 

958 ) 

959 elif label_smoothing == 0: 

960 if final_sum is not None: 

961 is_has_weight = weight is not None 

962 is_has_ignore_index = ignore_index >= 0 and ignore_index < C 

963 is_tgt_in_i32 = C < (1 << 31) 

964 grid = lambda meta: ( 

965 triton.cdiv(TOTAL_CORE_NUM, meta["C_TILE_NUM"]), 

966 meta["C_TILE_NUM"], 

967 ) 

968 celoss_indice_bwd_with_saved_sum_kernel[grid]( 

969 out_grad, 

970 inp, 

971 tgt, 

972 weight, 

973 inp_grad, 

974 final_max, 

975 final_sum, 

976 ignore_index, 

977 mean_num, 

978 N, 

979 C, 

980 D, 

981 is_has_weight, 

982 is_has_ignore_index, 

983 is_tgt_in_i32, 

984 ) 

985 else: 

986 celoss_indices_smooth_bwd[grid]( 

987 out_grad, 

988 inp, 

989 tgt, 

990 weight, 

991 inp_grad, 

992 ignore_index, 

993 label_smoothing, 

994 mean_num, 

995 C, 

996 D, 

997 ) 

998 return inp_grad, None, None, None, None, None 

999 

1000 

1001def cross_entropy_loss( 

1002 inp, target, weight=None, reduction="mean", ignore_index=-100, label_smoothing=0.0 

1003): 

1004 return CrossEntropyLoss.apply( 

1005 inp, 

1006 target, 

1007 weight, 

1008 _Reduction.get_enum(reduction), 

1009 ignore_index, 

1010 label_smoothing, 

1011 )