Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/cross_entropy_loss.py: 0%
431 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16def heur_block_c(args):
17 bc = triton.next_power_of_2(triton.cdiv(args["C"], 12))
18 return bc if bc > 64 else 64
19 # return triton.cdiv(args["C"], 12)
22def heur_block_d(args):
23 # return args["D"]
24 return triton.cdiv(args["D"], 12)
27@libentry()
28# @triton.autotune(
29# configs=runtime.get_tuned_config("cross_entropy_loss"),
30# key=["C", "D"],
31# )
32@triton.heuristics(
33 values={
34 "BLOCK_C": heur_block_c,
35 "BLOCK_D": heur_block_d,
36 },
37)
38@triton.jit(do_not_specialize=["ignore_index"])
39def celoss_indices_kernel(
40 inp_ptr,
41 tgt_ptr,
42 w_ptr,
43 out_ptr,
44 w_tgt_ptr,
45 ignore_index,
46 C,
47 D,
48 BLOCK_C: tl.constexpr,
49 BLOCK_D: tl.constexpr,
50):
51 pid_d = tle.program_id(0)
52 pid_n = tle.program_id(1)
53 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
55 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
56 tgt_mask = offset_d < D
57 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
59 ignore_mask = not (tgt == ignore_index) and tgt_mask
61 tmp_max = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
62 tmp_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
64 for off in range(0, C, BLOCK_C):
65 offset_c = off + tl.arange(0, BLOCK_C)
66 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
67 inp_mask = offset_c[None, :] < C and offset_d[:, None] < D
68 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
69 cur_max = tl.maximum(tmp_max, inp)
70 cur_exp = tl.exp(inp - cur_max)
71 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
72 tmp_max = cur_max
73 final_max = tl.max(tmp_max, axis=1)
74 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[:, None])
75 final_sum = tl.log(tl.sum(tmp_sum, axis=1))
77 inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt + offset_d * C
78 inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32)
80 out = final_sum + final_max - inp_tgt
81 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
83 if w_ptr is None:
84 w_tgt = ignore_mask
85 else:
86 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
88 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask)
89 out *= w_tgt
90 out_ptrs = out_ptr + pid_n * D + offset_d
91 tl.store(out_ptrs, out, mask=tgt_mask)
94@libentry()
95# @triton.autotune(
96# configs=runtime.get_tuned_config("cross_entropy_loss"),
97# key=["C", "D"],
98# )
99@triton.heuristics(
100 values={
101 "BLOCK_C": heur_block_c,
102 "BLOCK_D": heur_block_d,
103 },
104)
105@triton.jit(do_not_specialize=["label_smoothing"])
106def celoss_probability_kernel(
107 inp_ptr,
108 tgt_ptr,
109 w_ptr,
110 out_ptr,
111 label_smoothing,
112 C,
113 D,
114 BLOCK_C: tl.constexpr,
115 BLOCK_D: tl.constexpr,
116):
117 pid_d = tle.program_id(0)
118 pid_n = tle.program_id(1)
119 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
121 tmp_max = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
122 tmp_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
124 for off in range(0, C, BLOCK_C):
125 offset_c = off + tl.arange(0, BLOCK_C)
126 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
127 inp_mask = offset_c[None, :] < C and offset_d[:, None] < D
128 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
129 cur_max = tl.maximum(tmp_max, inp)
130 cur_exp = tl.exp(inp - cur_max)
131 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
132 tmp_max = cur_max
133 final_max = tl.max(tmp_max, axis=1)
134 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[:, None])
135 final_sum = tl.log(tl.sum(tmp_sum, axis=1))
137 _sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
138 for off in range(0, C, BLOCK_C):
139 offset_c = off + tl.arange(0, BLOCK_C)
140 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
141 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
142 mask = offset_c[None, :] < C and offset_d[:, None] < D
143 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
144 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
145 tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C
146 log = final_sum[:, None] + final_max[:, None] - inp
147 w_mask = offset_c < C
148 if w_ptr is None:
149 w = w_mask
150 else:
151 w = tl.load(w_ptr + offset_c, mask=w_mask, other=0).to(tl.float32)
152 _sum += log * tgt * w[None, :]
154 out = tl.sum(_sum, axis=1)
155 out_ptrs = out_ptr + pid_n * D + offset_d
156 tl.store(out_ptrs, out, mask=offset_d < D)
159@libentry()
160# @triton.autotune(
161# configs=runtime.get_tuned_config("cross_entropy_loss"),
162# key=["C", "D"],
163# )
164@triton.heuristics(
165 values={
166 "BLOCK_C": heur_block_c,
167 "BLOCK_D": heur_block_d,
168 },
169)
170@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"])
171def celoss_indices_smooth_kernel(
172 inp_ptr,
173 tgt_ptr,
174 w_ptr,
175 out_ptr,
176 w_tgt_ptr,
177 ignore_index,
178 label_smoothing,
179 C,
180 D,
181 BLOCK_C: tl.constexpr,
182 BLOCK_D: tl.constexpr,
183):
184 pid_d = tle.program_id(0)
185 pid_n = tle.program_id(1)
186 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
188 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
189 tgt_mask = offset_d < D
190 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
192 ignore_mask = not (tgt == ignore_index) and tgt_mask
194 if w_ptr is None:
195 w_tgt = ignore_mask
196 else:
197 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0)
198 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
199 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask)
201 tmp_max = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
202 tmp_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
204 for off in range(0, C, BLOCK_C):
205 offset_c = off + tl.arange(0, BLOCK_C)
206 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
207 mask = offset_c[None, :] < C and offset_d[:, None] < D
208 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32)
209 cur_max = tl.maximum(tmp_max, inp)
210 cur_exp = tl.exp(inp - cur_max)
211 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
212 tmp_max = cur_max
213 final_max = tl.max(tmp_max, axis=1)[:, None]
214 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
215 final_sum = tl.log(tl.sum(tmp_sum, axis=1))[:, None]
216 final_sum_max = final_sum + final_max
218 _sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
219 for off in range(0, C, BLOCK_C):
220 offset_c = off + tl.arange(0, BLOCK_C)
221 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
222 mask = offset_c[None, :] < C and offset_d[:, None] < D
223 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
225 w_mask = offset_c < C
226 if w_ptr is None:
227 w = w_mask
228 else:
229 w = tl.load(w_ptr + offset_c, w_mask, other=0).to(tl.float32)
231 smooth = tl.where(
232 offset_c[None, :] == tgt[:, None],
233 1 - label_smoothing + label_smoothing / C,
234 label_smoothing / C,
235 ).to(tl.float32)
237 log = final_sum_max - inp
238 _sum += log * smooth * w[None, :]
240 out = tl.sum(_sum, axis=1)
241 out = tl.where(ignore_mask, out, 0)
242 out_ptrs = out_ptr + pid_n * D + offset_d
243 tl.store(out_ptrs, out, mask=tgt_mask)
246@libentry()
247# @triton.autotune(
248# configs=runtime.get_tuned_config("cross_entropy_loss"),
249# key=["C", "D"],
250# )
251@triton.heuristics(
252 values={
253 "BLOCK_C": heur_block_c,
254 "BLOCK_D": heur_block_d,
255 },
256)
257@triton.jit(do_not_specialize=["ignore_index", "mean_num"])
258def celoss_indices_bwd(
259 out_grad_ptr,
260 inp_ptr,
261 tgt_ptr,
262 w_ptr,
263 inp_grad_ptr,
264 ignore_index,
265 mean_num,
266 C,
267 D,
268 BLOCK_C: tl.constexpr,
269 BLOCK_D: tl.constexpr,
270):
271 pid_d = tle.program_id(0)
272 pid_n = tle.program_id(1)
273 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
275 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
276 tgt_mask = offset_d < D
277 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
278 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
279 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)
281 ignore_mask = not (tgt == ignore_index) and tgt_mask
283 if w_ptr is None:
284 w_tgt = ignore_mask
285 else:
286 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
288 tmp_max = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
289 tmp_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
291 for off in range(0, C, BLOCK_C):
292 offset_c = off + tl.arange(0, BLOCK_C)
293 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
294 inp_mask = offset_c[None, :] < C and offset_d[:, None] < D
295 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
296 cur_max = tl.maximum(tmp_max, inp)
297 cur_exp = tl.exp(inp - cur_max)
298 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
299 tmp_max = cur_max
300 final_max = tl.max(tmp_max, axis=1)
301 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[:, None])
302 final_sum = tl.sum(tmp_sum, axis=1)
304 for off in range(0, C, BLOCK_C):
305 offset_c = off + tl.arange(0, BLOCK_C)
306 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
307 inp_mask = offset_c[None, :] < C and offset_d[:, None] < D
308 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
309 minus_one = (offset_c[None, :] == tgt[:, None]).to(tl.float32)
310 inp_grad = (
311 (tl.exp(inp - final_max[:, None]) / final_sum[:, None] - minus_one)
312 * w_tgt[:, None]
313 * out_grad[:, None]
314 * mean_num
315 )
316 inp_grad_ptrs = (
317 inp_grad_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
318 )
319 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask[:, None])
322@libentry()
323# @triton.autotune(
324# configs=runtime.get_tuned_config("cross_entropy_loss"),
325# key=["C", "D"],
326# )
327@triton.heuristics(
328 values={
329 "BLOCK_C": heur_block_c,
330 "BLOCK_D": heur_block_d,
331 },
332)
333@triton.jit(do_not_specialize=["label_smoothing", "mean_num"])
334def celoss_probability_bwd(
335 out_grad_ptr,
336 inp_ptr,
337 tgt_ptr,
338 w_ptr,
339 inp_grad_ptr,
340 label_smoothing,
341 mean_num,
342 C,
343 D,
344 BLOCK_C: tl.constexpr,
345 BLOCK_D: tl.constexpr,
346):
347 pid_d = tle.program_id(0)
348 pid_n = tle.program_id(1)
349 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
351 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
352 out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32)
354 tmp_max = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
355 tmp_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
356 w_tgt_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
358 for off in range(0, C, BLOCK_C):
359 offset_c = off + tl.arange(0, BLOCK_C)
360 mask = offset_c[None, :] < C and offset_d[:, None] < D
361 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
362 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32)
364 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
365 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
366 tgt = tgt * (1 - label_smoothing) + label_smoothing / C
368 w_mask = offset_c < C
369 if w_ptr is None:
370 w = w_mask
371 else:
372 w_ptrs = w_ptr + offset_c
373 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
375 w_tgt_sum += tgt * w[None, :]
377 cur_max = tl.maximum(tmp_max, inp)
378 cur_exp = tl.exp(inp - cur_max)
379 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
380 tmp_max = cur_max
381 final_max = tl.max(tmp_max, axis=1)[:, None]
382 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
383 final_sum = tl.sum(tmp_sum, axis=1)[:, None]
384 w_tgt_sum = tl.sum(w_tgt_sum, axis=1)[:, None]
386 for off in range(0, C, BLOCK_C):
387 offset_c = off + tl.arange(0, BLOCK_C)
388 offset = pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
389 inp_ptrs = inp_ptr + offset
390 mask = offset_c[None, :] < C and offset_d[:, None] < D
391 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
393 tgt_ptrs = tgt_ptr + offset
394 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
395 tgt = tgt * (1 - label_smoothing) + label_smoothing / C
397 w_mask = offset_c < C
398 if w_ptr is None:
399 w = w_mask
400 else:
401 w_ptrs = w_ptr + offset_c
402 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
404 grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - tgt * w[None, :]
405 inp_grad = grad * out_grad[:, None] * mean_num
407 inp_grad_ptrs = inp_grad_ptr + offset
408 tl.store(inp_grad_ptrs, inp_grad, mask)
411@libentry()
412# @triton.autotune(
413# configs=runtime.get_tuned_config("cross_entropy_loss"),
414# key=["C", "D"],
415# )
416@triton.heuristics(
417 values={
418 "BLOCK_C": heur_block_c,
419 "BLOCK_D": heur_block_d,
420 },
421)
422@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"])
423def celoss_indices_smooth_bwd(
424 out_grad_ptr,
425 inp_ptr,
426 tgt_ptr,
427 w_ptr,
428 inp_grad_ptr,
429 ignore_index,
430 label_smoothing,
431 mean_num,
432 C,
433 D,
434 BLOCK_C: tl.constexpr,
435 BLOCK_D: tl.constexpr,
436):
437 pid_d = tle.program_id(0)
438 pid_n = tle.program_id(1)
439 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
441 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
442 tgt_mask = offset_d < D
443 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
444 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
445 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)
447 ignore_mask = (tgt != ignore_index)[:, None]
449 tmp_max = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
450 tmp_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
451 w_sum = tl.zeros([BLOCK_D, BLOCK_C], dtype=tl.float32)
453 for off in range(0, C, BLOCK_C):
454 offset_c = off + tl.arange(0, BLOCK_C)
455 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
456 inp_mask = offset_c[None, :] < C and offset_d[:, None] < D
457 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
459 w_mask = offset_c < C
460 if w_ptr is None:
461 w = w_mask
462 else:
463 w_ptrs = w_ptr + offset_c
464 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
466 smooth = tl.full([BLOCK_D, BLOCK_C], label_smoothing / C, dtype=tl.float32)
467 smooth = tl.where(
468 offset_c[None, :] == tgt[:, None],
469 1 - label_smoothing + label_smoothing / C,
470 smooth,
471 )
473 w_sum += smooth * w[None, :]
475 cur_max = tl.maximum(tmp_max, inp)
476 cur_exp = tl.exp(inp - cur_max)
477 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
478 tmp_max = cur_max
479 final_max = tl.max(tmp_max, axis=1)[:, None]
480 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
481 final_sum = tl.sum(tmp_sum, axis=1)[:, None]
482 w_sum = tl.sum(w_sum, axis=1)[:, None]
484 for off in range(0, C, BLOCK_C):
485 offset_c = off + tl.arange(0, BLOCK_C)
486 inp_ptrs = inp_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
487 inp_mask = offset_c[None, :] < C and offset_d[:, None] < D
488 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
490 w_mask = offset_c < C
491 if w_ptr is None:
492 w = w_mask
493 else:
494 w_ptrs = w_ptr + offset_c
495 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
497 smooth = tl.where(
498 offset_c[None, :] == tgt[:, None],
499 1 - label_smoothing + label_smoothing / C,
500 label_smoothing / C,
501 )
503 grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[None, :]
504 inp_grad = grad * out_grad[:, None] * mean_num
505 inp_grad_ptrs = (
506 inp_grad_ptr + pid_n * C * D + offset_d[:, None] * C + offset_c[None, :]
507 )
508 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask)
511@libentry()
512@triton.jit
513def sum_and_scale(
514 inp_ptr,
515 out_ptr,
516 N,
517 scalebyw: tl.constexpr,
518 BLOCK_N: tl.constexpr = 128,
519 scale=1.0,
520 mean_num=None,
521):
522 mid_sum = tl.zeros(
523 [
524 BLOCK_N,
525 ],
526 dtype=tl.float32,
527 )
528 if scalebyw:
529 mid_wgt = tl.zeros(
530 [
531 BLOCK_N,
532 ],
533 dtype=tl.float32,
534 )
535 for off in range(0, N, BLOCK_N):
536 offset = off + tl.arange(0, BLOCK_N)
537 inp_ptrs = inp_ptr + offset
538 mask = offset < N
539 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0)
540 mid_sum += inp_vals
541 wgt_ptrs = scale + offset
542 wgt_vals = tl.load(wgt_ptrs, mask=mask, other=0.0)
543 mid_wgt += wgt_vals
544 out_val = tl.sum(mid_sum)
545 scale_val = tl.sum(mid_wgt)
546 tl.store(mean_num, scale_val)
547 else:
548 for off in range(0, N, BLOCK_N):
549 offset = off + tl.arange(0, BLOCK_N)
550 inp_ptrs = inp_ptr + offset
551 mask = offset < N
552 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0)
553 mid_sum += inp_vals
554 out_val = tl.sum(mid_sum)
555 scale_val = scale
556 out_val /= scale_val
557 tl.store(out_ptr, out_val)
560class CrossEntropyLoss(torch.autograd.Function):
561 @staticmethod
562 def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):
563 logger.debug("GEMS CrossEntropyLoss")
565 shape = list(inp.shape)
566 dim = inp.ndim
567 N = 1 if dim == 1 else shape[0]
569 C = shape[0] if dim == 1 else shape[1]
570 D = inp.numel() // N // C
571 axis = 0 if dim == 1 else 1
572 del shape[axis]
574 grad = inp.requires_grad
575 if dim == 3:
576 inp = inp.transpose(1, -1)
577 D_new = inp.shape[1]
578 else:
579 D_new = D
581 inp = inp.contiguous()
582 if dim == 3:
583 target = target.transpose(1, -1).contiguous()
584 tgt = target.contiguous()
585 weight = weight.contiguous() if weight is not None else None
586 out = torch.empty(shape, dtype=torch.float32, device=inp.device)
587 grid = lambda meta: (triton.cdiv(D_new, meta["BLOCK_D"]), N)
589 if tgt.ndim == dim:
590 # target probabilities
591 with torch_device_fn.device(inp.device):
592 if shape != [1]:
593 os.environ["TRITONXPU_OTHER_SIM"] = "1"
594 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
595 os.environ["TRITONXPU_CLOSE_OPTIMIZE"] = "1"
596 celoss_probability_kernel[grid](
597 inp,
598 tgt,
599 weight,
600 out,
601 label_smoothing,
602 C,
603 D,
604 )
605 if shape != [1]:
606 if "TRITONXPU_OTHER_SIM" in os.environ:
607 del os.environ["TRITONXPU_OTHER_SIM"]
608 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
609 del os.environ["TRITONXPU_STORE_MASK_SIM"]
610 if "TRITONXPU_CLOSE_OPTIMIZE" in os.environ:
611 del os.environ["TRITONXPU_CLOSE_OPTIMIZE"]
612 elif label_smoothing == 0:
613 # target indices
614 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device)
615 with torch_device_fn.device(inp.device):
616 celoss_indices_kernel[grid](
617 inp,
618 tgt,
619 weight,
620 out,
621 w_tgt,
622 ignore_index,
623 C,
624 D,
625 )
626 if dim > 1:
627 out = out.view(shape[:axis] + shape[axis + 1 :])
628 else:
629 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device)
630 with torch_device_fn.device(inp.device):
631 os.environ["TRITONXPU_OTHER_SIM"] = "1"
632 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
633 os.environ["TRITONXPU_CLOSE_OPTIMIZE"] = "1"
634 celoss_indices_smooth_kernel[grid](
635 inp,
636 tgt,
637 weight,
638 out,
639 w_tgt,
640 ignore_index,
641 label_smoothing,
642 C,
643 D,
644 )
645 if "TRITONXPU_OTHER_SIM" in os.environ:
646 del os.environ["TRITONXPU_OTHER_SIM"]
647 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
648 del os.environ["TRITONXPU_STORE_MASK_SIM"]
649 if "TRITONXPU_CLOSE_OPTIMIZE" in os.environ:
650 del os.environ["TRITONXPU_CLOSE_OPTIMIZE"]
651 if reduction == "mean": # MEAN
652 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device)
653 if tgt.ndim == dim:
654 sum_and_scale[(1,)](out, out_reduce, N * D, False, scale=N * D)
655 else:
656 wgt_sum = torch.empty([], dtype=torch.float32, device=inp.device)
657 sum_and_scale[(1,)](
658 out, out_reduce, N * D, True, scale=w_tgt, mean_num=wgt_sum
659 )
660 out = out_reduce
661 elif reduction == "sum": # SUM
662 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device)
663 sum_and_scale[(1,)](out, out_reduce, N * D, False)
664 out = out_reduce
666 if grad:
667 ctx.save_for_backward(inp, tgt, weight)
668 ctx.N = N
669 ctx.C = C
670 ctx.D = D
671 ctx.ignore_index = ignore_index
672 ctx.label_smoothing = label_smoothing
673 ctx.shape = shape
674 ctx.mean_num = 1
675 if reduction == "mean":
676 ctx.mean_num = N * D if tgt.ndim == dim else wgt_sum
678 return out.to(inp.dtype)
680 @staticmethod
681 def backward(ctx, out_grad):
682 logger.debug("GEMS CrossEntropyLoss VJP")
684 inp, tgt, weight = ctx.saved_tensors
685 N = ctx.N
686 C = ctx.C
687 D = ctx.D
688 ignore_index = ctx.ignore_index
689 label_smoothing = ctx.label_smoothing
690 mean_num = (
691 1 / ctx.mean_num.item()
692 if isinstance(ctx.mean_num, torch.Tensor)
693 else 1 / ctx.mean_num
694 )
696 shape = ctx.shape
697 out_grad = out_grad.broadcast_to(shape).contiguous()
698 dim = inp.ndim
699 inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device)
700 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N)
702 if tgt.ndim == inp.ndim:
703 if shape != [1]:
704 os.environ["TRITONXPU_OTHER_SIM"] = "1"
705 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
706 os.environ["TRITONXPU_CLOSE_OPTIMIZE"] = "1"
707 celoss_probability_bwd[grid](
708 out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, C, D
709 )
710 if shape != [1]:
711 if "TRITONXPU_OTHER_SIM" in os.environ:
712 del os.environ["TRITONXPU_OTHER_SIM"]
713 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
714 del os.environ["TRITONXPU_STORE_MASK_SIM"]
715 if "TRITONXPU_CLOSE_OPTIMIZE" in os.environ:
716 del os.environ["TRITONXPU_CLOSE_OPTIMIZE"]
717 elif label_smoothing == 0:
718 celoss_indices_bwd[grid](
719 out_grad, inp, tgt, weight, inp_grad, ignore_index, mean_num, C, D
720 )
721 else:
722 os.environ["TRITONXPU_OTHER_SIM"] = "1"
723 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
724 os.environ["TRITONXPU_CLOSE_OPTIMIZE"] = "1"
725 celoss_indices_smooth_bwd[grid](
726 out_grad,
727 inp,
728 tgt,
729 weight,
730 inp_grad,
731 ignore_index,
732 label_smoothing,
733 mean_num,
734 C,
735 D,
736 )
737 if "TRITONXPU_OTHER_SIM" in os.environ:
738 del os.environ["TRITONXPU_OTHER_SIM"]
739 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
740 del os.environ["TRITONXPU_STORE_MASK_SIM"]
741 if "TRITONXPU_CLOSE_OPTIMIZE" in os.environ:
742 del os.environ["TRITONXPU_CLOSE_OPTIMIZE"]
743 if dim == 3:
744 inp_grad = inp_grad.transpose(1, -1).contiguous()
745 return inp_grad, None, None, None, None, None
748def cross_entropy_loss(
749 inp, target, weight=None, reduction=1, ignore_index=-100, label_smoothing=0.0
750):
751 return CrossEntropyLoss.apply(
752 inp, target, weight, reduction, ignore_index, label_smoothing
753 )