Coverage for src/flag_gems/runtime/backend/_cambricon/fused/cross_entropy_loss.py: 0%

507 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch.nn import _reduction as _Reduction 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11 

12from ..ops import sum 

13from ..utils import TOTAL_CORE_NUM 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18@libentry() 

19@triton.autotune( 

20 configs=[ 

21 triton.Config({"BLOCK_C": 2**n}, num_warps=1, num_stages=3) 

22 for n in range(10, 17, 2) 

23 ], 

24 key=["C"], 

25) 

26@triton.jit 

27def softmax_forward_kernel( 

28 inp_ptr, 

29 final_max_ptr, 

30 final_sum_ptr, 

31 N, 

32 C: tl.constexpr, 

33 D: tl.constexpr, 

34 BLOCK_C: tl.constexpr, 

35): 

36 job_id = tl.program_id(0) 

37 job_num = tl.num_programs(0) 

38 

39 batch_per_job = N // job_num 

40 job_remain_batch = N - batch_per_job * job_num 

41 batch_per_job += 1 

42 batch_begin = job_id * batch_per_job 

43 if job_id >= job_remain_batch: 

44 batch_per_job -= 1 

45 batch_begin = job_id * batch_per_job + job_remain_batch 

46 batch_end = batch_begin + batch_per_job 

47 

48 for batch_idx in range(batch_begin, batch_end): 

49 pid_n = batch_idx 

50 

51 if C <= BLOCK_C: 

52 offset_d = tl.arange(0, D) 

53 offset_c = tl.arange(0, C) 

54 

55 inp_ptrs = ( 

56 inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

57 ) 

58 inp = tl.load(inp_ptrs).to(tl.float32) 

59 final_max = tl.max(inp, axis=0) 

60 final_sum = tl.sum(tl.exp(inp - final_max[None, :]), axis=0) 

61 

62 final_max_ptrs = final_max_ptr + pid_n * D + offset_d 

63 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d 

64 

65 tl.store(final_max_ptrs, final_max) 

66 tl.store(final_sum_ptrs, final_sum) 

67 else: 

68 tmp_max = tl.zeros([BLOCK_C, D], dtype=tl.float32) 

69 tmp_sum = tl.zeros([BLOCK_C, D], dtype=tl.float32) 

70 offset_d = tl.arange(0, D) 

71 

72 for off in range(0, C, BLOCK_C): 

73 offset_c = off + tl.arange(0, BLOCK_C) 

74 inp_ptrs = ( 

75 inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

76 ) 

77 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

78 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to( 

79 tl.float32 

80 ) 

81 cur_max = tl.maximum(tmp_max, inp) 

82 cur_exp = tl.exp(inp - cur_max) 

83 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

84 tmp_max = cur_max 

85 

86 final_max = tl.max(tmp_max, axis=0) 

87 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :]) 

88 final_sum = tl.sum(tmp_sum, axis=0) 

89 

90 final_max_ptrs = final_max_ptr + pid_n * D + offset_d 

91 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d 

92 

93 tl.store(final_max_ptrs, final_max) 

94 tl.store(final_sum_ptrs, final_sum) 

95 

96 

97@libentry() 

98@triton.autotune( 

99 configs=[ 

100 triton.Config({"C_TILE_NUM": num}, num_warps=1, num_stages=s) 

101 for num in [4, 8, 16, 48] 

102 for s in [0, 3] 

103 ], 

104 key=["C"], 

105 restore_value=["final_max_ptr"], 

106) 

107@triton.jit 

108def max_kernel( 

109 inp_ptr, 

110 final_max_ptr, 

111 N, 

112 C: tl.constexpr, 

113 D: tl.constexpr, 

114 C_TILE_NUM: tl.constexpr, 

115): 

116 job_id = tl.program_id(0) 

117 job_num = tl.num_programs(0) 

118 

119 batch_per_job = N // job_num 

120 job_remain_batch = N - batch_per_job * job_num 

121 batch_per_job += 1 

122 batch_begin = job_id * batch_per_job 

123 if job_id >= job_remain_batch: 

124 batch_per_job -= 1 

125 batch_begin = job_id * batch_per_job + job_remain_batch 

126 batch_end = batch_begin + batch_per_job 

127 

128 core_id = tl.program_id(1) 

129 offset_d = tl.arange(0, D) 

130 BLOCK_C: tl.constexpr = (C + C_TILE_NUM - 1) // C_TILE_NUM 

131 

132 for batch_idx in range(batch_begin, batch_end): 

133 pid_n = batch_idx 

134 offset_c = core_id * BLOCK_C + tl.arange(0, BLOCK_C) 

135 

136 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

137 inp_mask = offset_c[:, None] < C 

138 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(tl.float32) 

139 

140 final_max = tl.max(inp, axis=0) 

141 final_max_ptrs = final_max_ptr + pid_n * D + offset_d 

142 tl.atomic_max(final_max_ptrs, final_max) 

143 

144 

145@libentry() 

146@triton.autotune( 

147 configs=[ 

148 triton.Config({"C_TILE_NUM": num}, num_warps=1, num_stages=s) 

149 for num in [4, 8, 16, 48] 

150 for s in [0, 3] 

151 ], 

152 key=["C"], 

153 reset_to_zero=["final_sum_ptr"], 

154) 

155@triton.jit 

156def softmax_forward_with_max_kernel( 

157 inp_ptr, 

158 final_max_ptr, 

159 final_sum_ptr, 

160 N, 

161 C: tl.constexpr, 

162 D: tl.constexpr, 

163 C_TILE_NUM: tl.constexpr, 

164): 

165 job_id = tl.program_id(0) 

166 job_num = tl.num_programs(0) 

167 

168 batch_per_job = N // job_num 

169 job_remain_batch = N - batch_per_job * job_num 

170 batch_per_job += 1 

171 batch_begin = job_id * batch_per_job 

172 if job_id >= job_remain_batch: 

173 batch_per_job -= 1 

174 batch_begin = job_id * batch_per_job + job_remain_batch 

175 batch_end = batch_begin + batch_per_job 

176 

177 core_id = tl.program_id(1) 

178 offset_d = tl.arange(0, D) 

179 BLOCK_C: tl.constexpr = (C + C_TILE_NUM - 1) // C_TILE_NUM 

180 

181 for batch_idx in range(batch_begin, batch_end): 

182 pid_n = batch_idx 

183 offset_c = core_id * BLOCK_C + tl.arange(0, BLOCK_C) 

184 

185 final_max_ptrs = final_max_ptr + pid_n * D + offset_d 

186 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d 

187 final_max = tl.load(final_max_ptrs) 

188 

189 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

190 inp_mask = offset_c[:, None] < C 

191 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(tl.float32) 

192 

193 final_sum = tl.sum(tl.exp(inp - final_max[None, :]), axis=0) 

194 tl.atomic_add(final_sum_ptrs, final_sum) 

195 

196 

197@libentry() 

198@triton.autotune( 

199 configs=[ 

200 triton.Config({"BLOCK_N": 2**n}, num_warps=4, num_stages=0) 

201 for n in range(4, 11, 2) 

202 ], 

203 key=["N"], 

204) 

205@triton.jit(do_not_specialize=["ignore_index"]) 

206def nllloss_without_weight_kernel( 

207 inp_ptr, 

208 tgt_ptr, 

209 final_max_ptr, 

210 final_sum_ptr, 

211 out_ptr, 

212 ignore_index, 

213 N, 

214 C, 

215 D: tl.constexpr, 

216 BLOCK_N: tl.constexpr, 

217): 

218 core_id = tl.program_id(0) 

219 offset_n = core_id * BLOCK_N + tl.arange(0, BLOCK_N) 

220 offset_d = tl.arange(0, D) 

221 

222 tgt_ptrs = tgt_ptr + offset_n * D + offset_d 

223 tgt_mask = offset_n < N 

224 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

225 

226 ignore_mask = not (tgt == ignore_index) 

227 

228 final_max_ptrs = final_max_ptr + offset_n * D + offset_d 

229 final_sum_ptrs = final_sum_ptr + offset_n * D + offset_d 

230 final_max = tl.load(final_max_ptrs, mask=tgt_mask, other=0) 

231 final_sum = tl.load(final_sum_ptrs, mask=tgt_mask, other=1) 

232 

233 inp_tgt_ptrs = inp_ptr + offset_n * C * D + tgt * D + offset_d 

234 inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32) 

235 

236 loge2 = 0.693147 

237 out = tl.log2(final_sum) * loge2 + final_max - inp_tgt 

238 

239 out_ptrs = out_ptr + offset_n * D + offset_d 

240 tl.store(out_ptrs, out, mask=tgt_mask and ignore_mask) 

241 

242 

243@libentry() 

244@triton.heuristics( 

245 values={ 

246 "num_warps": lambda args: 1, 

247 "num_stages": lambda args: 0, 

248 }, 

249) 

250@triton.jit(do_not_specialize=["ignore_index"]) 

251def nllloss_with_weight_kernel( 

252 inp_ptr, 

253 tgt_ptr, 

254 w_ptr, 

255 w_tgt_ptr, 

256 final_max_ptr, 

257 final_sum_ptr, 

258 out_ptr, 

259 ignore_index, 

260 N, 

261 C, 

262 D: tl.constexpr, 

263): 

264 pid_n = tl.program_id(0) 

265 offset_d = tl.arange(0, D) 

266 

267 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

268 tgt = tl.load(tgt_ptrs) 

269 

270 ignore_mask = not (tgt == ignore_index) 

271 

272 if w_ptr is None: 

273 w_tgt = ignore_mask 

274 else: 

275 w_ptrs = w_ptr + tgt 

276 w_tgt = tl.load(w_ptrs).to(tl.float32) 

277 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d 

278 tl.store(w_tgt_ptrs, w_tgt, mask=ignore_mask) 

279 

280 final_max_ptrs = final_max_ptr + pid_n * D + offset_d 

281 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d 

282 final_max = tl.load(final_max_ptrs) 

283 final_sum = tl.load(final_sum_ptrs) 

284 

285 inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d 

286 inp_tgt = tl.load(inp_tgt_ptrs).to(tl.float32) 

287 

288 loge2 = 0.693147 

289 out = (tl.log2(final_sum) * loge2 + final_max - inp_tgt) * w_tgt 

290 

291 out_ptrs = out_ptr + pid_n * D + offset_d 

292 tl.store(out_ptrs, out, mask=ignore_mask) 

293 

294 

295@libentry() 

296@triton.autotune( 

297 configs=runtime.get_tuned_config("cross_entropy_loss"), 

298 key=["C", "D"], 

299) 

300@triton.jit(do_not_specialize=["label_smoothing"]) 

301def celoss_probability_kernel( 

302 inp_ptr, 

303 tgt_ptr, 

304 w_ptr, 

305 out_ptr, 

306 label_smoothing, 

307 C, 

308 D, 

309 BLOCK_C: tl.constexpr, 

310 BLOCK_D: tl.constexpr, 

311): 

312 pid_d = tl.program_id(0) 

313 pid_n = tl.program_id(1) 

314 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

315 

316 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

317 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

318 

319 for off in range(0, C, BLOCK_C): 

320 offset_c = off + tl.arange(0, BLOCK_C) 

321 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

322 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

323 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

324 cur_max = tl.maximum(tmp_max, inp) 

325 cur_exp = tl.exp(inp - cur_max) 

326 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

327 tmp_max = cur_max 

328 final_max = tl.max(tmp_max, axis=0)[None, :] 

329 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

330 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :] 

331 

332 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

333 for off in range(0, C, BLOCK_C): 

334 offset_c = off + tl.arange(0, BLOCK_C) 

335 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

336 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

337 mask = offset_c[:, None] < C and offset_d[None, :] < D 

338 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

339 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

340 tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C 

341 log = final_sum + final_max - inp 

342 w_mask = offset_c < C 

343 if w_ptr is None: 

344 w = w_mask 

345 else: 

346 w = tl.load(w_ptr + offset_c, mask=w_mask, other=0).to(tl.float32) 

347 _sum += log * tgt * w[:, None] 

348 

349 out = tl.sum(_sum, axis=0) 

350 out_ptrs = out_ptr + pid_n * D + offset_d 

351 tl.store(out_ptrs, out, mask=offset_d < D) 

352 

353 

354@libentry() 

355@triton.autotune( 

356 configs=runtime.get_tuned_config("cross_entropy_loss"), 

357 key=["C", "D"], 

358) 

359@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"]) 

360def celoss_indices_smooth_kernel( 

361 inp_ptr, 

362 tgt_ptr, 

363 w_ptr, 

364 out_ptr, 

365 w_tgt_ptr, 

366 ignore_index, 

367 label_smoothing, 

368 C, 

369 D, 

370 BLOCK_C: tl.constexpr, 

371 BLOCK_D: tl.constexpr, 

372): 

373 pid_d = tl.program_id(0) 

374 pid_n = tl.program_id(1) 

375 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

376 

377 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

378 tgt_mask = offset_d < D 

379 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

380 

381 ignore_mask = not (tgt == ignore_index) and tgt_mask 

382 

383 if w_ptr is None: 

384 w_tgt = ignore_mask 

385 else: 

386 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0) 

387 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d 

388 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask) 

389 

390 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

391 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

392 

393 for off in range(0, C, BLOCK_C): 

394 offset_c = off + tl.arange(0, BLOCK_C) 

395 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

396 mask = offset_c[:, None] < C and offset_d[None, :] < D 

397 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32) 

398 cur_max = tl.maximum(tmp_max, inp) 

399 cur_exp = tl.exp(inp - cur_max) 

400 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

401 tmp_max = cur_max 

402 final_max = tl.max(tmp_max, axis=0)[None, :] 

403 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

404 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :] 

405 final_sum_max = final_sum + final_max 

406 

407 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

408 for off in range(0, C, BLOCK_C): 

409 offset_c = off + tl.arange(0, BLOCK_C) 

410 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

411 mask = offset_c[:, None] < C and offset_d[None, :] < D 

412 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

413 

414 w_mask = offset_c < C 

415 if w_ptr is None: 

416 w = w_mask 

417 else: 

418 w = tl.load(w_ptr + offset_c, w_mask, other=0).to(tl.float32) 

419 

420 smooth = tl.where( 

421 offset_c[:, None] == tgt[None, :], 

422 1 - label_smoothing + label_smoothing / C, 

423 label_smoothing / C, 

424 ).to(tl.float32) 

425 

426 log = final_sum_max - inp 

427 _sum += log * smooth * w[:, None] 

428 

429 out = tl.sum(_sum, axis=0) 

430 out = tl.where(ignore_mask, out, 0) 

431 out_ptrs = out_ptr + pid_n * D + offset_d 

432 tl.store(out_ptrs, out, mask=tgt_mask) 

433 

434 

435@triton.jit 

436def single_celoss_indice_bwd( 

437 pid_n, 

438 offset_c, 

439 offset_d, 

440 final_max, 

441 final_sum, 

442 tgt, 

443 w_tgt, 

444 out_grad, 

445 mean_num, 

446 inp_ptr, 

447 inp_grad_ptr, 

448 ignore_mask, 

449 C, 

450 D, 

451): 

452 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

453 inp_mask = offset_c[:, None] < C 

454 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(tl.float32) 

455 

456 minus_one = offset_c[:, None] == tgt[None, :] 

457 inp_grad = ( 

458 (tl.exp(inp - final_max[None, :]) / final_sum[None, :] - minus_one) 

459 * w_tgt 

460 * out_grad 

461 * mean_num 

462 ) 

463 inp_grad_ptrs = ( 

464 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

465 ) 

466 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask) 

467 

468 

469def config_prune(configs, named_args, **kwargs): 

470 pruned_configs = [] 

471 

472 for config in configs: 

473 kw = config.kwargs 

474 mode, num, BLOCK_C = (kw["TILE_MODE"], kw["C_TILE_NUM"], kw["BLOCK_C"]) 

475 if (mode == 0 and num == 1) or (mode == 1 and num >= 4 and BLOCK_C <= 1024): 

476 pruned_configs.append(config) 

477 return pruned_configs 

478 

479 

480@libentry() 

481@triton.autotune( 

482 configs=[ 

483 triton.Config( 

484 { 

485 "TILE_MODE": mode, 

486 "C_TILE_NUM": num, 

487 "BLOCK_C": 2**n, 

488 }, 

489 num_warps=1, 

490 num_stages=s, 

491 ) 

492 for mode in [0, 1] 

493 for num in [1, 4, 8, 16, 48] 

494 for n in range(10, 17, 2) 

495 for s in [0, 3] 

496 ], 

497 key=["C"], 

498 prune_configs_by={ 

499 "early_config_prune": config_prune, 

500 }, 

501) 

502@triton.jit(do_not_specialize=["ignore_index", "mean_num"]) 

503def celoss_indice_bwd_with_saved_sum_kernel( 

504 out_grad_ptr, 

505 inp_ptr, 

506 tgt_ptr, 

507 w_ptr, 

508 inp_grad_ptr, 

509 final_max_ptr, 

510 final_sum_ptr, 

511 ignore_index, 

512 mean_num, 

513 N, 

514 C: tl.constexpr, 

515 D: tl.constexpr, 

516 is_has_weight: tl.constexpr, 

517 is_has_ignore_index: tl.constexpr, 

518 is_tgt_in_i32: tl.constexpr, 

519 TILE_MODE: tl.constexpr, 

520 C_TILE_NUM: tl.constexpr, 

521 BLOCK_C: tl.constexpr, 

522): 

523 job_id = tl.program_id(0) 

524 job_num = tl.num_programs(0) 

525 

526 batch_per_job = N // job_num 

527 job_remain_batch = N - batch_per_job * job_num 

528 batch_per_job += 1 

529 batch_begin = job_id * batch_per_job 

530 if job_id >= job_remain_batch: 

531 batch_per_job -= 1 

532 batch_begin = job_id * batch_per_job + job_remain_batch 

533 batch_end = batch_begin + batch_per_job 

534 

535 for batch_idx in range(batch_begin, batch_end): 

536 pid_n = batch_idx 

537 offset_d = tl.arange(0, D) 

538 

539 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

540 if is_tgt_in_i32: 

541 tgt = tl.load(tgt_ptrs).to(tl.int32) 

542 else: 

543 tgt = tl.load(tgt_ptrs) 

544 

545 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

546 out_grad = tl.load(out_grad_ptrs).to(tl.float32)[None, :] 

547 

548 if is_has_weight: 

549 w_ptrs = w_ptr + tgt 

550 w_tgt = tl.load(w_ptrs).to(tl.float32)[None, :] 

551 else: 

552 w_tgt = 1 

553 

554 if is_has_ignore_index: 

555 ignore_mask = (tgt != ignore_index)[None, :] 

556 else: 

557 ignore_mask = True 

558 

559 final_max_ptrs = final_max_ptr + pid_n * D + offset_d 

560 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d 

561 final_max = tl.load(final_max_ptrs) 

562 final_sum = tl.load(final_sum_ptrs) 

563 

564 if TILE_MODE == 0: 

565 if C <= BLOCK_C: 

566 offset_c = tl.arange(0, C) 

567 single_celoss_indice_bwd( 

568 pid_n, 

569 offset_c, 

570 offset_d, 

571 final_max, 

572 final_sum, 

573 tgt, 

574 w_tgt, 

575 out_grad, 

576 mean_num, 

577 inp_ptr, 

578 inp_grad_ptr, 

579 ignore_mask, 

580 C, 

581 D, 

582 ) 

583 else: 

584 for off in range(0, C, BLOCK_C): 

585 offset_c = off + tl.arange(0, BLOCK_C) 

586 single_celoss_indice_bwd( 

587 pid_n, 

588 offset_c, 

589 offset_d, 

590 final_max, 

591 final_sum, 

592 tgt, 

593 w_tgt, 

594 out_grad, 

595 mean_num, 

596 inp_ptr, 

597 inp_grad_ptr, 

598 ignore_mask, 

599 C, 

600 D, 

601 ) 

602 else: 

603 core_id = tl.program_id(1) 

604 C_TILE_SIZE: tl.constexpr = (C + C_TILE_NUM - 1) // C_TILE_NUM 

605 offset_c = core_id * C_TILE_SIZE + tl.arange(0, C_TILE_SIZE) 

606 

607 single_celoss_indice_bwd( 

608 pid_n, 

609 offset_c, 

610 offset_d, 

611 final_max, 

612 final_sum, 

613 tgt, 

614 w_tgt, 

615 out_grad, 

616 mean_num, 

617 inp_ptr, 

618 inp_grad_ptr, 

619 ignore_mask, 

620 C, 

621 D, 

622 ) 

623 

624 

625@libentry() 

626@triton.autotune( 

627 configs=runtime.get_tuned_config("cross_entropy_loss"), 

628 key=["C", "D"], 

629) 

630@triton.jit(do_not_specialize=["label_smoothing", "mean_num"]) 

631def celoss_probability_bwd( 

632 out_grad_ptr, 

633 inp_ptr, 

634 tgt_ptr, 

635 w_ptr, 

636 inp_grad_ptr, 

637 label_smoothing, 

638 mean_num, 

639 C, 

640 D, 

641 BLOCK_C: tl.constexpr, 

642 BLOCK_D: tl.constexpr, 

643): 

644 pid_d = tl.program_id(0) 

645 pid_n = tl.program_id(1) 

646 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

647 

648 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

649 out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32)[ 

650 None, : 

651 ] 

652 

653 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

654 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

655 w_tgt_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

656 

657 for off in range(0, C, BLOCK_C): 

658 offset_c = off + tl.arange(0, BLOCK_C) 

659 mask = offset_c[:, None] < C and offset_d[None, :] < D 

660 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

661 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32) 

662 

663 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

664 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

665 tgt = tgt * (1 - label_smoothing) + label_smoothing / C 

666 

667 w_mask = offset_c < C 

668 if w_ptr is None: 

669 w = w_mask 

670 else: 

671 w_ptrs = w_ptr + offset_c 

672 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

673 

674 w_tgt_sum += tgt * w[:, None] 

675 

676 cur_max = tl.maximum(tmp_max, inp) 

677 cur_exp = tl.exp(inp - cur_max) 

678 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

679 tmp_max = cur_max 

680 final_max = tl.max(tmp_max, axis=0)[None, :] 

681 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

682 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

683 w_tgt_sum = tl.sum(w_tgt_sum, axis=0)[None, :] 

684 

685 for off in range(0, C, BLOCK_C): 

686 offset_c = off + tl.arange(0, BLOCK_C) 

687 offset = pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

688 inp_ptrs = inp_ptr + offset 

689 mask = offset_c[:, None] < C and offset_d[None, :] < D 

690 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

691 

692 tgt_ptrs = tgt_ptr + offset 

693 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

694 tgt = tgt * (1 - label_smoothing) + label_smoothing / C 

695 

696 w_mask = offset_c < C 

697 if w_ptr is None: 

698 w = w_mask 

699 else: 

700 w_ptrs = w_ptr + offset_c 

701 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

702 

703 grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - tgt * w[:, None] 

704 inp_grad = grad * out_grad * mean_num 

705 

706 inp_grad_ptrs = inp_grad_ptr + offset 

707 tl.store(inp_grad_ptrs, inp_grad, mask) 

708 

709 

710@libentry() 

711@triton.autotune( 

712 configs=runtime.get_tuned_config("cross_entropy_loss"), 

713 key=["C", "D"], 

714) 

715@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"]) 

716def celoss_indices_smooth_bwd( 

717 out_grad_ptr, 

718 inp_ptr, 

719 tgt_ptr, 

720 w_ptr, 

721 inp_grad_ptr, 

722 ignore_index, 

723 label_smoothing, 

724 mean_num, 

725 C, 

726 D, 

727 BLOCK_C: tl.constexpr, 

728 BLOCK_D: tl.constexpr, 

729): 

730 pid_d = tl.program_id(0) 

731 pid_n = tl.program_id(1) 

732 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

733 

734 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

735 tgt_mask = offset_d < D 

736 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

737 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

738 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :] 

739 

740 ignore_mask = (tgt != ignore_index)[None, :] 

741 

742 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

743 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

744 w_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

745 

746 for off in range(0, C, BLOCK_C): 

747 offset_c = off + tl.arange(0, BLOCK_C) 

748 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

749 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

750 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

751 

752 w_mask = offset_c < C 

753 if w_ptr is None: 

754 w = w_mask 

755 else: 

756 w_ptrs = w_ptr + offset_c 

757 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

758 

759 smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32) 

760 smooth = tl.where( 

761 offset_c[:, None] == tgt[None, :], 

762 1 - label_smoothing + label_smoothing / C, 

763 smooth, 

764 ) 

765 

766 w_sum += smooth * w[:, None] 

767 

768 cur_max = tl.maximum(tmp_max, inp) 

769 cur_exp = tl.exp(inp - cur_max) 

770 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

771 tmp_max = cur_max 

772 final_max = tl.max(tmp_max, axis=0)[None, :] 

773 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

774 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

775 w_sum = tl.sum(w_sum, axis=0)[None, :] 

776 

777 for off in range(0, C, BLOCK_C): 

778 offset_c = off + tl.arange(0, BLOCK_C) 

779 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

780 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

781 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

782 

783 w_mask = offset_c < C 

784 if w_ptr is None: 

785 w = w_mask 

786 else: 

787 w_ptrs = w_ptr + offset_c 

788 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

789 

790 smooth = tl.where( 

791 offset_c[:, None] == tgt[None, :], 

792 1 - label_smoothing + label_smoothing / C, 

793 label_smoothing / C, 

794 ) 

795 

796 grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[:, None] 

797 inp_grad = grad * out_grad * mean_num 

798 inp_grad_ptrs = ( 

799 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

800 ) 

801 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask) 

802 

803 

804class CrossEntropyLoss(torch.autograd.Function): 

805 @staticmethod 

806 def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing): 

807 logger.debug("GEMS_CAMBRICON CrossEntropyLoss") 

808 

809 shape = list(inp.shape) 

810 dim = inp.ndim 

811 N = 1 if dim == 1 else shape[0] 

812 C = shape[0] if dim == 1 else shape[1] 

813 D = inp.numel() // N // C 

814 axis = 0 if dim == 1 else 1 

815 del shape[axis] 

816 

817 inp = inp.contiguous() 

818 tgt = target.contiguous() 

819 

820 ctx.N = N 

821 ctx.C = C 

822 ctx.D = D 

823 ctx.ignore_index = ignore_index 

824 ctx.label_smoothing = label_smoothing 

825 ctx.shape = shape 

826 

827 final_max = None 

828 final_sum = None 

829 

830 mean_num = 1 

831 if reduction == 1 and tgt.ndim == dim: 

832 mean_num = 1 / (N * D) 

833 out = torch.empty(shape, dtype=torch.float32, device=inp.device) 

834 

835 def get_result(inp, tgt, out, reduction, mean_num): 

836 if reduction == 0: # NONE 

837 return out.to(inp.dtype) 

838 elif reduction == 1: # MEAN 

839 return (sum(out) * mean_num).to(inp.dtype) 

840 else: # SUM 

841 return sum(out).to(inp.dtype) 

842 

843 if weight is None and tgt.ndim != dim and label_smoothing == 0: 

844 final_max = torch.full( 

845 shape, 

846 torch.finfo(torch.float32).min, 

847 dtype=torch.float32, 

848 device=inp.device, 

849 ) 

850 final_sum = torch.zeros(shape, dtype=torch.float32, device=inp.device) 

851 with torch.mlu.device(inp.device): 

852 if C <= (32 * 1000) or C > (2048 * 1000): 

853 softmax_forward_kernel[(TOTAL_CORE_NUM,)]( 

854 inp, final_max, final_sum, N, C, D 

855 ) 

856 else: 

857 grid = lambda meta: ( 

858 triton.cdiv(TOTAL_CORE_NUM, meta["C_TILE_NUM"]), 

859 meta["C_TILE_NUM"], 

860 ) 

861 max_kernel[grid](inp, final_max, N, C, D) 

862 softmax_forward_with_max_kernel[grid]( 

863 inp, final_max, final_sum, N, C, D 

864 ) 

865 

866 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),) 

867 nllloss_without_weight_kernel[grid]( 

868 inp, tgt, final_max, final_sum, out, ignore_index, N, C, D 

869 ) 

870 if reduction == 1: 

871 if ignore_index < 0 or ignore_index >= C: 

872 mean_num = 1 / C 

873 else: 

874 mean_num = 1 / (C - 1) 

875 ctx.mean_num = mean_num 

876 

877 ctx.save_for_backward(inp, tgt, weight, final_max, final_sum) 

878 return get_result(inp, tgt, out, reduction, mean_num) 

879 

880 weight = weight.contiguous() if weight is not None else None 

881 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N) 

882 

883 if tgt.ndim == dim: 

884 # target probabilities 

885 with torch_device_fn.device(inp.device): 

886 celoss_probability_kernel[grid]( 

887 inp, 

888 tgt, 

889 weight, 

890 out, 

891 label_smoothing, 

892 C, 

893 D, 

894 ) 

895 elif label_smoothing == 0: 

896 # target indices 

897 w_tgt = torch.zeros(shape, dtype=torch.float32, device=inp.device) 

898 final_max = torch.empty(shape, dtype=torch.float32, device=inp.device) 

899 final_sum = torch.empty(shape, dtype=torch.float32, device=inp.device) 

900 with torch_device_fn.device(inp.device): 

901 softmax_forward_kernel[(TOTAL_CORE_NUM,)]( 

902 inp, final_max, final_sum, N, C, D 

903 ) 

904 nllloss_with_weight_kernel[(N,)]( 

905 inp, 

906 tgt, 

907 weight, 

908 w_tgt, 

909 final_max, 

910 final_sum, 

911 out, 

912 ignore_index, 

913 N, 

914 C, 

915 D, 

916 ) 

917 else: 

918 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device) 

919 with torch_device_fn.device(inp.device): 

920 celoss_indices_smooth_kernel[grid]( 

921 inp, 

922 tgt, 

923 weight, 

924 out, 

925 w_tgt, 

926 ignore_index, 

927 label_smoothing, 

928 C, 

929 D, 

930 ) 

931 ctx.save_for_backward(inp, tgt, weight, final_max, final_sum) 

932 ctx.mean_num = 1 

933 

934 if reduction == 1 and tgt.ndim != dim: 

935 mean_num = 1 / sum(w_tgt).item() 

936 ctx.mean_num = mean_num 

937 return get_result(inp, tgt, out, reduction, mean_num) 

938 

939 @staticmethod 

940 def backward(ctx, out_grad): 

941 logger.debug("GEMS_CAMBRICON CrossEntropyLoss VJP") 

942 

943 inp, tgt, weight, final_max, final_sum = ctx.saved_tensors 

944 N = ctx.N 

945 C = ctx.C 

946 D = ctx.D 

947 ignore_index = ctx.ignore_index 

948 label_smoothing = ctx.label_smoothing 

949 mean_num = ctx.mean_num 

950 shape = ctx.shape 

951 

952 out_grad = out_grad.broadcast_to(shape).contiguous() 

953 

954 inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device) 

955 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N) 

956 if tgt.ndim == inp.ndim: 

957 celoss_probability_bwd[grid]( 

958 out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, C, D 

959 ) 

960 elif label_smoothing == 0: 

961 if final_sum is not None: 

962 is_has_weight = weight is not None 

963 is_has_ignore_index = ignore_index >= 0 and ignore_index < C 

964 is_tgt_in_i32 = C < (1 << 31) 

965 grid = lambda meta: ( 

966 triton.cdiv(TOTAL_CORE_NUM, meta["C_TILE_NUM"]), 

967 meta["C_TILE_NUM"], 

968 ) 

969 celoss_indice_bwd_with_saved_sum_kernel[grid]( 

970 out_grad, 

971 inp, 

972 tgt, 

973 weight, 

974 inp_grad, 

975 final_max, 

976 final_sum, 

977 ignore_index, 

978 mean_num, 

979 N, 

980 C, 

981 D, 

982 is_has_weight, 

983 is_has_ignore_index, 

984 is_tgt_in_i32, 

985 ) 

986 else: 

987 celoss_indices_smooth_bwd[grid]( 

988 out_grad, 

989 inp, 

990 tgt, 

991 weight, 

992 inp_grad, 

993 ignore_index, 

994 label_smoothing, 

995 mean_num, 

996 C, 

997 D, 

998 ) 

999 return inp_grad, None, None, None, None, None 

1000 

1001 

1002def cross_entropy_loss( 

1003 inp, target, weight=None, reduction="mean", ignore_index=-100, label_smoothing=0.0 

1004): 

1005 return CrossEntropyLoss.apply( 

1006 inp, 

1007 target, 

1008 weight, 

1009 _Reduction.get_enum(reduction), 

1010 ignore_index, 

1011 label_smoothing, 

1012 )