Coverage for src/flag_gems/fused/cross_entropy_loss.py: 24%

376 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch.nn import _reduction as _Reduction 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@triton.autotune( 

18 configs=runtime.get_tuned_config("cross_entropy_loss"), 

19 key=["C", "D"], 

20) 

21@triton.jit(do_not_specialize=["ignore_index"]) 

22def celoss_indices_kernel( 

23 inp_ptr, 

24 tgt_ptr, 

25 w_ptr, 

26 out_ptr, 

27 w_tgt_ptr, 

28 ignore_index, 

29 C, 

30 D, 

31 BLOCK_C: tl.constexpr, 

32 BLOCK_D: tl.constexpr, 

33): 

34 pid_d = tle.program_id(0) 

35 pid_n = tle.program_id(1) 

36 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

37 

38 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

39 tgt_mask = offset_d < D 

40 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

41 

42 ignore_mask = not (tgt == ignore_index) and tgt_mask 

43 

44 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

45 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

46 

47 for off in range(0, C, BLOCK_C): 

48 offset_c = off + tl.arange(0, BLOCK_C) 

49 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

50 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

51 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

52 cur_max = tl.maximum(tmp_max, inp) 

53 cur_exp = tl.exp(inp - cur_max) 

54 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

55 tmp_max = cur_max 

56 final_max = tl.max(tmp_max, axis=0) 

57 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :]) 

58 final_sum = tl.log(tl.sum(tmp_sum, axis=0)) 

59 

60 inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d 

61 inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32) 

62 

63 out = final_sum + final_max - inp_tgt 

64 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d 

65 

66 if w_ptr is None: 

67 w_tgt = ignore_mask 

68 else: 

69 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32) 

70 

71 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask) 

72 out *= w_tgt 

73 out_ptrs = out_ptr + pid_n * D + offset_d 

74 tl.store(out_ptrs, out, mask=tgt_mask) 

75 

76 

77@libentry() 

78@triton.autotune( 

79 configs=runtime.get_tuned_config("cross_entropy_loss"), 

80 key=["C", "D"], 

81) 

82@triton.jit(do_not_specialize=["label_smoothing"]) 

83def celoss_probability_kernel( 

84 inp_ptr, 

85 tgt_ptr, 

86 w_ptr, 

87 out_ptr, 

88 label_smoothing, 

89 C, 

90 D, 

91 BLOCK_C: tl.constexpr, 

92 BLOCK_D: tl.constexpr, 

93): 

94 pid_d = tle.program_id(0) 

95 pid_n = tle.program_id(1) 

96 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

97 

98 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

99 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

100 

101 for off in range(0, C, BLOCK_C): 

102 offset_c = off + tl.arange(0, BLOCK_C) 

103 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

104 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

105 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

106 cur_max = tl.maximum(tmp_max, inp) 

107 cur_exp = tl.exp(inp - cur_max) 

108 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

109 tmp_max = cur_max 

110 final_max = tl.max(tmp_max, axis=0)[None, :] 

111 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

112 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :] 

113 

114 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

115 for off in range(0, C, BLOCK_C): 

116 offset_c = off + tl.arange(0, BLOCK_C) 

117 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

118 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

119 mask = offset_c[:, None] < C and offset_d[None, :] < D 

120 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

121 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

122 tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C 

123 log = final_sum + final_max - inp 

124 w_mask = offset_c < C 

125 if w_ptr is None: 

126 w = w_mask 

127 else: 

128 w = tl.load(w_ptr + offset_c, mask=w_mask, other=0).to(tl.float32) 

129 _sum += log * tgt * w[:, None] 

130 

131 out = tl.sum(_sum, axis=0) 

132 out_ptrs = out_ptr + pid_n * D + offset_d 

133 tl.store(out_ptrs, out, mask=offset_d < D) 

134 

135 

136@libentry() 

137@triton.autotune( 

138 configs=runtime.get_tuned_config("cross_entropy_loss"), 

139 key=["C", "D"], 

140) 

141@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"]) 

142def celoss_indices_smooth_kernel( 

143 inp_ptr, 

144 tgt_ptr, 

145 w_ptr, 

146 out_ptr, 

147 w_tgt_ptr, 

148 ignore_index, 

149 label_smoothing, 

150 C, 

151 D, 

152 BLOCK_C: tl.constexpr, 

153 BLOCK_D: tl.constexpr, 

154): 

155 pid_d = tle.program_id(0) 

156 pid_n = tle.program_id(1) 

157 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

158 

159 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

160 tgt_mask = offset_d < D 

161 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

162 

163 ignore_mask = not (tgt == ignore_index) and tgt_mask 

164 

165 if w_ptr is None: 

166 w_tgt = ignore_mask 

167 else: 

168 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0) 

169 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d 

170 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask) 

171 

172 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

173 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

174 

175 for off in range(0, C, BLOCK_C): 

176 offset_c = off + tl.arange(0, BLOCK_C) 

177 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

178 mask = offset_c[:, None] < C and offset_d[None, :] < D 

179 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32) 

180 cur_max = tl.maximum(tmp_max, inp) 

181 cur_exp = tl.exp(inp - cur_max) 

182 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

183 tmp_max = cur_max 

184 final_max = tl.max(tmp_max, axis=0)[None, :] 

185 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

186 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :] 

187 final_sum_max = final_sum + final_max 

188 

189 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

190 for off in range(0, C, BLOCK_C): 

191 offset_c = off + tl.arange(0, BLOCK_C) 

192 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

193 mask = offset_c[:, None] < C and offset_d[None, :] < D 

194 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

195 

196 w_mask = offset_c < C 

197 if w_ptr is None: 

198 w = w_mask 

199 else: 

200 w = tl.load(w_ptr + offset_c, w_mask, other=0).to(tl.float32) 

201 

202 smooth = tl.where( 

203 offset_c[:, None] == tgt[None, :], 

204 1 - label_smoothing + label_smoothing / C, 

205 label_smoothing / C, 

206 ).to(tl.float32) 

207 

208 log = final_sum_max - inp 

209 _sum += log * smooth * w[:, None] 

210 

211 out = tl.sum(_sum, axis=0) 

212 out = tl.where(ignore_mask, out, 0) 

213 out_ptrs = out_ptr + pid_n * D + offset_d 

214 tl.store(out_ptrs, out, mask=tgt_mask) 

215 

216 

217@libentry() 

218@triton.autotune( 

219 configs=runtime.get_tuned_config("cross_entropy_loss"), 

220 key=["C", "D"], 

221) 

222@triton.jit(do_not_specialize=["ignore_index", "mean_num"]) 

223def celoss_indices_bwd( 

224 out_grad_ptr, 

225 inp_ptr, 

226 tgt_ptr, 

227 w_ptr, 

228 inp_grad_ptr, 

229 ignore_index, 

230 mean_num, 

231 C, 

232 D, 

233 BLOCK_C: tl.constexpr, 

234 BLOCK_D: tl.constexpr, 

235): 

236 pid_d = tle.program_id(0) 

237 pid_n = tle.program_id(1) 

238 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

239 

240 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

241 tgt_mask = offset_d < D 

242 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

243 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

244 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :] 

245 

246 if w_ptr is None: 

247 w_tgt = tgt_mask 

248 else: 

249 w_ptrs = w_ptr + tgt 

250 w_tgt = tl.load(w_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :] 

251 

252 ignore_mask = (tgt != ignore_index)[None, :] 

253 

254 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

255 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

256 

257 for off in range(0, C, BLOCK_C): 

258 offset_c = off + tl.arange(0, BLOCK_C) 

259 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

260 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

261 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

262 cur_max = tl.maximum(tmp_max, inp) 

263 cur_exp = tl.exp(inp - cur_max) 

264 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

265 tmp_max = cur_max 

266 final_max = tl.max(tmp_max, axis=0)[None, :] 

267 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

268 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

269 

270 for off in range(0, C, BLOCK_C): 

271 offset_c = off + tl.arange(0, BLOCK_C) 

272 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

273 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

274 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

275 minus_one = offset_c[:, None] == tgt[None, :] 

276 inp_grad = ( 

277 (tl.exp(inp - final_max) / final_sum - minus_one) 

278 * w_tgt 

279 * out_grad 

280 * mean_num 

281 ) 

282 inp_grad_ptrs = ( 

283 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

284 ) 

285 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask) 

286 

287 

288@libentry() 

289@triton.autotune( 

290 configs=runtime.get_tuned_config("cross_entropy_loss"), 

291 key=["C", "D"], 

292) 

293@triton.jit(do_not_specialize=["label_smoothing", "mean_num"]) 

294def celoss_probability_bwd( 

295 out_grad_ptr, 

296 inp_ptr, 

297 tgt_ptr, 

298 w_ptr, 

299 inp_grad_ptr, 

300 label_smoothing, 

301 mean_num, 

302 C, 

303 D, 

304 BLOCK_C: tl.constexpr, 

305 BLOCK_D: tl.constexpr, 

306): 

307 pid_d = tle.program_id(0) 

308 pid_n = tle.program_id(1) 

309 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

310 

311 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

312 out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32)[ 

313 None, : 

314 ] 

315 

316 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

317 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

318 w_tgt_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

319 

320 for off in range(0, C, BLOCK_C): 

321 offset_c = off + tl.arange(0, BLOCK_C) 

322 mask = offset_c[:, None] < C and offset_d[None, :] < D 

323 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

324 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32) 

325 

326 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

327 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

328 tgt = tgt * (1 - label_smoothing) + label_smoothing / C 

329 

330 w_mask = offset_c < C 

331 if w_ptr is None: 

332 w = w_mask 

333 else: 

334 w_ptrs = w_ptr + offset_c 

335 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

336 

337 w_tgt_sum += tgt * w[:, None] 

338 

339 cur_max = tl.maximum(tmp_max, inp) 

340 cur_exp = tl.exp(inp - cur_max) 

341 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

342 tmp_max = cur_max 

343 final_max = tl.max(tmp_max, axis=0)[None, :] 

344 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

345 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

346 w_tgt_sum = tl.sum(w_tgt_sum, axis=0)[None, :] 

347 

348 for off in range(0, C, BLOCK_C): 

349 offset_c = off + tl.arange(0, BLOCK_C) 

350 offset = pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

351 inp_ptrs = inp_ptr + offset 

352 mask = offset_c[:, None] < C and offset_d[None, :] < D 

353 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

354 

355 tgt_ptrs = tgt_ptr + offset 

356 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

357 tgt = tgt * (1 - label_smoothing) + label_smoothing / C 

358 

359 w_mask = offset_c < C 

360 if w_ptr is None: 

361 w = w_mask 

362 else: 

363 w_ptrs = w_ptr + offset_c 

364 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

365 

366 grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - tgt * w[:, None] 

367 inp_grad = grad * out_grad * mean_num 

368 

369 inp_grad_ptrs = inp_grad_ptr + offset 

370 tl.store(inp_grad_ptrs, inp_grad, mask) 

371 

372 

373@libentry() 

374@triton.autotune( 

375 configs=runtime.get_tuned_config("cross_entropy_loss"), 

376 key=["C", "D"], 

377) 

378@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"]) 

379def celoss_indices_smooth_bwd( 

380 out_grad_ptr, 

381 inp_ptr, 

382 tgt_ptr, 

383 w_ptr, 

384 inp_grad_ptr, 

385 ignore_index, 

386 label_smoothing, 

387 mean_num, 

388 C, 

389 D, 

390 BLOCK_C: tl.constexpr, 

391 BLOCK_D: tl.constexpr, 

392): 

393 pid_d = tle.program_id(0) 

394 pid_n = tle.program_id(1) 

395 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

396 

397 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

398 tgt_mask = offset_d < D 

399 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

400 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

401 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :] 

402 

403 ignore_mask = (tgt != ignore_index)[None, :] 

404 

405 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

406 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

407 w_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

408 

409 for off in range(0, C, BLOCK_C): 

410 offset_c = off + tl.arange(0, BLOCK_C) 

411 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

412 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

413 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

414 

415 w_mask = offset_c < C 

416 if w_ptr is None: 

417 w = w_mask 

418 else: 

419 w_ptrs = w_ptr + offset_c 

420 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

421 

422 smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32) 

423 smooth = tl.where( 

424 offset_c[:, None] == tgt[None, :], 

425 1 - label_smoothing + label_smoothing / C, 

426 smooth, 

427 ) 

428 

429 w_sum += smooth * w[:, None] 

430 

431 cur_max = tl.maximum(tmp_max, inp) 

432 cur_exp = tl.exp(inp - cur_max) 

433 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

434 tmp_max = cur_max 

435 final_max = tl.max(tmp_max, axis=0)[None, :] 

436 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

437 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

438 w_sum = tl.sum(w_sum, axis=0)[None, :] 

439 

440 for off in range(0, C, BLOCK_C): 

441 offset_c = off + tl.arange(0, BLOCK_C) 

442 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

443 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

444 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

445 

446 w_mask = offset_c < C 

447 if w_ptr is None: 

448 w = w_mask 

449 else: 

450 w_ptrs = w_ptr + offset_c 

451 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

452 

453 smooth = tl.where( 

454 offset_c[:, None] == tgt[None, :], 

455 1 - label_smoothing + label_smoothing / C, 

456 label_smoothing / C, 

457 ) 

458 

459 grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[:, None] 

460 inp_grad = grad * out_grad * mean_num 

461 inp_grad_ptrs = ( 

462 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

463 ) 

464 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask) 

465 

466 

467@libentry() 

468@triton.jit 

469def sum_and_scale( 

470 inp_ptr, 

471 out_ptr, 

472 N, 

473 scalebyw: tl.constexpr, 

474 BLOCK_N: tl.constexpr = 128, 

475 scale=1.0, 

476 mean_num=None, 

477): 

478 mid_sum = tl.zeros( 

479 [ 

480 BLOCK_N, 

481 ], 

482 dtype=tl.float32, 

483 ) 

484 if scalebyw: 

485 mid_wgt = tl.zeros( 

486 [ 

487 BLOCK_N, 

488 ], 

489 dtype=tl.float32, 

490 ) 

491 for off in range(0, N, BLOCK_N): 

492 offset = off + tl.arange(0, BLOCK_N) 

493 inp_ptrs = inp_ptr + offset 

494 mask = offset < N 

495 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0) 

496 mid_sum += inp_vals 

497 wgt_ptrs = scale + offset 

498 wgt_vals = tl.load(wgt_ptrs, mask=mask, other=0.0) 

499 mid_wgt += wgt_vals 

500 out_val = tl.sum(mid_sum) 

501 scale_val = tl.sum(mid_wgt) 

502 tl.store(mean_num, scale_val) 

503 else: 

504 for off in range(0, N, BLOCK_N): 

505 offset = off + tl.arange(0, BLOCK_N) 

506 inp_ptrs = inp_ptr + offset 

507 mask = offset < N 

508 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0) 

509 mid_sum += inp_vals 

510 out_val = tl.sum(mid_sum) 

511 scale_val = scale 

512 out_val /= scale_val 

513 tl.store(out_ptr, out_val) 

514 

515 

516class CrossEntropyLoss(torch.autograd.Function): 

517 @staticmethod 

518 def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing): 

519 logger.debug("GEMS CrossEntropyLoss") 

520 

521 shape = list(inp.shape) 

522 dim = inp.ndim 

523 N = 1 if dim == 1 else shape[0] 

524 C = shape[0] if dim == 1 else shape[1] 

525 D = inp.numel() // N // C 

526 axis = 0 if dim == 1 else 1 

527 del shape[axis] 

528 

529 inp = inp.contiguous() 

530 tgt = target.contiguous() 

531 weight = weight.contiguous() if weight is not None else None 

532 out = torch.empty(shape, dtype=torch.float32, device=inp.device) 

533 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N) 

534 

535 if tgt.ndim == dim: 

536 # target probabilities 

537 with torch_device_fn.device(inp.device): 

538 celoss_probability_kernel[grid]( 

539 inp, 

540 tgt, 

541 weight, 

542 out, 

543 label_smoothing, 

544 C, 

545 D, 

546 ) 

547 elif label_smoothing == 0: 

548 # target indices 

549 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device) 

550 with torch_device_fn.device(inp.device): 

551 celoss_indices_kernel[grid]( 

552 inp, 

553 tgt, 

554 weight, 

555 out, 

556 w_tgt, 

557 ignore_index, 

558 C, 

559 D, 

560 ) 

561 else: 

562 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device) 

563 with torch_device_fn.device(inp.device): 

564 celoss_indices_smooth_kernel[grid]( 

565 inp, 

566 tgt, 

567 weight, 

568 out, 

569 w_tgt, 

570 ignore_index, 

571 label_smoothing, 

572 C, 

573 D, 

574 ) 

575 

576 if reduction == 1: # MEAN 

577 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device) 

578 if tgt.ndim == dim: 

579 sum_and_scale[(1,)](out, out_reduce, N * D, False, scale=N * D) 

580 else: 

581 wgt_sum = torch.empty([], dtype=torch.float32, device=inp.device) 

582 sum_and_scale[(1,)]( 

583 out, out_reduce, N * D, True, scale=w_tgt, mean_num=wgt_sum 

584 ) 

585 out = out_reduce 

586 elif reduction == 2: # SUM 

587 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device) 

588 sum_and_scale[(1,)](out, out_reduce, N * D, False) 

589 out = out_reduce 

590 

591 if inp.requires_grad: 

592 ctx.save_for_backward(inp, tgt, weight) 

593 ctx.N = N 

594 ctx.C = C 

595 ctx.D = D 

596 ctx.ignore_index = ignore_index 

597 ctx.label_smoothing = label_smoothing 

598 ctx.shape = shape 

599 ctx.mean_num = 1 

600 if reduction == 1: 

601 ctx.mean_num = N * D if tgt.ndim == dim else wgt_sum 

602 

603 return out.to(inp.dtype) 

604 

605 @staticmethod 

606 def backward(ctx, out_grad): 

607 logger.debug("GEMS CrossEntropyLoss VJP") 

608 

609 inp, tgt, weight = ctx.saved_tensors 

610 N = ctx.N 

611 C = ctx.C 

612 D = ctx.D 

613 ignore_index = ctx.ignore_index 

614 label_smoothing = ctx.label_smoothing 

615 mean_num = ( 

616 1 / ctx.mean_num.item() 

617 if isinstance(ctx.mean_num, torch.Tensor) 

618 else 1 / ctx.mean_num 

619 ) 

620 shape = ctx.shape 

621 

622 out_grad = out_grad.broadcast_to(shape).contiguous() 

623 

624 inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device) 

625 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N) 

626 if tgt.ndim == inp.ndim: 

627 celoss_probability_bwd[grid]( 

628 out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, C, D 

629 ) 

630 elif label_smoothing == 0: 

631 celoss_indices_bwd[grid]( 

632 out_grad, inp, tgt, weight, inp_grad, ignore_index, mean_num, C, D 

633 ) 

634 else: 

635 celoss_indices_smooth_bwd[grid]( 

636 out_grad, 

637 inp, 

638 tgt, 

639 weight, 

640 inp_grad, 

641 ignore_index, 

642 label_smoothing, 

643 mean_num, 

644 C, 

645 D, 

646 ) 

647 return inp_grad, None, None, None, None, None 

648 

649 

650def cross_entropy_loss( 

651 inp, target, weight=None, reduction="mean", ignore_index=-100, label_smoothing=0.0 

652): 

653 return CrossEntropyLoss.apply( 

654 inp, 

655 target, 

656 weight, 

657 _Reduction.get_enum(reduction), 

658 ignore_index, 

659 label_smoothing, 

660 )