Coverage for src/flag_gems/fused/cross_entropy_loss.py: 24%
376 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch.nn import _reduction as _Reduction
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
16@libentry()
17@triton.autotune(
18 configs=runtime.get_tuned_config("cross_entropy_loss"),
19 key=["C", "D"],
20)
21@triton.jit(do_not_specialize=["ignore_index"])
22def celoss_indices_kernel(
23 inp_ptr,
24 tgt_ptr,
25 w_ptr,
26 out_ptr,
27 w_tgt_ptr,
28 ignore_index,
29 C,
30 D,
31 BLOCK_C: tl.constexpr,
32 BLOCK_D: tl.constexpr,
33):
34 pid_d = tle.program_id(0)
35 pid_n = tle.program_id(1)
36 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
38 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
39 tgt_mask = offset_d < D
40 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
42 ignore_mask = not (tgt == ignore_index) and tgt_mask
44 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
45 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
47 for off in range(0, C, BLOCK_C):
48 offset_c = off + tl.arange(0, BLOCK_C)
49 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
50 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
51 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
52 cur_max = tl.maximum(tmp_max, inp)
53 cur_exp = tl.exp(inp - cur_max)
54 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
55 tmp_max = cur_max
56 final_max = tl.max(tmp_max, axis=0)
57 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :])
58 final_sum = tl.log(tl.sum(tmp_sum, axis=0))
60 inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d
61 inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32)
63 out = final_sum + final_max - inp_tgt
64 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
66 if w_ptr is None:
67 w_tgt = ignore_mask
68 else:
69 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
71 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask)
72 out *= w_tgt
73 out_ptrs = out_ptr + pid_n * D + offset_d
74 tl.store(out_ptrs, out, mask=tgt_mask)
77@libentry()
78@triton.autotune(
79 configs=runtime.get_tuned_config("cross_entropy_loss"),
80 key=["C", "D"],
81)
82@triton.jit(do_not_specialize=["label_smoothing"])
83def celoss_probability_kernel(
84 inp_ptr,
85 tgt_ptr,
86 w_ptr,
87 out_ptr,
88 label_smoothing,
89 C,
90 D,
91 BLOCK_C: tl.constexpr,
92 BLOCK_D: tl.constexpr,
93):
94 pid_d = tle.program_id(0)
95 pid_n = tle.program_id(1)
96 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
98 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
99 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
101 for off in range(0, C, BLOCK_C):
102 offset_c = off + tl.arange(0, BLOCK_C)
103 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
104 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
105 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
106 cur_max = tl.maximum(tmp_max, inp)
107 cur_exp = tl.exp(inp - cur_max)
108 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
109 tmp_max = cur_max
110 final_max = tl.max(tmp_max, axis=0)[None, :]
111 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
112 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :]
114 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
115 for off in range(0, C, BLOCK_C):
116 offset_c = off + tl.arange(0, BLOCK_C)
117 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
118 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
119 mask = offset_c[:, None] < C and offset_d[None, :] < D
120 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
121 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
122 tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C
123 log = final_sum + final_max - inp
124 w_mask = offset_c < C
125 if w_ptr is None:
126 w = w_mask
127 else:
128 w = tl.load(w_ptr + offset_c, mask=w_mask, other=0).to(tl.float32)
129 _sum += log * tgt * w[:, None]
131 out = tl.sum(_sum, axis=0)
132 out_ptrs = out_ptr + pid_n * D + offset_d
133 tl.store(out_ptrs, out, mask=offset_d < D)
136@libentry()
137@triton.autotune(
138 configs=runtime.get_tuned_config("cross_entropy_loss"),
139 key=["C", "D"],
140)
141@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"])
142def celoss_indices_smooth_kernel(
143 inp_ptr,
144 tgt_ptr,
145 w_ptr,
146 out_ptr,
147 w_tgt_ptr,
148 ignore_index,
149 label_smoothing,
150 C,
151 D,
152 BLOCK_C: tl.constexpr,
153 BLOCK_D: tl.constexpr,
154):
155 pid_d = tle.program_id(0)
156 pid_n = tle.program_id(1)
157 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
159 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
160 tgt_mask = offset_d < D
161 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
163 ignore_mask = not (tgt == ignore_index) and tgt_mask
165 if w_ptr is None:
166 w_tgt = ignore_mask
167 else:
168 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0)
169 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
170 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask)
172 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
173 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
175 for off in range(0, C, BLOCK_C):
176 offset_c = off + tl.arange(0, BLOCK_C)
177 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
178 mask = offset_c[:, None] < C and offset_d[None, :] < D
179 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32)
180 cur_max = tl.maximum(tmp_max, inp)
181 cur_exp = tl.exp(inp - cur_max)
182 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
183 tmp_max = cur_max
184 final_max = tl.max(tmp_max, axis=0)[None, :]
185 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
186 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :]
187 final_sum_max = final_sum + final_max
189 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
190 for off in range(0, C, BLOCK_C):
191 offset_c = off + tl.arange(0, BLOCK_C)
192 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
193 mask = offset_c[:, None] < C and offset_d[None, :] < D
194 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
196 w_mask = offset_c < C
197 if w_ptr is None:
198 w = w_mask
199 else:
200 w = tl.load(w_ptr + offset_c, w_mask, other=0).to(tl.float32)
202 smooth = tl.where(
203 offset_c[:, None] == tgt[None, :],
204 1 - label_smoothing + label_smoothing / C,
205 label_smoothing / C,
206 ).to(tl.float32)
208 log = final_sum_max - inp
209 _sum += log * smooth * w[:, None]
211 out = tl.sum(_sum, axis=0)
212 out = tl.where(ignore_mask, out, 0)
213 out_ptrs = out_ptr + pid_n * D + offset_d
214 tl.store(out_ptrs, out, mask=tgt_mask)
217@libentry()
218@triton.autotune(
219 configs=runtime.get_tuned_config("cross_entropy_loss"),
220 key=["C", "D"],
221)
222@triton.jit(do_not_specialize=["ignore_index", "mean_num"])
223def celoss_indices_bwd(
224 out_grad_ptr,
225 inp_ptr,
226 tgt_ptr,
227 w_ptr,
228 inp_grad_ptr,
229 ignore_index,
230 mean_num,
231 C,
232 D,
233 BLOCK_C: tl.constexpr,
234 BLOCK_D: tl.constexpr,
235):
236 pid_d = tle.program_id(0)
237 pid_n = tle.program_id(1)
238 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
240 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
241 tgt_mask = offset_d < D
242 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
243 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
244 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]
246 if w_ptr is None:
247 w_tgt = tgt_mask
248 else:
249 w_ptrs = w_ptr + tgt
250 w_tgt = tl.load(w_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]
252 ignore_mask = (tgt != ignore_index)[None, :]
254 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
255 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
257 for off in range(0, C, BLOCK_C):
258 offset_c = off + tl.arange(0, BLOCK_C)
259 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
260 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
261 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
262 cur_max = tl.maximum(tmp_max, inp)
263 cur_exp = tl.exp(inp - cur_max)
264 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
265 tmp_max = cur_max
266 final_max = tl.max(tmp_max, axis=0)[None, :]
267 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
268 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
270 for off in range(0, C, BLOCK_C):
271 offset_c = off + tl.arange(0, BLOCK_C)
272 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
273 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
274 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
275 minus_one = offset_c[:, None] == tgt[None, :]
276 inp_grad = (
277 (tl.exp(inp - final_max) / final_sum - minus_one)
278 * w_tgt
279 * out_grad
280 * mean_num
281 )
282 inp_grad_ptrs = (
283 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
284 )
285 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask)
288@libentry()
289@triton.autotune(
290 configs=runtime.get_tuned_config("cross_entropy_loss"),
291 key=["C", "D"],
292)
293@triton.jit(do_not_specialize=["label_smoothing", "mean_num"])
294def celoss_probability_bwd(
295 out_grad_ptr,
296 inp_ptr,
297 tgt_ptr,
298 w_ptr,
299 inp_grad_ptr,
300 label_smoothing,
301 mean_num,
302 C,
303 D,
304 BLOCK_C: tl.constexpr,
305 BLOCK_D: tl.constexpr,
306):
307 pid_d = tle.program_id(0)
308 pid_n = tle.program_id(1)
309 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
311 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
312 out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32)[
313 None, :
314 ]
316 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
317 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
318 w_tgt_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
320 for off in range(0, C, BLOCK_C):
321 offset_c = off + tl.arange(0, BLOCK_C)
322 mask = offset_c[:, None] < C and offset_d[None, :] < D
323 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
324 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32)
326 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
327 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
328 tgt = tgt * (1 - label_smoothing) + label_smoothing / C
330 w_mask = offset_c < C
331 if w_ptr is None:
332 w = w_mask
333 else:
334 w_ptrs = w_ptr + offset_c
335 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
337 w_tgt_sum += tgt * w[:, None]
339 cur_max = tl.maximum(tmp_max, inp)
340 cur_exp = tl.exp(inp - cur_max)
341 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
342 tmp_max = cur_max
343 final_max = tl.max(tmp_max, axis=0)[None, :]
344 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
345 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
346 w_tgt_sum = tl.sum(w_tgt_sum, axis=0)[None, :]
348 for off in range(0, C, BLOCK_C):
349 offset_c = off + tl.arange(0, BLOCK_C)
350 offset = pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
351 inp_ptrs = inp_ptr + offset
352 mask = offset_c[:, None] < C and offset_d[None, :] < D
353 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
355 tgt_ptrs = tgt_ptr + offset
356 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
357 tgt = tgt * (1 - label_smoothing) + label_smoothing / C
359 w_mask = offset_c < C
360 if w_ptr is None:
361 w = w_mask
362 else:
363 w_ptrs = w_ptr + offset_c
364 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
366 grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - tgt * w[:, None]
367 inp_grad = grad * out_grad * mean_num
369 inp_grad_ptrs = inp_grad_ptr + offset
370 tl.store(inp_grad_ptrs, inp_grad, mask)
373@libentry()
374@triton.autotune(
375 configs=runtime.get_tuned_config("cross_entropy_loss"),
376 key=["C", "D"],
377)
378@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"])
379def celoss_indices_smooth_bwd(
380 out_grad_ptr,
381 inp_ptr,
382 tgt_ptr,
383 w_ptr,
384 inp_grad_ptr,
385 ignore_index,
386 label_smoothing,
387 mean_num,
388 C,
389 D,
390 BLOCK_C: tl.constexpr,
391 BLOCK_D: tl.constexpr,
392):
393 pid_d = tle.program_id(0)
394 pid_n = tle.program_id(1)
395 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
397 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
398 tgt_mask = offset_d < D
399 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
400 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
401 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]
403 ignore_mask = (tgt != ignore_index)[None, :]
405 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
406 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
407 w_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
409 for off in range(0, C, BLOCK_C):
410 offset_c = off + tl.arange(0, BLOCK_C)
411 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
412 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
413 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
415 w_mask = offset_c < C
416 if w_ptr is None:
417 w = w_mask
418 else:
419 w_ptrs = w_ptr + offset_c
420 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
422 smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32)
423 smooth = tl.where(
424 offset_c[:, None] == tgt[None, :],
425 1 - label_smoothing + label_smoothing / C,
426 smooth,
427 )
429 w_sum += smooth * w[:, None]
431 cur_max = tl.maximum(tmp_max, inp)
432 cur_exp = tl.exp(inp - cur_max)
433 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
434 tmp_max = cur_max
435 final_max = tl.max(tmp_max, axis=0)[None, :]
436 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
437 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
438 w_sum = tl.sum(w_sum, axis=0)[None, :]
440 for off in range(0, C, BLOCK_C):
441 offset_c = off + tl.arange(0, BLOCK_C)
442 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
443 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
444 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
446 w_mask = offset_c < C
447 if w_ptr is None:
448 w = w_mask
449 else:
450 w_ptrs = w_ptr + offset_c
451 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
453 smooth = tl.where(
454 offset_c[:, None] == tgt[None, :],
455 1 - label_smoothing + label_smoothing / C,
456 label_smoothing / C,
457 )
459 grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[:, None]
460 inp_grad = grad * out_grad * mean_num
461 inp_grad_ptrs = (
462 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
463 )
464 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask)
467@libentry()
468@triton.jit
469def sum_and_scale(
470 inp_ptr,
471 out_ptr,
472 N,
473 scalebyw: tl.constexpr,
474 BLOCK_N: tl.constexpr = 128,
475 scale=1.0,
476 mean_num=None,
477):
478 mid_sum = tl.zeros(
479 [
480 BLOCK_N,
481 ],
482 dtype=tl.float32,
483 )
484 if scalebyw:
485 mid_wgt = tl.zeros(
486 [
487 BLOCK_N,
488 ],
489 dtype=tl.float32,
490 )
491 for off in range(0, N, BLOCK_N):
492 offset = off + tl.arange(0, BLOCK_N)
493 inp_ptrs = inp_ptr + offset
494 mask = offset < N
495 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0)
496 mid_sum += inp_vals
497 wgt_ptrs = scale + offset
498 wgt_vals = tl.load(wgt_ptrs, mask=mask, other=0.0)
499 mid_wgt += wgt_vals
500 out_val = tl.sum(mid_sum)
501 scale_val = tl.sum(mid_wgt)
502 tl.store(mean_num, scale_val)
503 else:
504 for off in range(0, N, BLOCK_N):
505 offset = off + tl.arange(0, BLOCK_N)
506 inp_ptrs = inp_ptr + offset
507 mask = offset < N
508 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0)
509 mid_sum += inp_vals
510 out_val = tl.sum(mid_sum)
511 scale_val = scale
512 out_val /= scale_val
513 tl.store(out_ptr, out_val)
516class CrossEntropyLoss(torch.autograd.Function):
517 @staticmethod
518 def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):
519 logger.debug("GEMS CrossEntropyLoss")
521 shape = list(inp.shape)
522 dim = inp.ndim
523 N = 1 if dim == 1 else shape[0]
524 C = shape[0] if dim == 1 else shape[1]
525 D = inp.numel() // N // C
526 axis = 0 if dim == 1 else 1
527 del shape[axis]
529 inp = inp.contiguous()
530 tgt = target.contiguous()
531 weight = weight.contiguous() if weight is not None else None
532 out = torch.empty(shape, dtype=torch.float32, device=inp.device)
533 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N)
535 if tgt.ndim == dim:
536 # target probabilities
537 with torch_device_fn.device(inp.device):
538 celoss_probability_kernel[grid](
539 inp,
540 tgt,
541 weight,
542 out,
543 label_smoothing,
544 C,
545 D,
546 )
547 elif label_smoothing == 0:
548 # target indices
549 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device)
550 with torch_device_fn.device(inp.device):
551 celoss_indices_kernel[grid](
552 inp,
553 tgt,
554 weight,
555 out,
556 w_tgt,
557 ignore_index,
558 C,
559 D,
560 )
561 else:
562 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device)
563 with torch_device_fn.device(inp.device):
564 celoss_indices_smooth_kernel[grid](
565 inp,
566 tgt,
567 weight,
568 out,
569 w_tgt,
570 ignore_index,
571 label_smoothing,
572 C,
573 D,
574 )
576 if reduction == 1: # MEAN
577 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device)
578 if tgt.ndim == dim:
579 sum_and_scale[(1,)](out, out_reduce, N * D, False, scale=N * D)
580 else:
581 wgt_sum = torch.empty([], dtype=torch.float32, device=inp.device)
582 sum_and_scale[(1,)](
583 out, out_reduce, N * D, True, scale=w_tgt, mean_num=wgt_sum
584 )
585 out = out_reduce
586 elif reduction == 2: # SUM
587 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device)
588 sum_and_scale[(1,)](out, out_reduce, N * D, False)
589 out = out_reduce
591 if inp.requires_grad:
592 ctx.save_for_backward(inp, tgt, weight)
593 ctx.N = N
594 ctx.C = C
595 ctx.D = D
596 ctx.ignore_index = ignore_index
597 ctx.label_smoothing = label_smoothing
598 ctx.shape = shape
599 ctx.mean_num = 1
600 if reduction == 1:
601 ctx.mean_num = N * D if tgt.ndim == dim else wgt_sum
603 return out.to(inp.dtype)
605 @staticmethod
606 def backward(ctx, out_grad):
607 logger.debug("GEMS CrossEntropyLoss VJP")
609 inp, tgt, weight = ctx.saved_tensors
610 N = ctx.N
611 C = ctx.C
612 D = ctx.D
613 ignore_index = ctx.ignore_index
614 label_smoothing = ctx.label_smoothing
615 mean_num = (
616 1 / ctx.mean_num.item()
617 if isinstance(ctx.mean_num, torch.Tensor)
618 else 1 / ctx.mean_num
619 )
620 shape = ctx.shape
622 out_grad = out_grad.broadcast_to(shape).contiguous()
624 inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device)
625 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N)
626 if tgt.ndim == inp.ndim:
627 celoss_probability_bwd[grid](
628 out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, C, D
629 )
630 elif label_smoothing == 0:
631 celoss_indices_bwd[grid](
632 out_grad, inp, tgt, weight, inp_grad, ignore_index, mean_num, C, D
633 )
634 else:
635 celoss_indices_smooth_bwd[grid](
636 out_grad,
637 inp,
638 tgt,
639 weight,
640 inp_grad,
641 ignore_index,
642 label_smoothing,
643 mean_num,
644 C,
645 D,
646 )
647 return inp_grad, None, None, None, None, None
650def cross_entropy_loss(
651 inp, target, weight=None, reduction="mean", ignore_index=-100, label_smoothing=0.0
652):
653 return CrossEntropyLoss.apply(
654 inp,
655 target,
656 weight,
657 _Reduction.get_enum(reduction),
658 ignore_index,
659 label_smoothing,
660 )