Coverage for src/flag_gems/runtime/backend/_ascend/fused/cross_entropy_loss.py: 0%
380 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch.nn import _reduction as _Reduction
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
16@libentry()
17@triton.autotune(
18 configs=runtime.get_tuned_config("cross_entropy_loss"),
19 key=["C", "D"],
20)
21@triton.jit(do_not_specialize=["ignore_index"])
22def celoss_indices_kernel(
23 inp_ptr,
24 tgt_ptr,
25 w_ptr,
26 out_ptr,
27 w_tgt_ptr,
28 ignore_index,
29 C,
30 D,
31 BLOCK_C: tl.constexpr,
32 BLOCK_D: tl.constexpr,
33):
34 pid_d = tle.program_id(0)
35 pid_n = tle.program_id(1)
36 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
38 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
39 tgt_mask = offset_d < D
40 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
42 ignore_mask = not (tgt == ignore_index) and tgt_mask
44 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
45 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
47 for off in range(0, C, BLOCK_C):
48 offset_c = off + tl.arange(0, BLOCK_C)
49 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
50 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
51 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
52 cur_max = tl.maximum(tmp_max, inp)
53 cur_exp = tl.exp(inp - cur_max)
54 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
55 tmp_max = cur_max
56 final_max = tl.max(tmp_max, axis=0)
57 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :])
58 final_sum = tl.log(tl.sum(tmp_sum, axis=0))
60 inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d
61 inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32)
63 out = final_sum + final_max - inp_tgt
64 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
66 if w_ptr is None:
67 w_tgt = ignore_mask
68 else:
69 w_tgt = tl.load(w_ptr + tgt, mask=tgt_mask, other=0).to(tl.float32)
70 w_tgt = tl.where(ignore_mask, w_tgt, 0)
72 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask)
73 out *= w_tgt
74 out_ptrs = out_ptr + pid_n * D + offset_d
75 tl.store(out_ptrs, out, mask=tgt_mask)
78@libentry()
79@triton.autotune(
80 configs=runtime.get_tuned_config("cross_entropy_loss"),
81 key=["C", "D"],
82)
83@triton.jit(do_not_specialize=["label_smoothing"])
84def celoss_probability_kernel(
85 inp_ptr,
86 tgt_ptr,
87 w_ptr,
88 out_ptr,
89 label_smoothing,
90 C,
91 D,
92 BLOCK_C: tl.constexpr,
93 BLOCK_D: tl.constexpr,
94):
95 pid_d = tle.program_id(0)
96 pid_n = tle.program_id(1)
97 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
99 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
100 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
102 for off in range(0, C, BLOCK_C):
103 offset_c = off + tl.arange(0, BLOCK_C)
104 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
105 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
106 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
107 cur_max = tl.maximum(tmp_max, inp)
108 cur_exp = tl.exp(inp - cur_max)
109 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
110 tmp_max = cur_max
111 final_max = tl.max(tmp_max, axis=0)[None, :]
112 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
113 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :]
115 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
116 for off in range(0, C, BLOCK_C):
117 offset_c = off + tl.arange(0, BLOCK_C)
118 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
119 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
120 mask = offset_c[:, None] < C and offset_d[None, :] < D
121 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
122 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
123 tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C
124 log = final_sum + final_max - inp
125 w_mask = offset_c < C
126 if w_ptr is None:
127 w = w_mask
128 else:
129 w = tl.load(w_ptr + offset_c, mask=w_mask, other=0).to(tl.float32)
130 _sum += log * tgt * w[:, None]
132 out = tl.sum(_sum, axis=0)
133 out_ptrs = out_ptr + pid_n * D + offset_d
134 tl.store(out_ptrs, out, mask=offset_d < D)
137@libentry()
138@triton.autotune(
139 configs=runtime.get_tuned_config("cross_entropy_loss"),
140 key=["C", "D"],
141)
142@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"])
143def celoss_indices_smooth_kernel(
144 inp_ptr,
145 tgt_ptr,
146 w_ptr,
147 out_ptr,
148 w_tgt_ptr,
149 ignore_index,
150 label_smoothing,
151 C,
152 D,
153 BLOCK_C: tl.constexpr,
154 BLOCK_D: tl.constexpr,
155):
156 pid_d = tle.program_id(0)
157 pid_n = tle.program_id(1)
158 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
160 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
161 tgt_mask = offset_d < D
162 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
164 ignore_mask = not (tgt == ignore_index) and tgt_mask
166 if w_ptr is None:
167 w_tgt = ignore_mask
168 else:
169 w_tgt = tl.load(w_ptr + tgt, mask=tgt_mask, other=0)
170 w_tgt = tl.where(ignore_mask, w_tgt, 0)
171 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
172 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask)
174 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
175 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
177 for off in range(0, C, BLOCK_C):
178 offset_c = off + tl.arange(0, BLOCK_C)
179 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
180 mask = offset_c[:, None] < C and offset_d[None, :] < D
181 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32)
182 cur_max = tl.maximum(tmp_max, inp)
183 cur_exp = tl.exp(inp - cur_max)
184 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
185 tmp_max = cur_max
186 final_max = tl.max(tmp_max, axis=0)[None, :]
187 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
188 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :]
189 final_sum_max = final_sum + final_max
191 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
192 for off in range(0, C, BLOCK_C):
193 offset_c = off + tl.arange(0, BLOCK_C)
194 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
195 mask = offset_c[:, None] < C and offset_d[None, :] < D
196 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
198 w_mask = offset_c < C
199 if w_ptr is None:
200 w = w_mask
201 else:
202 w = tl.load(w_ptr + offset_c, w_mask, other=0).to(tl.float32)
204 smooth = tl.where(
205 offset_c[:, None] == tgt[None, :],
206 1 - label_smoothing + label_smoothing / C,
207 label_smoothing / C,
208 ).to(tl.float32)
210 log = final_sum_max - inp
211 _sum += log * smooth * w[:, None]
213 out = tl.sum(_sum, axis=0)
214 out = tl.where(ignore_mask, out, 0)
215 out_ptrs = out_ptr + pid_n * D + offset_d
216 tl.store(out_ptrs, out, mask=tgt_mask)
219@libentry()
220@triton.autotune(
221 configs=runtime.get_tuned_config("cross_entropy_loss"),
222 key=["C", "D"],
223)
224@triton.jit(do_not_specialize=["ignore_index", "mean_num"])
225def celoss_indices_bwd(
226 out_grad_ptr,
227 inp_ptr,
228 tgt_ptr,
229 w_ptr,
230 inp_grad_ptr,
231 ignore_index,
232 mean_num,
233 C,
234 D,
235 BLOCK_C: tl.constexpr,
236 BLOCK_D: tl.constexpr,
237):
238 pid_d = tle.program_id(0)
239 pid_n = tle.program_id(1)
240 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
242 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
243 tgt_mask = offset_d < D
244 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
245 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
246 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]
248 if w_ptr is None:
249 w_tgt = tgt_mask
250 else:
251 w_ptrs = w_ptr + tgt
252 w_tgt = tl.load(w_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]
254 ignore_mask = (tgt != ignore_index)[None, :]
256 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
257 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
259 for off in range(0, C, BLOCK_C):
260 offset_c = off + tl.arange(0, BLOCK_C)
261 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
262 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
263 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
264 cur_max = tl.maximum(tmp_max, inp)
265 cur_exp = tl.exp(inp - cur_max)
266 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
267 tmp_max = cur_max
268 final_max = tl.max(tmp_max, axis=0)[None, :]
269 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
270 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
272 for off in range(0, C, BLOCK_C):
273 offset_c = off + tl.arange(0, BLOCK_C)
274 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
275 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
276 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
277 minus_one = offset_c[:, None] == tgt[None, :]
278 inp_grad = (
279 (tl.exp(inp - final_max) / final_sum - minus_one)
280 * w_tgt
281 * out_grad
282 * mean_num
283 )
284 inp_grad = tl.where(ignore_mask, inp_grad, 0.0)
285 inp_grad_ptrs = (
286 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
287 )
288 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask)
291@libentry()
292@triton.autotune(
293 configs=runtime.get_tuned_config("cross_entropy_loss"),
294 key=["C", "D"],
295)
296@triton.jit(do_not_specialize=["label_smoothing", "mean_num"])
297def celoss_probability_bwd(
298 out_grad_ptr,
299 inp_ptr,
300 tgt_ptr,
301 w_ptr,
302 inp_grad_ptr,
303 label_smoothing,
304 mean_num,
305 C,
306 D,
307 BLOCK_C: tl.constexpr,
308 BLOCK_D: tl.constexpr,
309):
310 pid_d = tle.program_id(0)
311 pid_n = tle.program_id(1)
312 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
314 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
315 out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32)[
316 None, :
317 ]
319 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
320 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
321 w_tgt_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
323 for off in range(0, C, BLOCK_C):
324 offset_c = off + tl.arange(0, BLOCK_C)
325 mask = offset_c[:, None] < C and offset_d[None, :] < D
326 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
327 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32)
329 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
330 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
331 tgt = tgt * (1 - label_smoothing) + label_smoothing / C
333 w_mask = offset_c < C
334 if w_ptr is None:
335 w = w_mask
336 else:
337 w_ptrs = w_ptr + offset_c
338 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
340 w_tgt_sum += tgt * w[:, None]
342 cur_max = tl.maximum(tmp_max, inp)
343 cur_exp = tl.exp(inp - cur_max)
344 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
345 tmp_max = cur_max
346 final_max = tl.max(tmp_max, axis=0)[None, :]
347 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
348 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
349 w_tgt_sum = tl.sum(w_tgt_sum, axis=0)[None, :]
351 for off in range(0, C, BLOCK_C):
352 offset_c = off + tl.arange(0, BLOCK_C)
353 offset = pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
354 inp_ptrs = inp_ptr + offset
355 mask = offset_c[:, None] < C and offset_d[None, :] < D
356 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
358 tgt_ptrs = tgt_ptr + offset
359 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
360 tgt = tgt * (1 - label_smoothing) + label_smoothing / C
362 w_mask = offset_c < C
363 if w_ptr is None:
364 w = w_mask
365 else:
366 w_ptrs = w_ptr + offset_c
367 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
369 grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - tgt * w[:, None]
370 inp_grad = grad * out_grad * mean_num
372 inp_grad_ptrs = inp_grad_ptr + offset
373 tl.store(inp_grad_ptrs, inp_grad, mask)
376@libentry()
377@triton.autotune(
378 configs=runtime.get_tuned_config("cross_entropy_loss"),
379 key=["C", "D"],
380)
381@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"])
382def celoss_indices_smooth_bwd(
383 out_grad_ptr,
384 inp_ptr,
385 tgt_ptr,
386 w_ptr,
387 inp_grad_ptr,
388 ignore_index,
389 label_smoothing,
390 mean_num,
391 C,
392 D,
393 BLOCK_C: tl.constexpr,
394 BLOCK_D: tl.constexpr,
395):
396 pid_d = tle.program_id(0)
397 pid_n = tle.program_id(1)
398 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
400 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
401 tgt_mask = offset_d < D
402 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
403 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
404 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]
406 ignore_mask = (tgt != ignore_index)[None, :]
408 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
409 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
410 w_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
412 for off in range(0, C, BLOCK_C):
413 offset_c = off + tl.arange(0, BLOCK_C)
414 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
415 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
416 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
418 w_mask = offset_c < C
419 if w_ptr is None:
420 w = w_mask
421 else:
422 w_ptrs = w_ptr + offset_c
423 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
425 smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32)
426 smooth = tl.where(
427 offset_c[:, None] == tgt[None, :],
428 1 - label_smoothing + label_smoothing / C,
429 smooth,
430 )
432 w_sum += smooth * w[:, None]
434 cur_max = tl.maximum(tmp_max, inp)
435 cur_exp = tl.exp(inp - cur_max)
436 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
437 tmp_max = cur_max
438 final_max = tl.max(tmp_max, axis=0)[None, :]
439 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
440 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
441 w_sum = tl.sum(w_sum, axis=0)[None, :]
443 for off in range(0, C, BLOCK_C):
444 offset_c = off + tl.arange(0, BLOCK_C)
445 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
446 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
447 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
449 w_mask = offset_c < C
450 if w_ptr is None:
451 w = w_mask
452 else:
453 w_ptrs = w_ptr + offset_c
454 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
456 smooth = tl.where(
457 offset_c[:, None] == tgt[None, :],
458 1 - label_smoothing + label_smoothing / C,
459 label_smoothing / C,
460 )
462 grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[:, None]
463 inp_grad = grad * out_grad * mean_num
464 inp_grad = tl.where(ignore_mask, inp_grad, 0.0)
465 inp_grad_ptrs = (
466 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
467 )
468 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask)
471@libentry()
472@triton.jit
473def sum_and_scale(
474 inp_ptr,
475 out_ptr,
476 N,
477 scalebyw: tl.constexpr,
478 BLOCK_N: tl.constexpr = 128,
479 scale=1.0,
480 mean_num=None,
481):
482 mid_sum = tl.zeros(
483 [
484 BLOCK_N,
485 ],
486 dtype=tl.float32,
487 )
488 if scalebyw:
489 mid_wgt = tl.zeros(
490 [
491 BLOCK_N,
492 ],
493 dtype=tl.float32,
494 )
495 for off in range(0, N, BLOCK_N):
496 offset = off + tl.arange(0, BLOCK_N)
497 inp_ptrs = inp_ptr + offset
498 mask = offset < N
499 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0)
500 mid_sum += inp_vals
501 wgt_ptrs = scale + offset
502 wgt_vals = tl.load(wgt_ptrs, mask=mask, other=0.0)
503 mid_wgt += wgt_vals
504 out_val = tl.sum(mid_sum)
505 scale_val = tl.sum(mid_wgt)
506 tl.store(mean_num, scale_val)
507 else:
508 for off in range(0, N, BLOCK_N):
509 offset = off + tl.arange(0, BLOCK_N)
510 inp_ptrs = inp_ptr + offset
511 mask = offset < N
512 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0)
513 mid_sum += inp_vals
514 out_val = tl.sum(mid_sum)
515 scale_val = scale
516 out_val /= scale_val
517 tl.store(out_ptr, out_val)
520class CrossEntropyLoss(torch.autograd.Function):
521 @staticmethod
522 def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):
523 logger.debug("GEMS_ASCEND CrossEntropyLoss")
525 shape = list(inp.shape)
526 dim = inp.ndim
527 N = 1 if dim == 1 else shape[0]
528 C = shape[0] if dim == 1 else shape[1]
529 D = inp.numel() // N // C
530 axis = 0 if dim == 1 else 1
531 del shape[axis]
533 inp = inp.contiguous()
534 tgt = target.contiguous()
535 weight = weight.contiguous() if weight is not None else None
536 out = torch.empty(shape, dtype=torch.float32, device=inp.device)
537 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N)
539 if tgt.ndim == dim:
540 # target probabilities
541 with torch_device_fn.device(inp.device):
542 celoss_probability_kernel[grid](
543 inp,
544 tgt,
545 weight,
546 out,
547 label_smoothing,
548 C,
549 D,
550 )
551 elif label_smoothing == 0:
552 # target indices
553 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device)
554 with torch_device_fn.device(inp.device):
555 celoss_indices_kernel[grid](
556 inp,
557 tgt,
558 weight,
559 out,
560 w_tgt,
561 ignore_index,
562 C,
563 D,
564 )
565 else:
566 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device)
567 with torch_device_fn.device(inp.device):
568 celoss_indices_smooth_kernel[grid](
569 inp,
570 tgt,
571 weight,
572 out,
573 w_tgt,
574 ignore_index,
575 label_smoothing,
576 C,
577 D,
578 )
580 if reduction == 1: # MEAN
581 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device)
582 if tgt.ndim == dim:
583 sum_and_scale[(1,)](out, out_reduce, N * D, False, scale=N * D)
584 else:
585 wgt_sum = torch.empty([], dtype=torch.float32, device=inp.device)
586 sum_and_scale[(1,)](
587 out, out_reduce, N * D, True, scale=w_tgt, mean_num=wgt_sum
588 )
589 out = out_reduce
590 elif reduction == 2: # SUM
591 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device)
592 sum_and_scale[(1,)](out, out_reduce, N * D, False)
593 out = out_reduce
595 if inp.requires_grad:
596 ctx.save_for_backward(inp, tgt, weight)
597 ctx.N = N
598 ctx.C = C
599 ctx.D = D
600 ctx.ignore_index = ignore_index
601 ctx.label_smoothing = label_smoothing
602 ctx.shape = shape
603 ctx.mean_num = 1
604 if reduction == 1:
605 ctx.mean_num = N * D if tgt.ndim == dim else wgt_sum
607 return out.to(inp.dtype)
609 @staticmethod
610 def backward(ctx, out_grad):
611 logger.debug("GEMS_ASCEND CrossEntropyLoss VJP")
613 inp, tgt, weight = ctx.saved_tensors
614 N = ctx.N
615 C = ctx.C
616 D = ctx.D
617 ignore_index = ctx.ignore_index
618 label_smoothing = ctx.label_smoothing
619 mean_num = (
620 1 / ctx.mean_num.item()
621 if isinstance(ctx.mean_num, torch.Tensor)
622 else 1 / ctx.mean_num
623 )
624 shape = ctx.shape
626 out_grad = out_grad.broadcast_to(shape).contiguous()
628 inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device)
629 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N)
630 if tgt.ndim == inp.ndim:
631 celoss_probability_bwd[grid](
632 out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, C, D
633 )
634 elif label_smoothing == 0:
635 celoss_indices_bwd[grid](
636 out_grad, inp, tgt, weight, inp_grad, ignore_index, mean_num, C, D
637 )
638 else:
639 celoss_indices_smooth_bwd[grid](
640 out_grad,
641 inp,
642 tgt,
643 weight,
644 inp_grad,
645 ignore_index,
646 label_smoothing,
647 mean_num,
648 C,
649 D,
650 )
651 return inp_grad, None, None, None, None, None
654def cross_entropy_loss(
655 inp, target, weight=None, reduction="mean", ignore_index=-100, label_smoothing=0.0
656):
657 return CrossEntropyLoss.apply(
658 inp,
659 target,
660 weight,
661 _Reduction.get_enum(reduction),
662 ignore_index,
663 label_smoothing,
664 )