Coverage for src/flag_gems/runtime/backend/_ascend/fused/cross_entropy_loss.py: 0%

380 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch.nn import _reduction as _Reduction 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@triton.autotune( 

18 configs=runtime.get_tuned_config("cross_entropy_loss"), 

19 key=["C", "D"], 

20) 

21@triton.jit(do_not_specialize=["ignore_index"]) 

22def celoss_indices_kernel( 

23 inp_ptr, 

24 tgt_ptr, 

25 w_ptr, 

26 out_ptr, 

27 w_tgt_ptr, 

28 ignore_index, 

29 C, 

30 D, 

31 BLOCK_C: tl.constexpr, 

32 BLOCK_D: tl.constexpr, 

33): 

34 pid_d = tle.program_id(0) 

35 pid_n = tle.program_id(1) 

36 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

37 

38 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

39 tgt_mask = offset_d < D 

40 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

41 

42 ignore_mask = not (tgt == ignore_index) and tgt_mask 

43 

44 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

45 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

46 

47 for off in range(0, C, BLOCK_C): 

48 offset_c = off + tl.arange(0, BLOCK_C) 

49 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

50 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

51 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

52 cur_max = tl.maximum(tmp_max, inp) 

53 cur_exp = tl.exp(inp - cur_max) 

54 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

55 tmp_max = cur_max 

56 final_max = tl.max(tmp_max, axis=0) 

57 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :]) 

58 final_sum = tl.log(tl.sum(tmp_sum, axis=0)) 

59 

60 inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d 

61 inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32) 

62 

63 out = final_sum + final_max - inp_tgt 

64 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d 

65 

66 if w_ptr is None: 

67 w_tgt = ignore_mask 

68 else: 

69 w_tgt = tl.load(w_ptr + tgt, mask=tgt_mask, other=0).to(tl.float32) 

70 w_tgt = tl.where(ignore_mask, w_tgt, 0) 

71 

72 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask) 

73 out *= w_tgt 

74 out_ptrs = out_ptr + pid_n * D + offset_d 

75 tl.store(out_ptrs, out, mask=tgt_mask) 

76 

77 

78@libentry() 

79@triton.autotune( 

80 configs=runtime.get_tuned_config("cross_entropy_loss"), 

81 key=["C", "D"], 

82) 

83@triton.jit(do_not_specialize=["label_smoothing"]) 

84def celoss_probability_kernel( 

85 inp_ptr, 

86 tgt_ptr, 

87 w_ptr, 

88 out_ptr, 

89 label_smoothing, 

90 C, 

91 D, 

92 BLOCK_C: tl.constexpr, 

93 BLOCK_D: tl.constexpr, 

94): 

95 pid_d = tle.program_id(0) 

96 pid_n = tle.program_id(1) 

97 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

98 

99 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

100 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

101 

102 for off in range(0, C, BLOCK_C): 

103 offset_c = off + tl.arange(0, BLOCK_C) 

104 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

105 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

106 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

107 cur_max = tl.maximum(tmp_max, inp) 

108 cur_exp = tl.exp(inp - cur_max) 

109 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

110 tmp_max = cur_max 

111 final_max = tl.max(tmp_max, axis=0)[None, :] 

112 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

113 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :] 

114 

115 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

116 for off in range(0, C, BLOCK_C): 

117 offset_c = off + tl.arange(0, BLOCK_C) 

118 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

119 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

120 mask = offset_c[:, None] < C and offset_d[None, :] < D 

121 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

122 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

123 tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C 

124 log = final_sum + final_max - inp 

125 w_mask = offset_c < C 

126 if w_ptr is None: 

127 w = w_mask 

128 else: 

129 w = tl.load(w_ptr + offset_c, mask=w_mask, other=0).to(tl.float32) 

130 _sum += log * tgt * w[:, None] 

131 

132 out = tl.sum(_sum, axis=0) 

133 out_ptrs = out_ptr + pid_n * D + offset_d 

134 tl.store(out_ptrs, out, mask=offset_d < D) 

135 

136 

137@libentry() 

138@triton.autotune( 

139 configs=runtime.get_tuned_config("cross_entropy_loss"), 

140 key=["C", "D"], 

141) 

142@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"]) 

143def celoss_indices_smooth_kernel( 

144 inp_ptr, 

145 tgt_ptr, 

146 w_ptr, 

147 out_ptr, 

148 w_tgt_ptr, 

149 ignore_index, 

150 label_smoothing, 

151 C, 

152 D, 

153 BLOCK_C: tl.constexpr, 

154 BLOCK_D: tl.constexpr, 

155): 

156 pid_d = tle.program_id(0) 

157 pid_n = tle.program_id(1) 

158 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

159 

160 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

161 tgt_mask = offset_d < D 

162 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

163 

164 ignore_mask = not (tgt == ignore_index) and tgt_mask 

165 

166 if w_ptr is None: 

167 w_tgt = ignore_mask 

168 else: 

169 w_tgt = tl.load(w_ptr + tgt, mask=tgt_mask, other=0) 

170 w_tgt = tl.where(ignore_mask, w_tgt, 0) 

171 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d 

172 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask) 

173 

174 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

175 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

176 

177 for off in range(0, C, BLOCK_C): 

178 offset_c = off + tl.arange(0, BLOCK_C) 

179 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

180 mask = offset_c[:, None] < C and offset_d[None, :] < D 

181 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32) 

182 cur_max = tl.maximum(tmp_max, inp) 

183 cur_exp = tl.exp(inp - cur_max) 

184 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

185 tmp_max = cur_max 

186 final_max = tl.max(tmp_max, axis=0)[None, :] 

187 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

188 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :] 

189 final_sum_max = final_sum + final_max 

190 

191 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

192 for off in range(0, C, BLOCK_C): 

193 offset_c = off + tl.arange(0, BLOCK_C) 

194 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

195 mask = offset_c[:, None] < C and offset_d[None, :] < D 

196 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

197 

198 w_mask = offset_c < C 

199 if w_ptr is None: 

200 w = w_mask 

201 else: 

202 w = tl.load(w_ptr + offset_c, w_mask, other=0).to(tl.float32) 

203 

204 smooth = tl.where( 

205 offset_c[:, None] == tgt[None, :], 

206 1 - label_smoothing + label_smoothing / C, 

207 label_smoothing / C, 

208 ).to(tl.float32) 

209 

210 log = final_sum_max - inp 

211 _sum += log * smooth * w[:, None] 

212 

213 out = tl.sum(_sum, axis=0) 

214 out = tl.where(ignore_mask, out, 0) 

215 out_ptrs = out_ptr + pid_n * D + offset_d 

216 tl.store(out_ptrs, out, mask=tgt_mask) 

217 

218 

219@libentry() 

220@triton.autotune( 

221 configs=runtime.get_tuned_config("cross_entropy_loss"), 

222 key=["C", "D"], 

223) 

224@triton.jit(do_not_specialize=["ignore_index", "mean_num"]) 

225def celoss_indices_bwd( 

226 out_grad_ptr, 

227 inp_ptr, 

228 tgt_ptr, 

229 w_ptr, 

230 inp_grad_ptr, 

231 ignore_index, 

232 mean_num, 

233 C, 

234 D, 

235 BLOCK_C: tl.constexpr, 

236 BLOCK_D: tl.constexpr, 

237): 

238 pid_d = tle.program_id(0) 

239 pid_n = tle.program_id(1) 

240 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

241 

242 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

243 tgt_mask = offset_d < D 

244 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

245 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

246 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :] 

247 

248 if w_ptr is None: 

249 w_tgt = tgt_mask 

250 else: 

251 w_ptrs = w_ptr + tgt 

252 w_tgt = tl.load(w_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :] 

253 

254 ignore_mask = (tgt != ignore_index)[None, :] 

255 

256 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

257 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

258 

259 for off in range(0, C, BLOCK_C): 

260 offset_c = off + tl.arange(0, BLOCK_C) 

261 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

262 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

263 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

264 cur_max = tl.maximum(tmp_max, inp) 

265 cur_exp = tl.exp(inp - cur_max) 

266 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

267 tmp_max = cur_max 

268 final_max = tl.max(tmp_max, axis=0)[None, :] 

269 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

270 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

271 

272 for off in range(0, C, BLOCK_C): 

273 offset_c = off + tl.arange(0, BLOCK_C) 

274 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

275 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

276 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

277 minus_one = offset_c[:, None] == tgt[None, :] 

278 inp_grad = ( 

279 (tl.exp(inp - final_max) / final_sum - minus_one) 

280 * w_tgt 

281 * out_grad 

282 * mean_num 

283 ) 

284 inp_grad = tl.where(ignore_mask, inp_grad, 0.0) 

285 inp_grad_ptrs = ( 

286 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

287 ) 

288 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask) 

289 

290 

291@libentry() 

292@triton.autotune( 

293 configs=runtime.get_tuned_config("cross_entropy_loss"), 

294 key=["C", "D"], 

295) 

296@triton.jit(do_not_specialize=["label_smoothing", "mean_num"]) 

297def celoss_probability_bwd( 

298 out_grad_ptr, 

299 inp_ptr, 

300 tgt_ptr, 

301 w_ptr, 

302 inp_grad_ptr, 

303 label_smoothing, 

304 mean_num, 

305 C, 

306 D, 

307 BLOCK_C: tl.constexpr, 

308 BLOCK_D: tl.constexpr, 

309): 

310 pid_d = tle.program_id(0) 

311 pid_n = tle.program_id(1) 

312 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

313 

314 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

315 out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32)[ 

316 None, : 

317 ] 

318 

319 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

320 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

321 w_tgt_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

322 

323 for off in range(0, C, BLOCK_C): 

324 offset_c = off + tl.arange(0, BLOCK_C) 

325 mask = offset_c[:, None] < C and offset_d[None, :] < D 

326 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

327 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32) 

328 

329 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

330 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

331 tgt = tgt * (1 - label_smoothing) + label_smoothing / C 

332 

333 w_mask = offset_c < C 

334 if w_ptr is None: 

335 w = w_mask 

336 else: 

337 w_ptrs = w_ptr + offset_c 

338 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

339 

340 w_tgt_sum += tgt * w[:, None] 

341 

342 cur_max = tl.maximum(tmp_max, inp) 

343 cur_exp = tl.exp(inp - cur_max) 

344 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

345 tmp_max = cur_max 

346 final_max = tl.max(tmp_max, axis=0)[None, :] 

347 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

348 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

349 w_tgt_sum = tl.sum(w_tgt_sum, axis=0)[None, :] 

350 

351 for off in range(0, C, BLOCK_C): 

352 offset_c = off + tl.arange(0, BLOCK_C) 

353 offset = pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

354 inp_ptrs = inp_ptr + offset 

355 mask = offset_c[:, None] < C and offset_d[None, :] < D 

356 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

357 

358 tgt_ptrs = tgt_ptr + offset 

359 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

360 tgt = tgt * (1 - label_smoothing) + label_smoothing / C 

361 

362 w_mask = offset_c < C 

363 if w_ptr is None: 

364 w = w_mask 

365 else: 

366 w_ptrs = w_ptr + offset_c 

367 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

368 

369 grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - tgt * w[:, None] 

370 inp_grad = grad * out_grad * mean_num 

371 

372 inp_grad_ptrs = inp_grad_ptr + offset 

373 tl.store(inp_grad_ptrs, inp_grad, mask) 

374 

375 

376@libentry() 

377@triton.autotune( 

378 configs=runtime.get_tuned_config("cross_entropy_loss"), 

379 key=["C", "D"], 

380) 

381@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"]) 

382def celoss_indices_smooth_bwd( 

383 out_grad_ptr, 

384 inp_ptr, 

385 tgt_ptr, 

386 w_ptr, 

387 inp_grad_ptr, 

388 ignore_index, 

389 label_smoothing, 

390 mean_num, 

391 C, 

392 D, 

393 BLOCK_C: tl.constexpr, 

394 BLOCK_D: tl.constexpr, 

395): 

396 pid_d = tle.program_id(0) 

397 pid_n = tle.program_id(1) 

398 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

399 

400 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

401 tgt_mask = offset_d < D 

402 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

403 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

404 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :] 

405 

406 ignore_mask = (tgt != ignore_index)[None, :] 

407 

408 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

409 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

410 w_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

411 

412 for off in range(0, C, BLOCK_C): 

413 offset_c = off + tl.arange(0, BLOCK_C) 

414 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

415 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

416 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

417 

418 w_mask = offset_c < C 

419 if w_ptr is None: 

420 w = w_mask 

421 else: 

422 w_ptrs = w_ptr + offset_c 

423 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

424 

425 smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32) 

426 smooth = tl.where( 

427 offset_c[:, None] == tgt[None, :], 

428 1 - label_smoothing + label_smoothing / C, 

429 smooth, 

430 ) 

431 

432 w_sum += smooth * w[:, None] 

433 

434 cur_max = tl.maximum(tmp_max, inp) 

435 cur_exp = tl.exp(inp - cur_max) 

436 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

437 tmp_max = cur_max 

438 final_max = tl.max(tmp_max, axis=0)[None, :] 

439 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

440 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

441 w_sum = tl.sum(w_sum, axis=0)[None, :] 

442 

443 for off in range(0, C, BLOCK_C): 

444 offset_c = off + tl.arange(0, BLOCK_C) 

445 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

446 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

447 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

448 

449 w_mask = offset_c < C 

450 if w_ptr is None: 

451 w = w_mask 

452 else: 

453 w_ptrs = w_ptr + offset_c 

454 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

455 

456 smooth = tl.where( 

457 offset_c[:, None] == tgt[None, :], 

458 1 - label_smoothing + label_smoothing / C, 

459 label_smoothing / C, 

460 ) 

461 

462 grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[:, None] 

463 inp_grad = grad * out_grad * mean_num 

464 inp_grad = tl.where(ignore_mask, inp_grad, 0.0) 

465 inp_grad_ptrs = ( 

466 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :] 

467 ) 

468 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask) 

469 

470 

471@libentry() 

472@triton.jit 

473def sum_and_scale( 

474 inp_ptr, 

475 out_ptr, 

476 N, 

477 scalebyw: tl.constexpr, 

478 BLOCK_N: tl.constexpr = 128, 

479 scale=1.0, 

480 mean_num=None, 

481): 

482 mid_sum = tl.zeros( 

483 [ 

484 BLOCK_N, 

485 ], 

486 dtype=tl.float32, 

487 ) 

488 if scalebyw: 

489 mid_wgt = tl.zeros( 

490 [ 

491 BLOCK_N, 

492 ], 

493 dtype=tl.float32, 

494 ) 

495 for off in range(0, N, BLOCK_N): 

496 offset = off + tl.arange(0, BLOCK_N) 

497 inp_ptrs = inp_ptr + offset 

498 mask = offset < N 

499 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0) 

500 mid_sum += inp_vals 

501 wgt_ptrs = scale + offset 

502 wgt_vals = tl.load(wgt_ptrs, mask=mask, other=0.0) 

503 mid_wgt += wgt_vals 

504 out_val = tl.sum(mid_sum) 

505 scale_val = tl.sum(mid_wgt) 

506 tl.store(mean_num, scale_val) 

507 else: 

508 for off in range(0, N, BLOCK_N): 

509 offset = off + tl.arange(0, BLOCK_N) 

510 inp_ptrs = inp_ptr + offset 

511 mask = offset < N 

512 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0) 

513 mid_sum += inp_vals 

514 out_val = tl.sum(mid_sum) 

515 scale_val = scale 

516 out_val /= scale_val 

517 tl.store(out_ptr, out_val) 

518 

519 

520class CrossEntropyLoss(torch.autograd.Function): 

521 @staticmethod 

522 def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing): 

523 logger.debug("GEMS_ASCEND CrossEntropyLoss") 

524 

525 shape = list(inp.shape) 

526 dim = inp.ndim 

527 N = 1 if dim == 1 else shape[0] 

528 C = shape[0] if dim == 1 else shape[1] 

529 D = inp.numel() // N // C 

530 axis = 0 if dim == 1 else 1 

531 del shape[axis] 

532 

533 inp = inp.contiguous() 

534 tgt = target.contiguous() 

535 weight = weight.contiguous() if weight is not None else None 

536 out = torch.empty(shape, dtype=torch.float32, device=inp.device) 

537 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N) 

538 

539 if tgt.ndim == dim: 

540 # target probabilities 

541 with torch_device_fn.device(inp.device): 

542 celoss_probability_kernel[grid]( 

543 inp, 

544 tgt, 

545 weight, 

546 out, 

547 label_smoothing, 

548 C, 

549 D, 

550 ) 

551 elif label_smoothing == 0: 

552 # target indices 

553 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device) 

554 with torch_device_fn.device(inp.device): 

555 celoss_indices_kernel[grid]( 

556 inp, 

557 tgt, 

558 weight, 

559 out, 

560 w_tgt, 

561 ignore_index, 

562 C, 

563 D, 

564 ) 

565 else: 

566 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device) 

567 with torch_device_fn.device(inp.device): 

568 celoss_indices_smooth_kernel[grid]( 

569 inp, 

570 tgt, 

571 weight, 

572 out, 

573 w_tgt, 

574 ignore_index, 

575 label_smoothing, 

576 C, 

577 D, 

578 ) 

579 

580 if reduction == 1: # MEAN 

581 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device) 

582 if tgt.ndim == dim: 

583 sum_and_scale[(1,)](out, out_reduce, N * D, False, scale=N * D) 

584 else: 

585 wgt_sum = torch.empty([], dtype=torch.float32, device=inp.device) 

586 sum_and_scale[(1,)]( 

587 out, out_reduce, N * D, True, scale=w_tgt, mean_num=wgt_sum 

588 ) 

589 out = out_reduce 

590 elif reduction == 2: # SUM 

591 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device) 

592 sum_and_scale[(1,)](out, out_reduce, N * D, False) 

593 out = out_reduce 

594 

595 if inp.requires_grad: 

596 ctx.save_for_backward(inp, tgt, weight) 

597 ctx.N = N 

598 ctx.C = C 

599 ctx.D = D 

600 ctx.ignore_index = ignore_index 

601 ctx.label_smoothing = label_smoothing 

602 ctx.shape = shape 

603 ctx.mean_num = 1 

604 if reduction == 1: 

605 ctx.mean_num = N * D if tgt.ndim == dim else wgt_sum 

606 

607 return out.to(inp.dtype) 

608 

609 @staticmethod 

610 def backward(ctx, out_grad): 

611 logger.debug("GEMS_ASCEND CrossEntropyLoss VJP") 

612 

613 inp, tgt, weight = ctx.saved_tensors 

614 N = ctx.N 

615 C = ctx.C 

616 D = ctx.D 

617 ignore_index = ctx.ignore_index 

618 label_smoothing = ctx.label_smoothing 

619 mean_num = ( 

620 1 / ctx.mean_num.item() 

621 if isinstance(ctx.mean_num, torch.Tensor) 

622 else 1 / ctx.mean_num 

623 ) 

624 shape = ctx.shape 

625 

626 out_grad = out_grad.broadcast_to(shape).contiguous() 

627 

628 inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device) 

629 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N) 

630 if tgt.ndim == inp.ndim: 

631 celoss_probability_bwd[grid]( 

632 out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, C, D 

633 ) 

634 elif label_smoothing == 0: 

635 celoss_indices_bwd[grid]( 

636 out_grad, inp, tgt, weight, inp_grad, ignore_index, mean_num, C, D 

637 ) 

638 else: 

639 celoss_indices_smooth_bwd[grid]( 

640 out_grad, 

641 inp, 

642 tgt, 

643 weight, 

644 inp_grad, 

645 ignore_index, 

646 label_smoothing, 

647 mean_num, 

648 C, 

649 D, 

650 ) 

651 return inp_grad, None, None, None, None, None 

652 

653 

654def cross_entropy_loss( 

655 inp, target, weight=None, reduction="mean", ignore_index=-100, label_smoothing=0.0 

656): 

657 return CrossEntropyLoss.apply( 

658 inp, 

659 target, 

660 weight, 

661 _Reduction.get_enum(reduction), 

662 ignore_index, 

663 label_smoothing, 

664 )