Coverage for src/flag_gems/runtime/backend/_mthreads/fused/cross_entropy_loss.py: 0%

376 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch.nn import _reduction as _Reduction 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@triton.autotune( 

18 configs=runtime.get_tuned_config("cross_entropy_loss"), 

19 key=["C", "D"], 

20) 

21@triton.jit(do_not_specialize=["ignore_index"]) 

22def celoss_indices_kernel( 

23 inp_ptr, 

24 tgt_ptr, 

25 w_ptr, 

26 out_ptr, 

27 w_tgt_ptr, 

28 ignore_index, 

29 C, 

30 D, 

31 BLOCK_C: tl.constexpr, 

32 BLOCK_D: tl.constexpr, 

33): 

34 pid_d = tle.program_id(0) 

35 pid_n = tle.program_id(1) 

36 offset_d = (pid_d * BLOCK_D).to(tl.int64) + tl.arange(0, BLOCK_D) 

37 

38 tgt_ptrs = tgt_ptr + (pid_n * D).to(tl.int64) + offset_d 

39 tgt_mask = offset_d < D 

40 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

41 

42 ignore_mask = not (tgt == ignore_index) and tgt_mask 

43 

44 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

45 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

46 

47 for off in range(0, C, BLOCK_C): 

48 offset_c = off + tl.arange(0, BLOCK_C) 

49 inp_ptrs = inp_ptr + ( 

50 (pid_n * C * D).to(tl.int64) 

51 + (offset_c[:, None] * D).to(tl.int64) 

52 + offset_d[None, :] 

53 ).to(tl.int64) 

54 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

55 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

56 cur_max = tl.maximum(tmp_max, inp) 

57 cur_exp = tl.exp(inp - cur_max) 

58 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

59 tmp_max = cur_max 

60 final_max = tl.max(tmp_max, axis=0) 

61 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :]) 

62 final_sum = tl.log(tl.sum(tmp_sum, axis=0)) 

63 

64 inp_tgt_ptrs = inp_ptr + (pid_n * C * D).to(tl.int64) + tgt * D + offset_d 

65 inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32) 

66 

67 out = final_sum + final_max - inp_tgt 

68 w_tgt_ptrs = w_tgt_ptr + (pid_n * D).to(tl.int64) + offset_d 

69 

70 if w_ptr is None: 

71 w_tgt = ignore_mask 

72 else: 

73 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32) 

74 

75 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask) 

76 out *= w_tgt 

77 out_ptrs = out_ptr + (pid_n * D).to(tl.int64) + offset_d 

78 tl.store(out_ptrs, out, mask=tgt_mask) 

79 

80 

81@libentry() 

82@triton.autotune( 

83 configs=runtime.get_tuned_config("cross_entropy_loss"), 

84 key=["C", "D"], 

85) 

86@triton.jit(do_not_specialize=["label_smoothing"]) 

87def celoss_probability_kernel( 

88 inp_ptr, 

89 tgt_ptr, 

90 w_ptr, 

91 out_ptr, 

92 label_smoothing, 

93 C, 

94 D, 

95 BLOCK_C: tl.constexpr, 

96 BLOCK_D: tl.constexpr, 

97): 

98 pid_d = tle.program_id(0) 

99 pid_n = tle.program_id(1) 

100 offset_d = (pid_d * BLOCK_D).to(tl.int64) + tl.arange(0, BLOCK_D) 

101 

102 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

103 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

104 

105 for off in range(0, C, BLOCK_C): 

106 offset_c = off + tl.arange(0, BLOCK_C) 

107 inp_ptrs = inp_ptr + ( 

108 (pid_n * C * D).to(tl.int64) 

109 + (offset_c[:, None] * D).to(tl.int64) 

110 + offset_d[None, :] 

111 ).to(tl.int64) 

112 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

113 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

114 cur_max = tl.maximum(tmp_max, inp) 

115 cur_exp = tl.exp(inp - cur_max) 

116 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

117 tmp_max = cur_max 

118 final_max = tl.max(tmp_max, axis=0)[None, :] 

119 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

120 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :] 

121 

122 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

123 for off in range(0, C, BLOCK_C): 

124 offset_c = off + tl.arange(0, BLOCK_C) 

125 inp_ptrs = inp_ptr + ( 

126 (pid_n * C * D).to(tl.int64) 

127 + (offset_c[:, None] * D).to(tl.int64) 

128 + offset_d[None, :] 

129 ).to(tl.int64) 

130 tgt_ptrs = tgt_ptr + ( 

131 (pid_n * C * D).to(tl.int64) 

132 + (offset_c[:, None] * D).to(tl.int64) 

133 + offset_d[None, :] 

134 ).to(tl.int64) 

135 mask = offset_c[:, None] < C and offset_d[None, :] < D 

136 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

137 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

138 tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C 

139 log = final_sum + final_max - inp 

140 w_mask = offset_c < C 

141 if w_ptr is None: 

142 w = w_mask 

143 else: 

144 w = tl.load(w_ptr + offset_c, mask=w_mask, other=0).to(tl.float32) 

145 _sum += log * tgt * w[:, None] 

146 

147 out = tl.sum(_sum, axis=0) 

148 out_ptrs = out_ptr + (pid_n * D).to(tl.int64) + offset_d 

149 tl.store(out_ptrs, out, mask=offset_d < D) 

150 

151 

152@libentry() 

153@triton.autotune( 

154 configs=runtime.get_tuned_config("cross_entropy_loss"), 

155 key=["C", "D"], 

156) 

157@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"]) 

158def celoss_indices_smooth_kernel( 

159 inp_ptr, 

160 tgt_ptr, 

161 w_ptr, 

162 out_ptr, 

163 w_tgt_ptr, 

164 ignore_index, 

165 label_smoothing, 

166 C, 

167 D, 

168 BLOCK_C: tl.constexpr, 

169 BLOCK_D: tl.constexpr, 

170): 

171 pid_d = tle.program_id(0) 

172 pid_n = tle.program_id(1) 

173 offset_d = (pid_d * BLOCK_D).to(tl.int64) + tl.arange(0, BLOCK_D) 

174 

175 tgt_ptrs = tgt_ptr + (pid_n * D).to(tl.int64) + offset_d 

176 tgt_mask = offset_d < D 

177 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

178 

179 ignore_mask = not (tgt == ignore_index) and tgt_mask 

180 

181 if w_ptr is None: 

182 w_tgt = ignore_mask 

183 else: 

184 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0) 

185 w_tgt_ptrs = w_tgt_ptr + (pid_n * D).to(tl.int64) + offset_d 

186 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask) 

187 

188 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

189 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

190 

191 for off in range(0, C, BLOCK_C): 

192 offset_c = off + tl.arange(0, BLOCK_C) 

193 inp_ptrs = inp_ptr + ( 

194 (pid_n * C * D).to(tl.int64) 

195 + (offset_c[:, None] * D).to(tl.int64) 

196 + offset_d[None, :] 

197 ).to(tl.int64) 

198 mask = offset_c[:, None] < C and offset_d[None, :] < D 

199 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32) 

200 cur_max = tl.maximum(tmp_max, inp) 

201 cur_exp = tl.exp(inp - cur_max) 

202 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

203 tmp_max = cur_max 

204 final_max = tl.max(tmp_max, axis=0)[None, :] 

205 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

206 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :] 

207 final_sum_max = final_sum + final_max 

208 

209 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

210 for off in range(0, C, BLOCK_C): 

211 offset_c = off + tl.arange(0, BLOCK_C) 

212 inp_ptrs = inp_ptr + ( 

213 (pid_n * C * D).to(tl.int64) 

214 + (offset_c[:, None] * D).to(tl.int64) 

215 + offset_d[None, :] 

216 ).to(tl.int64) 

217 mask = offset_c[:, None] < C and offset_d[None, :] < D 

218 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

219 

220 w_mask = offset_c < C 

221 if w_ptr is None: 

222 w = w_mask 

223 else: 

224 w = tl.load(w_ptr + offset_c, w_mask, other=0).to(tl.float32) 

225 

226 smooth = tl.where( 

227 offset_c[:, None] == tgt[None, :], 

228 1 - label_smoothing + label_smoothing / C, 

229 label_smoothing / C, 

230 ).to(tl.float32) 

231 

232 log = final_sum_max - inp 

233 _sum += log * smooth * w[:, None] 

234 

235 out = tl.sum(_sum, axis=0) 

236 out = tl.where(ignore_mask, out, 0) 

237 out_ptrs = out_ptr + (pid_n * D).to(tl.int64) + offset_d 

238 tl.store(out_ptrs, out, mask=tgt_mask) 

239 

240 

241@libentry() 

242@triton.autotune( 

243 configs=runtime.get_tuned_config("cross_entropy_loss"), 

244 key=["C", "D"], 

245) 

246@triton.jit(do_not_specialize=["ignore_index", "mean_num"]) 

247def celoss_indices_bwd( 

248 out_grad_ptr, 

249 inp_ptr, 

250 tgt_ptr, 

251 w_ptr, 

252 inp_grad_ptr, 

253 ignore_index, 

254 mean_num, 

255 C, 

256 D, 

257 BLOCK_C: tl.constexpr, 

258 BLOCK_D: tl.constexpr, 

259): 

260 pid_d = tle.program_id(0) 

261 pid_n = tle.program_id(1) 

262 offset_d = (pid_d * BLOCK_D).to(tl.int64) + tl.arange(0, BLOCK_D) 

263 

264 tgt_ptrs = tgt_ptr + (pid_n * D).to(tl.int64) + offset_d 

265 tgt_mask = offset_d < D 

266 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

267 out_grad_ptrs = out_grad_ptr + (pid_n * D).to(tl.int64) + offset_d 

268 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :] 

269 

270 if w_ptr is None: 

271 w_tgt = tgt_mask 

272 else: 

273 w_ptrs = w_ptr + tgt 

274 w_tgt = tl.load(w_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :] 

275 

276 ignore_mask = (tgt != ignore_index)[None, :] 

277 

278 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

279 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

280 

281 for off in range(0, C, BLOCK_C): 

282 offset_c = off + tl.arange(0, BLOCK_C) 

283 inp_ptrs = inp_ptr + ( 

284 (pid_n * C * D).to(tl.int64) 

285 + (offset_c[:, None] * D).to(tl.int64) 

286 + offset_d[None, :] 

287 ).to(tl.int64) 

288 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

289 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

290 cur_max = tl.maximum(tmp_max, inp) 

291 cur_exp = tl.exp(inp - cur_max) 

292 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

293 tmp_max = cur_max 

294 final_max = tl.max(tmp_max, axis=0)[None, :] 

295 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

296 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

297 

298 for off in range(0, C, BLOCK_C): 

299 offset_c = off + tl.arange(0, BLOCK_C) 

300 inp_ptrs = inp_ptr + ( 

301 (pid_n * C * D).to(tl.int64) 

302 + (offset_c[:, None] * D).to(tl.int64) 

303 + offset_d[None, :] 

304 ).to(tl.int64) 

305 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

306 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

307 minus_one = offset_c[:, None] == tgt[None, :] 

308 inp_grad = ( 

309 (tl.exp(inp - final_max) / final_sum - minus_one) 

310 * w_tgt 

311 * out_grad 

312 * mean_num 

313 ) 

314 inp_grad_ptrs = inp_grad_ptr + ( 

315 (pid_n * C * D).to(tl.int64) 

316 + (offset_c[:, None] * D).to(tl.int64) 

317 + offset_d[None, :] 

318 ).to(tl.int64) 

319 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask) 

320 

321 

322@libentry() 

323@triton.autotune( 

324 configs=runtime.get_tuned_config("cross_entropy_loss"), 

325 key=["C", "D"], 

326) 

327@triton.jit(do_not_specialize=["label_smoothing", "mean_num"]) 

328def celoss_probability_bwd( 

329 out_grad_ptr, 

330 inp_ptr, 

331 tgt_ptr, 

332 w_ptr, 

333 inp_grad_ptr, 

334 label_smoothing, 

335 mean_num, 

336 C, 

337 D, 

338 BLOCK_C: tl.constexpr, 

339 BLOCK_D: tl.constexpr, 

340): 

341 pid_d = tle.program_id(0) 

342 pid_n = tle.program_id(1) 

343 offset_d = (pid_d * BLOCK_D).to(tl.int64) + tl.arange(0, BLOCK_D) 

344 

345 out_grad_ptrs = out_grad_ptr + (pid_n * D).to(tl.int64) + offset_d 

346 out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32)[ 

347 None, : 

348 ] 

349 

350 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

351 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

352 w_tgt_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

353 

354 for off in range(0, C, BLOCK_C): 

355 offset_c = off + tl.arange(0, BLOCK_C) 

356 mask = offset_c[:, None] < C and offset_d[None, :] < D 

357 inp_ptrs = inp_ptr + ( 

358 (pid_n * C * D).to(tl.int64) 

359 + (offset_c[:, None] * D).to(tl.int64) 

360 + offset_d[None, :] 

361 ).to(tl.int64) 

362 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32) 

363 

364 tgt_ptrs = tgt_ptr + ( 

365 (pid_n * C * D).to(tl.int64) 

366 + (offset_c[:, None] * D).to(tl.int64) 

367 + offset_d[None, :] 

368 ).to(tl.int64) 

369 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

370 tgt = tgt * (1 - label_smoothing) + label_smoothing / C 

371 

372 w_mask = offset_c < C 

373 if w_ptr is None: 

374 w = w_mask 

375 else: 

376 w_ptrs = w_ptr + offset_c 

377 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

378 

379 w_tgt_sum += tgt * w[:, None] 

380 

381 cur_max = tl.maximum(tmp_max, inp) 

382 cur_exp = tl.exp(inp - cur_max) 

383 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

384 tmp_max = cur_max 

385 final_max = tl.max(tmp_max, axis=0)[None, :] 

386 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

387 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

388 w_tgt_sum = tl.sum(w_tgt_sum, axis=0)[None, :] 

389 

390 for off in range(0, C, BLOCK_C): 

391 offset_c = off + tl.arange(0, BLOCK_C) 

392 offset = ( 

393 (pid_n * C * D).to(tl.int64) 

394 + (offset_c[:, None] * D).to(tl.int64) 

395 + offset_d[None, :] 

396 ).to(tl.int64) 

397 inp_ptrs = inp_ptr + offset 

398 mask = offset_c[:, None] < C and offset_d[None, :] < D 

399 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

400 

401 tgt_ptrs = tgt_ptr + offset 

402 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

403 tgt = tgt * (1 - label_smoothing) + label_smoothing / C 

404 

405 w_mask = offset_c < C 

406 if w_ptr is None: 

407 w = w_mask 

408 else: 

409 w_ptrs = w_ptr + offset_c 

410 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

411 

412 grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - tgt * w[:, None] 

413 inp_grad = grad * out_grad * mean_num 

414 

415 inp_grad_ptrs = inp_grad_ptr + offset 

416 tl.store(inp_grad_ptrs, inp_grad, mask) 

417 

418 

419@libentry() 

420@triton.autotune( 

421 configs=runtime.get_tuned_config("cross_entropy_loss"), 

422 key=["C", "D"], 

423) 

424@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"]) 

425def celoss_indices_smooth_bwd( 

426 out_grad_ptr, 

427 inp_ptr, 

428 tgt_ptr, 

429 w_ptr, 

430 inp_grad_ptr, 

431 ignore_index, 

432 label_smoothing, 

433 mean_num, 

434 C, 

435 D, 

436 BLOCK_C: tl.constexpr, 

437 BLOCK_D: tl.constexpr, 

438): 

439 pid_d = tle.program_id(0) 

440 pid_n = tle.program_id(1) 

441 offset_d = (pid_d * BLOCK_D).to(tl.int64) + tl.arange(0, BLOCK_D) 

442 

443 tgt_ptrs = tgt_ptr + (pid_n * D).to(tl.int64) + offset_d 

444 tgt_mask = offset_d < D 

445 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

446 out_grad_ptrs = out_grad_ptr + (pid_n * D).to(tl.int64) + offset_d 

447 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :] 

448 

449 ignore_mask = (tgt != ignore_index)[None, :] 

450 

451 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

452 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

453 w_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32) 

454 

455 for off in range(0, C, BLOCK_C): 

456 offset_c = off + tl.arange(0, BLOCK_C) 

457 inp_ptrs = inp_ptr + ( 

458 (pid_n * C * D).to(tl.int64) 

459 + (offset_c[:, None] * D).to(tl.int64) 

460 + offset_d[None, :] 

461 ).to(tl.int64) 

462 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

463 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

464 

465 w_mask = offset_c < C 

466 if w_ptr is None: 

467 w = w_mask 

468 else: 

469 w_ptrs = w_ptr + offset_c 

470 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

471 

472 smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32) 

473 smooth = tl.where( 

474 offset_c[:, None] == tgt[None, :], 

475 1 - label_smoothing + label_smoothing / C, 

476 smooth, 

477 ) 

478 

479 w_sum += smooth * w[:, None] 

480 

481 cur_max = tl.maximum(tmp_max, inp) 

482 cur_exp = tl.exp(inp - cur_max) 

483 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

484 tmp_max = cur_max 

485 final_max = tl.max(tmp_max, axis=0)[None, :] 

486 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

487 final_sum = tl.sum(tmp_sum, axis=0)[None, :] 

488 w_sum = tl.sum(w_sum, axis=0)[None, :] 

489 

490 for off in range(0, C, BLOCK_C): 

491 offset_c = off + tl.arange(0, BLOCK_C) 

492 inp_ptrs = inp_ptr + ( 

493 (pid_n * C * D).to(tl.int64) 

494 + (offset_c[:, None] * D).to(tl.int64) 

495 + offset_d[None, :] 

496 ).to(tl.int64) 

497 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D 

498 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

499 

500 w_mask = offset_c < C 

501 if w_ptr is None: 

502 w = w_mask 

503 else: 

504 w_ptrs = w_ptr + offset_c 

505 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

506 

507 smooth = tl.where( 

508 offset_c[:, None] == tgt[None, :], 

509 1 - label_smoothing + label_smoothing / C, 

510 label_smoothing / C, 

511 ) 

512 

513 grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[:, None] 

514 inp_grad = grad * out_grad * mean_num 

515 inp_grad_ptrs = inp_grad_ptr + ( 

516 (pid_n * C * D).to(tl.int64) 

517 + (offset_c[:, None] * D).to(tl.int64) 

518 + offset_d[None, :] 

519 ).to(tl.int64) 

520 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask) 

521 

522 

523@libentry() 

524@triton.jit 

525def sum_and_scale( 

526 inp_ptr, 

527 out_ptr, 

528 N, 

529 scalebyw: tl.constexpr, 

530 BLOCK_N: tl.constexpr = 128, 

531 scale=1.0, 

532 mean_num=None, 

533): 

534 mid_sum = tl.zeros( 

535 [ 

536 BLOCK_N, 

537 ], 

538 dtype=tl.float32, 

539 ) 

540 if scalebyw: 

541 mid_wgt = tl.zeros( 

542 [ 

543 BLOCK_N, 

544 ], 

545 dtype=tl.float32, 

546 ) 

547 for off in range(0, N, BLOCK_N): 

548 offset = off + tl.arange(0, BLOCK_N) 

549 inp_ptrs = inp_ptr + offset 

550 mask = offset < N 

551 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0) 

552 mid_sum += inp_vals 

553 wgt_ptrs = scale + offset 

554 wgt_vals = tl.load(wgt_ptrs, mask=mask, other=0.0) 

555 mid_wgt += wgt_vals 

556 out_val = tl.sum(mid_sum) 

557 scale_val = tl.sum(mid_wgt) 

558 tl.store(mean_num, scale_val) 

559 else: 

560 for off in range(0, N, BLOCK_N): 

561 offset = off + tl.arange(0, BLOCK_N) 

562 inp_ptrs = inp_ptr + offset 

563 mask = offset < N 

564 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0) 

565 mid_sum += inp_vals 

566 out_val = tl.sum(mid_sum) 

567 scale_val = scale 

568 out_val /= scale_val 

569 tl.store(out_ptr, out_val) 

570 

571 

572class CrossEntropyLoss(torch.autograd.Function): 

573 @staticmethod 

574 def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing): 

575 logger.debug("GEMS CrossEntropyLoss") 

576 

577 shape = list(inp.shape) 

578 dim = inp.ndim 

579 N = 1 if dim == 1 else shape[0] 

580 C = shape[0] if dim == 1 else shape[1] 

581 D = inp.numel() // N // C 

582 axis = 0 if dim == 1 else 1 

583 del shape[axis] 

584 

585 inp = inp.contiguous() 

586 tgt = target.contiguous() 

587 weight = weight.contiguous() if weight is not None else None 

588 out = torch.empty(shape, dtype=torch.float32, device=inp.device) 

589 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N) 

590 

591 if tgt.ndim == dim: 

592 # target probabilities 

593 with torch_device_fn.device(inp.device): 

594 celoss_probability_kernel[grid]( 

595 inp, 

596 tgt, 

597 weight, 

598 out, 

599 label_smoothing, 

600 C, 

601 D, 

602 ) 

603 elif label_smoothing == 0: 

604 # target indices 

605 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device) 

606 with torch_device_fn.device(inp.device): 

607 celoss_indices_kernel[grid]( 

608 inp, 

609 tgt, 

610 weight, 

611 out, 

612 w_tgt, 

613 ignore_index, 

614 C, 

615 D, 

616 ) 

617 else: 

618 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device) 

619 with torch_device_fn.device(inp.device): 

620 celoss_indices_smooth_kernel[grid]( 

621 inp, 

622 tgt, 

623 weight, 

624 out, 

625 w_tgt, 

626 ignore_index, 

627 label_smoothing, 

628 C, 

629 D, 

630 ) 

631 

632 if reduction == 1: # MEAN 

633 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device) 

634 if tgt.ndim == dim: 

635 sum_and_scale[(1,)](out, out_reduce, N * D, False, scale=N * D) 

636 else: 

637 wgt_sum = torch.empty([], dtype=torch.float32, device=inp.device) 

638 sum_and_scale[(1,)]( 

639 out, out_reduce, N * D, True, scale=w_tgt, mean_num=wgt_sum 

640 ) 

641 out = out_reduce 

642 elif reduction == 2: # SUM 

643 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device) 

644 sum_and_scale[(1,)](out, out_reduce, N * D, False) 

645 out = out_reduce 

646 

647 if inp.requires_grad: 

648 ctx.save_for_backward(inp, tgt, weight) 

649 ctx.N = N 

650 ctx.C = C 

651 ctx.D = D 

652 ctx.ignore_index = ignore_index 

653 ctx.label_smoothing = label_smoothing 

654 ctx.shape = shape 

655 ctx.mean_num = 1 

656 if reduction == 1: 

657 ctx.mean_num = N * D if tgt.ndim == dim else wgt_sum 

658 

659 return out.to(inp.dtype) 

660 

661 @staticmethod 

662 def backward(ctx, out_grad): 

663 logger.debug("GEMS CrossEntropyLoss VJP") 

664 

665 inp, tgt, weight = ctx.saved_tensors 

666 N = ctx.N 

667 C = ctx.C 

668 D = ctx.D 

669 ignore_index = ctx.ignore_index 

670 label_smoothing = ctx.label_smoothing 

671 mean_num = ( 

672 1 / ctx.mean_num.item() 

673 if isinstance(ctx.mean_num, torch.Tensor) 

674 else 1 / ctx.mean_num 

675 ) 

676 shape = ctx.shape 

677 

678 out_grad = out_grad.broadcast_to(shape).contiguous() 

679 

680 inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device) 

681 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N) 

682 if tgt.ndim == inp.ndim: 

683 celoss_probability_bwd[grid]( 

684 out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, C, D 

685 ) 

686 elif label_smoothing == 0: 

687 celoss_indices_bwd[grid]( 

688 out_grad, inp, tgt, weight, inp_grad, ignore_index, mean_num, C, D 

689 ) 

690 else: 

691 celoss_indices_smooth_bwd[grid]( 

692 out_grad, 

693 inp, 

694 tgt, 

695 weight, 

696 inp_grad, 

697 ignore_index, 

698 label_smoothing, 

699 mean_num, 

700 C, 

701 D, 

702 ) 

703 return inp_grad, None, None, None, None, None 

704 

705 

706def cross_entropy_loss( 

707 inp, target, weight=None, reduction="mean", ignore_index=-100, label_smoothing=0.0 

708): 

709 return CrossEntropyLoss.apply( 

710 inp, 

711 target, 

712 weight, 

713 _Reduction.get_enum(reduction), 

714 ignore_index, 

715 label_smoothing, 

716 )