Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/cross_entropy_loss.py: 0%

431 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16def heur_block_c(args): 

17 bc = triton.next_power_of_2(triton.cdiv(args["C"], 12)) 

18 return bc if bc > 64 else 64 

19 # return triton.cdiv(args["C"], 12) 

20 

21 

22def heur_block_d(args): 

23 # return args["D"] 

24 return triton.cdiv(args["D"], 12) 

25 

26 

27@libentry() 

28# @triton.autotune( 

29# configs=runtime.get_tuned_config("cross_entropy_loss"), 

30# key=["C", "D"], 

31# ) 

32@triton.heuristics( 

33 values={ 

34 "BLOCK_C": heur_block_c, 

35 "BLOCK_D": heur_block_d, 

36 }, 

37) 

38@triton.jit(do_not_specialize=["ignore_index"]) 

39def celoss_indices_kernel( 

40 inp_ptr, 

41 tgt_ptr, 

42 w_ptr, 

43 out_ptr, 

44 w_tgt_ptr, 

45 ignore_index, 

46 C, 

47 D, 

48 BLOCK_C: tl.constexpr, 

49 BLOCK_D: tl.constexpr, 

50): 

51 pid_d = tle.program_id(0) 

52 pid_n = tle.program_id(1) 

53 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

54 

55 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

56 tgt_mask = offset_d < D 

57 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

58 

59 ignore_mask = not (tgt == ignore_index) and tgt_mask 

60 

61 tmp_max = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

62 tmp_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

63 

64 for off in range(0, C, BLOCK_C): 

65 offset_c = off + tl.arange(0, BLOCK_C) 

66 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

67 inp_mask = offset_c[None, :] < C and offset_d[:, None] < D 

68 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

69 cur_max = tl.maximum(tmp_max, inp) 

70 cur_exp = tl.exp(inp - cur_max) 

71 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

72 tmp_max = cur_max 

73 final_max = tl.max(tmp_max, axis=1) 

74 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[:, None]) 

75 final_sum = tl.log(tl.sum(tmp_sum, axis=1)) 

76 

77 inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt + offset_d * C 

78 inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32) 

79 

80 out = final_sum + final_max - inp_tgt 

81 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d 

82 

83 if w_ptr is None: 

84 w_tgt = ignore_mask 

85 else: 

86 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32) 

87 

88 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask) 

89 out *= w_tgt 

90 out_ptrs = out_ptr + pid_n * D + offset_d 

91 tl.store(out_ptrs, out, mask=tgt_mask) 

92 

93 

94@libentry() 

95# @triton.autotune( 

96# configs=runtime.get_tuned_config("cross_entropy_loss"), 

97# key=["C", "D"], 

98# ) 

99@triton.heuristics( 

100 values={ 

101 "BLOCK_C": heur_block_c, 

102 "BLOCK_D": heur_block_d, 

103 }, 

104) 

105@triton.jit(do_not_specialize=["label_smoothing"]) 

106def celoss_probability_kernel( 

107 inp_ptr, 

108 tgt_ptr, 

109 w_ptr, 

110 out_ptr, 

111 label_smoothing, 

112 C, 

113 D, 

114 BLOCK_C: tl.constexpr, 

115 BLOCK_D: tl.constexpr, 

116): 

117 pid_d = tle.program_id(0) 

118 pid_n = tle.program_id(1) 

119 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

120 

121 tmp_max = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

122 tmp_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

123 

124 for off in range(0, C, BLOCK_C): 

125 offset_c = off + tl.arange(0, BLOCK_C) 

126 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

127 inp_mask = offset_c[None, :] < C and offset_d[:, None] < D 

128 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

129 cur_max = tl.maximum(tmp_max, inp) 

130 cur_exp = tl.exp(inp - cur_max) 

131 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

132 tmp_max = cur_max 

133 final_max = tl.max(tmp_max, axis=1) 

134 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[:, None]) 

135 final_sum = tl.log(tl.sum(tmp_sum, axis=1)) 

136 

137 _sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

138 for off in range(0, C, BLOCK_C): 

139 offset_c = off + tl.arange(0, BLOCK_C) 

140 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

141 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

142 mask = offset_c[None, :] < C and offset_d[:, None] < D 

143 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

144 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

145 tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C 

146 log = final_sum[:, None] + final_max[:, None] - inp 

147 w_mask = offset_c < C 

148 if w_ptr is None: 

149 w = w_mask 

150 else: 

151 w = tl.load(w_ptr + offset_c, mask=w_mask, other=0).to(tl.float32) 

152 _sum += log * tgt * w[None, :] 

153 

154 out = tl.sum(_sum, axis=1) 

155 out_ptrs = out_ptr + pid_n * D + offset_d 

156 tl.store(out_ptrs, out, mask=offset_d < D) 

157 

158 

159@libentry() 

160# @triton.autotune( 

161# configs=runtime.get_tuned_config("cross_entropy_loss"), 

162# key=["C", "D"], 

163# ) 

164@triton.heuristics( 

165 values={ 

166 "BLOCK_C": heur_block_c, 

167 "BLOCK_D": heur_block_d, 

168 }, 

169) 

170@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"]) 

171def celoss_indices_smooth_kernel( 

172 inp_ptr, 

173 tgt_ptr, 

174 w_ptr, 

175 out_ptr, 

176 w_tgt_ptr, 

177 ignore_index, 

178 label_smoothing, 

179 C, 

180 D, 

181 BLOCK_C: tl.constexpr, 

182 BLOCK_D: tl.constexpr, 

183): 

184 pid_d = tle.program_id(0) 

185 pid_n = tle.program_id(1) 

186 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

187 

188 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

189 tgt_mask = offset_d < D 

190 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

191 

192 ignore_mask = not (tgt == ignore_index) and tgt_mask 

193 

194 if w_ptr is None: 

195 w_tgt = ignore_mask 

196 else: 

197 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0) 

198 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d 

199 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask) 

200 

201 tmp_max = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

202 tmp_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

203 

204 for off in range(0, C, BLOCK_C): 

205 offset_c = off + tl.arange(0, BLOCK_C) 

206 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

207 mask = offset_c[None, :] < C and offset_d[:, None] < D 

208 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32) 

209 cur_max = tl.maximum(tmp_max, inp) 

210 cur_exp = tl.exp(inp - cur_max) 

211 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

212 tmp_max = cur_max 

213 final_max = tl.max(tmp_max, axis=1)[:, None] 

214 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

215 final_sum = tl.log(tl.sum(tmp_sum, axis=1))[:, None] 

216 final_sum_max = final_sum + final_max 

217 

218 _sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

219 for off in range(0, C, BLOCK_C): 

220 offset_c = off + tl.arange(0, BLOCK_C) 

221 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

222 mask = offset_c[None, :] < C and offset_d[:, None] < D 

223 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

224 

225 w_mask = offset_c < C 

226 if w_ptr is None: 

227 w = w_mask 

228 else: 

229 w = tl.load(w_ptr + offset_c, w_mask, other=0).to(tl.float32) 

230 

231 smooth = tl.where( 

232 offset_c[None, :] == tgt[:, None], 

233 1 - label_smoothing + label_smoothing / C, 

234 label_smoothing / C, 

235 ).to(tl.float32) 

236 

237 log = final_sum_max - inp 

238 _sum += log * smooth * w[None, :] 

239 

240 out = tl.sum(_sum, axis=1) 

241 out = tl.where(ignore_mask, out, 0) 

242 out_ptrs = out_ptr + pid_n * D + offset_d 

243 tl.store(out_ptrs, out, mask=tgt_mask) 

244 

245 

246@libentry() 

247# @triton.autotune( 

248# configs=runtime.get_tuned_config("cross_entropy_loss"), 

249# key=["C", "D"], 

250# ) 

251@triton.heuristics( 

252 values={ 

253 "BLOCK_C": heur_block_c, 

254 "BLOCK_D": heur_block_d, 

255 }, 

256) 

257@triton.jit(do_not_specialize=["ignore_index", "mean_num"]) 

258def celoss_indices_bwd( 

259 out_grad_ptr, 

260 inp_ptr, 

261 tgt_ptr, 

262 w_ptr, 

263 inp_grad_ptr, 

264 ignore_index, 

265 mean_num, 

266 C, 

267 D, 

268 BLOCK_C: tl.constexpr, 

269 BLOCK_D: tl.constexpr, 

270): 

271 pid_d = tle.program_id(0) 

272 pid_n = tle.program_id(1) 

273 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

274 

275 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

276 tgt_mask = offset_d < D 

277 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

278 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

279 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32) 

280 

281 ignore_mask = not (tgt == ignore_index) and tgt_mask 

282 

283 if w_ptr is None: 

284 w_tgt = ignore_mask 

285 else: 

286 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32) 

287 

288 tmp_max = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

289 tmp_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

290 

291 for off in range(0, C, BLOCK_C): 

292 offset_c = off + tl.arange(0, BLOCK_C) 

293 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

294 inp_mask = offset_c[None, :] < C and offset_d[:, None] < D 

295 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

296 cur_max = tl.maximum(tmp_max, inp) 

297 cur_exp = tl.exp(inp - cur_max) 

298 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

299 tmp_max = cur_max 

300 final_max = tl.max(tmp_max, axis=1) 

301 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[:, None]) 

302 final_sum = tl.sum(tmp_sum, axis=1) 

303 

304 for off in range(0, C, BLOCK_C): 

305 offset_c = off + tl.arange(0, BLOCK_C) 

306 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

307 inp_mask = offset_c[None, :] < C and offset_d[:, None] < D 

308 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

309 minus_one = (offset_c[None, :] == tgt[:, None]).to(tl.float32) 

310 inp_grad = ( 

311 (tl.exp(inp - final_max[:, None]) / final_sum[:, None] - minus_one) 

312 * w_tgt[:, None] 

313 * out_grad[:, None] 

314 * mean_num 

315 ) 

316 inp_grad_ptrs = ( 

317 inp_grad_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

318 ) 

319 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask[:, None]) 

320 

321 

322@libentry() 

323# @triton.autotune( 

324# configs=runtime.get_tuned_config("cross_entropy_loss"), 

325# key=["C", "D"], 

326# ) 

327@triton.heuristics( 

328 values={ 

329 "BLOCK_C": heur_block_c, 

330 "BLOCK_D": heur_block_d, 

331 }, 

332) 

333@triton.jit(do_not_specialize=["label_smoothing", "mean_num"]) 

334def celoss_probability_bwd( 

335 out_grad_ptr, 

336 inp_ptr, 

337 tgt_ptr, 

338 w_ptr, 

339 inp_grad_ptr, 

340 label_smoothing, 

341 mean_num, 

342 C, 

343 D, 

344 BLOCK_C: tl.constexpr, 

345 BLOCK_D: tl.constexpr, 

346): 

347 pid_d = tle.program_id(0) 

348 pid_n = tle.program_id(1) 

349 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

350 

351 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

352 out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32) 

353 

354 tmp_max = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

355 tmp_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

356 w_tgt_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

357 

358 for off in range(0, C, BLOCK_C): 

359 offset_c = off + tl.arange(0, BLOCK_C) 

360 mask = offset_c[None, :] < C and offset_d[:, None] < D 

361 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

362 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32) 

363 

364 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

365 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

366 tgt = tgt * (1 - label_smoothing) + label_smoothing / C 

367 

368 w_mask = offset_c < C 

369 if w_ptr is None: 

370 w = w_mask 

371 else: 

372 w_ptrs = w_ptr + offset_c 

373 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

374 

375 w_tgt_sum += tgt * w[None, :] 

376 

377 cur_max = tl.maximum(tmp_max, inp) 

378 cur_exp = tl.exp(inp - cur_max) 

379 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

380 tmp_max = cur_max 

381 final_max = tl.max(tmp_max, axis=1)[:, None] 

382 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

383 final_sum = tl.sum(tmp_sum, axis=1)[:, None] 

384 w_tgt_sum = tl.sum(w_tgt_sum, axis=1)[:, None] 

385 

386 for off in range(0, C, BLOCK_C): 

387 offset_c = off + tl.arange(0, BLOCK_C) 

388 offset = pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

389 inp_ptrs = inp_ptr + offset 

390 mask = offset_c[None, :] < C and offset_d[:, None] < D 

391 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32) 

392 

393 tgt_ptrs = tgt_ptr + offset 

394 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32) 

395 tgt = tgt * (1 - label_smoothing) + label_smoothing / C 

396 

397 w_mask = offset_c < C 

398 if w_ptr is None: 

399 w = w_mask 

400 else: 

401 w_ptrs = w_ptr + offset_c 

402 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

403 

404 grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - tgt * w[None, :] 

405 inp_grad = grad * out_grad[:, None] * mean_num 

406 

407 inp_grad_ptrs = inp_grad_ptr + offset 

408 tl.store(inp_grad_ptrs, inp_grad, mask) 

409 

410 

411@libentry() 

412# @triton.autotune( 

413# configs=runtime.get_tuned_config("cross_entropy_loss"), 

414# key=["C", "D"], 

415# ) 

416@triton.heuristics( 

417 values={ 

418 "BLOCK_C": heur_block_c, 

419 "BLOCK_D": heur_block_d, 

420 }, 

421) 

422@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"]) 

423def celoss_indices_smooth_bwd( 

424 out_grad_ptr, 

425 inp_ptr, 

426 tgt_ptr, 

427 w_ptr, 

428 inp_grad_ptr, 

429 ignore_index, 

430 label_smoothing, 

431 mean_num, 

432 C, 

433 D, 

434 BLOCK_C: tl.constexpr, 

435 BLOCK_D: tl.constexpr, 

436): 

437 pid_d = tle.program_id(0) 

438 pid_n = tle.program_id(1) 

439 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

440 

441 tgt_ptrs = tgt_ptr + pid_n * D + offset_d 

442 tgt_mask = offset_d < D 

443 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0) 

444 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d 

445 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32) 

446 

447 ignore_mask = (tgt != ignore_index)[:, None] 

448 

449 tmp_max = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

450 tmp_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

451 w_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32) 

452 

453 for off in range(0, C, BLOCK_C): 

454 offset_c = off + tl.arange(0, BLOCK_C) 

455 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

456 inp_mask = offset_c[None, :] < C and offset_d[:, None] < D 

457 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

458 

459 w_mask = offset_c < C 

460 if w_ptr is None: 

461 w = w_mask 

462 else: 

463 w_ptrs = w_ptr + offset_c 

464 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

465 

466 smooth = tl.full([BLOCK_D, BLOCK_C], label_smoothing / C, dtype=tl.float32) 

467 smooth = tl.where( 

468 offset_c[None, :] == tgt[:, None], 

469 1 - label_smoothing + label_smoothing / C, 

470 smooth, 

471 ) 

472 

473 w_sum += smooth * w[None, :] 

474 

475 cur_max = tl.maximum(tmp_max, inp) 

476 cur_exp = tl.exp(inp - cur_max) 

477 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp 

478 tmp_max = cur_max 

479 final_max = tl.max(tmp_max, axis=1)[:, None] 

480 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max) 

481 final_sum = tl.sum(tmp_sum, axis=1)[:, None] 

482 w_sum = tl.sum(w_sum, axis=1)[:, None] 

483 

484 for off in range(0, C, BLOCK_C): 

485 offset_c = off + tl.arange(0, BLOCK_C) 

486 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

487 inp_mask = offset_c[None, :] < C and offset_d[:, None] < D 

488 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32) 

489 

490 w_mask = offset_c < C 

491 if w_ptr is None: 

492 w = w_mask 

493 else: 

494 w_ptrs = w_ptr + offset_c 

495 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32) 

496 

497 smooth = tl.where( 

498 offset_c[None, :] == tgt[:, None], 

499 1 - label_smoothing + label_smoothing / C, 

500 label_smoothing / C, 

501 ) 

502 

503 grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[None, :] 

504 inp_grad = grad * out_grad[:, None] * mean_num 

505 inp_grad_ptrs = ( 

506 inp_grad_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :] 

507 ) 

508 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask) 

509 

510 

511@libentry() 

512@triton.jit 

513def sum_and_scale( 

514 inp_ptr, 

515 out_ptr, 

516 N, 

517 scalebyw: tl.constexpr, 

518 BLOCK_N: tl.constexpr = 128, 

519 scale=1.0, 

520 mean_num=None, 

521): 

522 mid_sum = tl.zeros( 

523 [ 

524 BLOCK_N, 

525 ], 

526 dtype=tl.float32, 

527 ) 

528 if scalebyw: 

529 mid_wgt = tl.zeros( 

530 [ 

531 BLOCK_N, 

532 ], 

533 dtype=tl.float32, 

534 ) 

535 for off in range(0, N, BLOCK_N): 

536 offset = off + tl.arange(0, BLOCK_N) 

537 inp_ptrs = inp_ptr + offset 

538 mask = offset < N 

539 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0) 

540 mid_sum += inp_vals 

541 wgt_ptrs = scale + offset 

542 wgt_vals = tl.load(wgt_ptrs, mask=mask, other=0.0) 

543 mid_wgt += wgt_vals 

544 out_val = tl.sum(mid_sum) 

545 scale_val = tl.sum(mid_wgt) 

546 tl.store(mean_num, scale_val) 

547 else: 

548 for off in range(0, N, BLOCK_N): 

549 offset = off + tl.arange(0, BLOCK_N) 

550 inp_ptrs = inp_ptr + offset 

551 mask = offset < N 

552 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0) 

553 mid_sum += inp_vals 

554 out_val = tl.sum(mid_sum) 

555 scale_val = scale 

556 out_val /= scale_val 

557 tl.store(out_ptr, out_val) 

558 

559 

560class CrossEntropyLoss(torch.autograd.Function): 

561 @staticmethod 

562 def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing): 

563 logger.debug("GEMS CrossEntropyLoss") 

564 

565 shape = list(inp.shape) 

566 dim = inp.ndim 

567 N = 1 if dim == 1 else shape[0] 

568 

569 C = shape[0] if dim == 1 else shape[1] 

570 D = inp.numel() // N // C 

571 axis = 0 if dim == 1 else 1 

572 del shape[axis] 

573 

574 grad = inp.requires_grad 

575 if dim == 3: 

576 inp = inp.transpose(1, -1) 

577 D_new = inp.shape[1] 

578 else: 

579 D_new = D 

580 

581 inp = inp.contiguous() 

582 if dim == 3: 

583 target = target.transpose(1, -1).contiguous() 

584 tgt = target.contiguous() 

585 weight = weight.contiguous() if weight is not None else None 

586 out = torch.empty(shape, dtype=torch.float32, device=inp.device) 

587 grid = lambda meta: (triton.cdiv(D_new, meta["BLOCK_D"]), N) 

588 

589 if tgt.ndim == dim: 

590 # target probabilities 

591 with torch_device_fn.device(inp.device): 

592 if shape != [1]: 

593 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

594 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

595 os.environ["TRITONXPU_CLOSE_OPTIMIZE"] = "1" 

596 celoss_probability_kernel[grid]( 

597 inp, 

598 tgt, 

599 weight, 

600 out, 

601 label_smoothing, 

602 C, 

603 D, 

604 ) 

605 if shape != [1]: 

606 if "TRITONXPU_OTHER_SIM" in os.environ: 

607 del os.environ["TRITONXPU_OTHER_SIM"] 

608 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

609 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

610 if "TRITONXPU_CLOSE_OPTIMIZE" in os.environ: 

611 del os.environ["TRITONXPU_CLOSE_OPTIMIZE"] 

612 elif label_smoothing == 0: 

613 # target indices 

614 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device) 

615 with torch_device_fn.device(inp.device): 

616 celoss_indices_kernel[grid]( 

617 inp, 

618 tgt, 

619 weight, 

620 out, 

621 w_tgt, 

622 ignore_index, 

623 C, 

624 D, 

625 ) 

626 if dim > 1: 

627 out = out.view(shape[:axis] + shape[axis + 1 :]) 

628 else: 

629 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device) 

630 with torch_device_fn.device(inp.device): 

631 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

632 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

633 os.environ["TRITONXPU_CLOSE_OPTIMIZE"] = "1" 

634 celoss_indices_smooth_kernel[grid]( 

635 inp, 

636 tgt, 

637 weight, 

638 out, 

639 w_tgt, 

640 ignore_index, 

641 label_smoothing, 

642 C, 

643 D, 

644 ) 

645 if "TRITONXPU_OTHER_SIM" in os.environ: 

646 del os.environ["TRITONXPU_OTHER_SIM"] 

647 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

648 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

649 if "TRITONXPU_CLOSE_OPTIMIZE" in os.environ: 

650 del os.environ["TRITONXPU_CLOSE_OPTIMIZE"] 

651 if reduction == "mean": # MEAN 

652 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device) 

653 if tgt.ndim == dim: 

654 sum_and_scale[(1,)](out, out_reduce, N * D, False, scale=N * D) 

655 else: 

656 wgt_sum = torch.empty([], dtype=torch.float32, device=inp.device) 

657 sum_and_scale[(1,)]( 

658 out, out_reduce, N * D, True, scale=w_tgt, mean_num=wgt_sum 

659 ) 

660 out = out_reduce 

661 elif reduction == "sum": # SUM 

662 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device) 

663 sum_and_scale[(1,)](out, out_reduce, N * D, False) 

664 out = out_reduce 

665 

666 if grad: 

667 ctx.save_for_backward(inp, tgt, weight) 

668 ctx.N = N 

669 ctx.C = C 

670 ctx.D = D 

671 ctx.ignore_index = ignore_index 

672 ctx.label_smoothing = label_smoothing 

673 ctx.shape = shape 

674 ctx.mean_num = 1 

675 if reduction == "mean": 

676 ctx.mean_num = N * D if tgt.ndim == dim else wgt_sum 

677 

678 return out.to(inp.dtype) 

679 

680 @staticmethod 

681 def backward(ctx, out_grad): 

682 logger.debug("GEMS CrossEntropyLoss VJP") 

683 

684 inp, tgt, weight = ctx.saved_tensors 

685 N = ctx.N 

686 C = ctx.C 

687 D = ctx.D 

688 ignore_index = ctx.ignore_index 

689 label_smoothing = ctx.label_smoothing 

690 mean_num = ( 

691 1 / ctx.mean_num.item() 

692 if isinstance(ctx.mean_num, torch.Tensor) 

693 else 1 / ctx.mean_num 

694 ) 

695 

696 shape = ctx.shape 

697 out_grad = out_grad.broadcast_to(shape).contiguous() 

698 dim = inp.ndim 

699 inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device) 

700 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N) 

701 

702 if tgt.ndim == inp.ndim: 

703 if shape != [1]: 

704 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

705 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

706 os.environ["TRITONXPU_CLOSE_OPTIMIZE"] = "1" 

707 celoss_probability_bwd[grid]( 

708 out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, C, D 

709 ) 

710 if shape != [1]: 

711 if "TRITONXPU_OTHER_SIM" in os.environ: 

712 del os.environ["TRITONXPU_OTHER_SIM"] 

713 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

714 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

715 if "TRITONXPU_CLOSE_OPTIMIZE" in os.environ: 

716 del os.environ["TRITONXPU_CLOSE_OPTIMIZE"] 

717 elif label_smoothing == 0: 

718 celoss_indices_bwd[grid]( 

719 out_grad, inp, tgt, weight, inp_grad, ignore_index, mean_num, C, D 

720 ) 

721 else: 

722 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

723 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

724 os.environ["TRITONXPU_CLOSE_OPTIMIZE"] = "1" 

725 celoss_indices_smooth_bwd[grid]( 

726 out_grad, 

727 inp, 

728 tgt, 

729 weight, 

730 inp_grad, 

731 ignore_index, 

732 label_smoothing, 

733 mean_num, 

734 C, 

735 D, 

736 ) 

737 if "TRITONXPU_OTHER_SIM" in os.environ: 

738 del os.environ["TRITONXPU_OTHER_SIM"] 

739 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

740 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

741 if "TRITONXPU_CLOSE_OPTIMIZE" in os.environ: 

742 del os.environ["TRITONXPU_CLOSE_OPTIMIZE"] 

743 if dim == 3: 

744 inp_grad = inp_grad.transpose(1, -1).contiguous() 

745 return inp_grad, None, None, None, None, None 

746 

747 

748def cross_entropy_loss( 

749 inp, target, weight=None, reduction=1, ignore_index=-100, label_smoothing=0.0 

750): 

751 return CrossEntropyLoss.apply( 

752 inp, target, weight, reduction, ignore_index, label_smoothing 

753 )