Coverage for src/flag_gems/runtime/backend/_mthreads/fused/cross_entropy_loss.py: 0%
376 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch.nn import _reduction as _Reduction
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
16@libentry()
17@triton.autotune(
18 configs=runtime.get_tuned_config("cross_entropy_loss"),
19 key=["C", "D"],
20)
21@triton.jit(do_not_specialize=["ignore_index"])
22def celoss_indices_kernel(
23 inp_ptr,
24 tgt_ptr,
25 w_ptr,
26 out_ptr,
27 w_tgt_ptr,
28 ignore_index,
29 C,
30 D,
31 BLOCK_C: tl.constexpr,
32 BLOCK_D: tl.constexpr,
33):
34 pid_d = tle.program_id(0)
35 pid_n = tle.program_id(1)
36 offset_d = (pid_d * BLOCK_D).to(tl.int64) + tl.arange(0, BLOCK_D)
38 tgt_ptrs = tgt_ptr + (pid_n * D).to(tl.int64) + offset_d
39 tgt_mask = offset_d < D
40 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
42 ignore_mask = not (tgt == ignore_index) and tgt_mask
44 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
45 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
47 for off in range(0, C, BLOCK_C):
48 offset_c = off + tl.arange(0, BLOCK_C)
49 inp_ptrs = inp_ptr + (
50 (pid_n * C * D).to(tl.int64)
51 + (offset_c[:, None] * D).to(tl.int64)
52 + offset_d[None, :]
53 ).to(tl.int64)
54 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
55 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
56 cur_max = tl.maximum(tmp_max, inp)
57 cur_exp = tl.exp(inp - cur_max)
58 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
59 tmp_max = cur_max
60 final_max = tl.max(tmp_max, axis=0)
61 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :])
62 final_sum = tl.log(tl.sum(tmp_sum, axis=0))
64 inp_tgt_ptrs = inp_ptr + (pid_n * C * D).to(tl.int64) + tgt * D + offset_d
65 inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32)
67 out = final_sum + final_max - inp_tgt
68 w_tgt_ptrs = w_tgt_ptr + (pid_n * D).to(tl.int64) + offset_d
70 if w_ptr is None:
71 w_tgt = ignore_mask
72 else:
73 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
75 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask)
76 out *= w_tgt
77 out_ptrs = out_ptr + (pid_n * D).to(tl.int64) + offset_d
78 tl.store(out_ptrs, out, mask=tgt_mask)
81@libentry()
82@triton.autotune(
83 configs=runtime.get_tuned_config("cross_entropy_loss"),
84 key=["C", "D"],
85)
86@triton.jit(do_not_specialize=["label_smoothing"])
87def celoss_probability_kernel(
88 inp_ptr,
89 tgt_ptr,
90 w_ptr,
91 out_ptr,
92 label_smoothing,
93 C,
94 D,
95 BLOCK_C: tl.constexpr,
96 BLOCK_D: tl.constexpr,
97):
98 pid_d = tle.program_id(0)
99 pid_n = tle.program_id(1)
100 offset_d = (pid_d * BLOCK_D).to(tl.int64) + tl.arange(0, BLOCK_D)
102 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
103 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
105 for off in range(0, C, BLOCK_C):
106 offset_c = off + tl.arange(0, BLOCK_C)
107 inp_ptrs = inp_ptr + (
108 (pid_n * C * D).to(tl.int64)
109 + (offset_c[:, None] * D).to(tl.int64)
110 + offset_d[None, :]
111 ).to(tl.int64)
112 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
113 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
114 cur_max = tl.maximum(tmp_max, inp)
115 cur_exp = tl.exp(inp - cur_max)
116 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
117 tmp_max = cur_max
118 final_max = tl.max(tmp_max, axis=0)[None, :]
119 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
120 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :]
122 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
123 for off in range(0, C, BLOCK_C):
124 offset_c = off + tl.arange(0, BLOCK_C)
125 inp_ptrs = inp_ptr + (
126 (pid_n * C * D).to(tl.int64)
127 + (offset_c[:, None] * D).to(tl.int64)
128 + offset_d[None, :]
129 ).to(tl.int64)
130 tgt_ptrs = tgt_ptr + (
131 (pid_n * C * D).to(tl.int64)
132 + (offset_c[:, None] * D).to(tl.int64)
133 + offset_d[None, :]
134 ).to(tl.int64)
135 mask = offset_c[:, None] < C and offset_d[None, :] < D
136 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
137 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
138 tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C
139 log = final_sum + final_max - inp
140 w_mask = offset_c < C
141 if w_ptr is None:
142 w = w_mask
143 else:
144 w = tl.load(w_ptr + offset_c, mask=w_mask, other=0).to(tl.float32)
145 _sum += log * tgt * w[:, None]
147 out = tl.sum(_sum, axis=0)
148 out_ptrs = out_ptr + (pid_n * D).to(tl.int64) + offset_d
149 tl.store(out_ptrs, out, mask=offset_d < D)
152@libentry()
153@triton.autotune(
154 configs=runtime.get_tuned_config("cross_entropy_loss"),
155 key=["C", "D"],
156)
157@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"])
158def celoss_indices_smooth_kernel(
159 inp_ptr,
160 tgt_ptr,
161 w_ptr,
162 out_ptr,
163 w_tgt_ptr,
164 ignore_index,
165 label_smoothing,
166 C,
167 D,
168 BLOCK_C: tl.constexpr,
169 BLOCK_D: tl.constexpr,
170):
171 pid_d = tle.program_id(0)
172 pid_n = tle.program_id(1)
173 offset_d = (pid_d * BLOCK_D).to(tl.int64) + tl.arange(0, BLOCK_D)
175 tgt_ptrs = tgt_ptr + (pid_n * D).to(tl.int64) + offset_d
176 tgt_mask = offset_d < D
177 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
179 ignore_mask = not (tgt == ignore_index) and tgt_mask
181 if w_ptr is None:
182 w_tgt = ignore_mask
183 else:
184 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0)
185 w_tgt_ptrs = w_tgt_ptr + (pid_n * D).to(tl.int64) + offset_d
186 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask)
188 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
189 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
191 for off in range(0, C, BLOCK_C):
192 offset_c = off + tl.arange(0, BLOCK_C)
193 inp_ptrs = inp_ptr + (
194 (pid_n * C * D).to(tl.int64)
195 + (offset_c[:, None] * D).to(tl.int64)
196 + offset_d[None, :]
197 ).to(tl.int64)
198 mask = offset_c[:, None] < C and offset_d[None, :] < D
199 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32)
200 cur_max = tl.maximum(tmp_max, inp)
201 cur_exp = tl.exp(inp - cur_max)
202 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
203 tmp_max = cur_max
204 final_max = tl.max(tmp_max, axis=0)[None, :]
205 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
206 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :]
207 final_sum_max = final_sum + final_max
209 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
210 for off in range(0, C, BLOCK_C):
211 offset_c = off + tl.arange(0, BLOCK_C)
212 inp_ptrs = inp_ptr + (
213 (pid_n * C * D).to(tl.int64)
214 + (offset_c[:, None] * D).to(tl.int64)
215 + offset_d[None, :]
216 ).to(tl.int64)
217 mask = offset_c[:, None] < C and offset_d[None, :] < D
218 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
220 w_mask = offset_c < C
221 if w_ptr is None:
222 w = w_mask
223 else:
224 w = tl.load(w_ptr + offset_c, w_mask, other=0).to(tl.float32)
226 smooth = tl.where(
227 offset_c[:, None] == tgt[None, :],
228 1 - label_smoothing + label_smoothing / C,
229 label_smoothing / C,
230 ).to(tl.float32)
232 log = final_sum_max - inp
233 _sum += log * smooth * w[:, None]
235 out = tl.sum(_sum, axis=0)
236 out = tl.where(ignore_mask, out, 0)
237 out_ptrs = out_ptr + (pid_n * D).to(tl.int64) + offset_d
238 tl.store(out_ptrs, out, mask=tgt_mask)
241@libentry()
242@triton.autotune(
243 configs=runtime.get_tuned_config("cross_entropy_loss"),
244 key=["C", "D"],
245)
246@triton.jit(do_not_specialize=["ignore_index", "mean_num"])
247def celoss_indices_bwd(
248 out_grad_ptr,
249 inp_ptr,
250 tgt_ptr,
251 w_ptr,
252 inp_grad_ptr,
253 ignore_index,
254 mean_num,
255 C,
256 D,
257 BLOCK_C: tl.constexpr,
258 BLOCK_D: tl.constexpr,
259):
260 pid_d = tle.program_id(0)
261 pid_n = tle.program_id(1)
262 offset_d = (pid_d * BLOCK_D).to(tl.int64) + tl.arange(0, BLOCK_D)
264 tgt_ptrs = tgt_ptr + (pid_n * D).to(tl.int64) + offset_d
265 tgt_mask = offset_d < D
266 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
267 out_grad_ptrs = out_grad_ptr + (pid_n * D).to(tl.int64) + offset_d
268 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]
270 if w_ptr is None:
271 w_tgt = tgt_mask
272 else:
273 w_ptrs = w_ptr + tgt
274 w_tgt = tl.load(w_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]
276 ignore_mask = (tgt != ignore_index)[None, :]
278 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
279 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
281 for off in range(0, C, BLOCK_C):
282 offset_c = off + tl.arange(0, BLOCK_C)
283 inp_ptrs = inp_ptr + (
284 (pid_n * C * D).to(tl.int64)
285 + (offset_c[:, None] * D).to(tl.int64)
286 + offset_d[None, :]
287 ).to(tl.int64)
288 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
289 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
290 cur_max = tl.maximum(tmp_max, inp)
291 cur_exp = tl.exp(inp - cur_max)
292 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
293 tmp_max = cur_max
294 final_max = tl.max(tmp_max, axis=0)[None, :]
295 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
296 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
298 for off in range(0, C, BLOCK_C):
299 offset_c = off + tl.arange(0, BLOCK_C)
300 inp_ptrs = inp_ptr + (
301 (pid_n * C * D).to(tl.int64)
302 + (offset_c[:, None] * D).to(tl.int64)
303 + offset_d[None, :]
304 ).to(tl.int64)
305 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
306 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
307 minus_one = offset_c[:, None] == tgt[None, :]
308 inp_grad = (
309 (tl.exp(inp - final_max) / final_sum - minus_one)
310 * w_tgt
311 * out_grad
312 * mean_num
313 )
314 inp_grad_ptrs = inp_grad_ptr + (
315 (pid_n * C * D).to(tl.int64)
316 + (offset_c[:, None] * D).to(tl.int64)
317 + offset_d[None, :]
318 ).to(tl.int64)
319 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask)
322@libentry()
323@triton.autotune(
324 configs=runtime.get_tuned_config("cross_entropy_loss"),
325 key=["C", "D"],
326)
327@triton.jit(do_not_specialize=["label_smoothing", "mean_num"])
328def celoss_probability_bwd(
329 out_grad_ptr,
330 inp_ptr,
331 tgt_ptr,
332 w_ptr,
333 inp_grad_ptr,
334 label_smoothing,
335 mean_num,
336 C,
337 D,
338 BLOCK_C: tl.constexpr,
339 BLOCK_D: tl.constexpr,
340):
341 pid_d = tle.program_id(0)
342 pid_n = tle.program_id(1)
343 offset_d = (pid_d * BLOCK_D).to(tl.int64) + tl.arange(0, BLOCK_D)
345 out_grad_ptrs = out_grad_ptr + (pid_n * D).to(tl.int64) + offset_d
346 out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32)[
347 None, :
348 ]
350 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
351 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
352 w_tgt_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
354 for off in range(0, C, BLOCK_C):
355 offset_c = off + tl.arange(0, BLOCK_C)
356 mask = offset_c[:, None] < C and offset_d[None, :] < D
357 inp_ptrs = inp_ptr + (
358 (pid_n * C * D).to(tl.int64)
359 + (offset_c[:, None] * D).to(tl.int64)
360 + offset_d[None, :]
361 ).to(tl.int64)
362 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32)
364 tgt_ptrs = tgt_ptr + (
365 (pid_n * C * D).to(tl.int64)
366 + (offset_c[:, None] * D).to(tl.int64)
367 + offset_d[None, :]
368 ).to(tl.int64)
369 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
370 tgt = tgt * (1 - label_smoothing) + label_smoothing / C
372 w_mask = offset_c < C
373 if w_ptr is None:
374 w = w_mask
375 else:
376 w_ptrs = w_ptr + offset_c
377 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
379 w_tgt_sum += tgt * w[:, None]
381 cur_max = tl.maximum(tmp_max, inp)
382 cur_exp = tl.exp(inp - cur_max)
383 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
384 tmp_max = cur_max
385 final_max = tl.max(tmp_max, axis=0)[None, :]
386 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
387 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
388 w_tgt_sum = tl.sum(w_tgt_sum, axis=0)[None, :]
390 for off in range(0, C, BLOCK_C):
391 offset_c = off + tl.arange(0, BLOCK_C)
392 offset = (
393 (pid_n * C * D).to(tl.int64)
394 + (offset_c[:, None] * D).to(tl.int64)
395 + offset_d[None, :]
396 ).to(tl.int64)
397 inp_ptrs = inp_ptr + offset
398 mask = offset_c[:, None] < C and offset_d[None, :] < D
399 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
401 tgt_ptrs = tgt_ptr + offset
402 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
403 tgt = tgt * (1 - label_smoothing) + label_smoothing / C
405 w_mask = offset_c < C
406 if w_ptr is None:
407 w = w_mask
408 else:
409 w_ptrs = w_ptr + offset_c
410 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
412 grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - tgt * w[:, None]
413 inp_grad = grad * out_grad * mean_num
415 inp_grad_ptrs = inp_grad_ptr + offset
416 tl.store(inp_grad_ptrs, inp_grad, mask)
419@libentry()
420@triton.autotune(
421 configs=runtime.get_tuned_config("cross_entropy_loss"),
422 key=["C", "D"],
423)
424@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"])
425def celoss_indices_smooth_bwd(
426 out_grad_ptr,
427 inp_ptr,
428 tgt_ptr,
429 w_ptr,
430 inp_grad_ptr,
431 ignore_index,
432 label_smoothing,
433 mean_num,
434 C,
435 D,
436 BLOCK_C: tl.constexpr,
437 BLOCK_D: tl.constexpr,
438):
439 pid_d = tle.program_id(0)
440 pid_n = tle.program_id(1)
441 offset_d = (pid_d * BLOCK_D).to(tl.int64) + tl.arange(0, BLOCK_D)
443 tgt_ptrs = tgt_ptr + (pid_n * D).to(tl.int64) + offset_d
444 tgt_mask = offset_d < D
445 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
446 out_grad_ptrs = out_grad_ptr + (pid_n * D).to(tl.int64) + offset_d
447 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]
449 ignore_mask = (tgt != ignore_index)[None, :]
451 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
452 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
453 w_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
455 for off in range(0, C, BLOCK_C):
456 offset_c = off + tl.arange(0, BLOCK_C)
457 inp_ptrs = inp_ptr + (
458 (pid_n * C * D).to(tl.int64)
459 + (offset_c[:, None] * D).to(tl.int64)
460 + offset_d[None, :]
461 ).to(tl.int64)
462 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
463 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
465 w_mask = offset_c < C
466 if w_ptr is None:
467 w = w_mask
468 else:
469 w_ptrs = w_ptr + offset_c
470 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
472 smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32)
473 smooth = tl.where(
474 offset_c[:, None] == tgt[None, :],
475 1 - label_smoothing + label_smoothing / C,
476 smooth,
477 )
479 w_sum += smooth * w[:, None]
481 cur_max = tl.maximum(tmp_max, inp)
482 cur_exp = tl.exp(inp - cur_max)
483 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
484 tmp_max = cur_max
485 final_max = tl.max(tmp_max, axis=0)[None, :]
486 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
487 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
488 w_sum = tl.sum(w_sum, axis=0)[None, :]
490 for off in range(0, C, BLOCK_C):
491 offset_c = off + tl.arange(0, BLOCK_C)
492 inp_ptrs = inp_ptr + (
493 (pid_n * C * D).to(tl.int64)
494 + (offset_c[:, None] * D).to(tl.int64)
495 + offset_d[None, :]
496 ).to(tl.int64)
497 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
498 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
500 w_mask = offset_c < C
501 if w_ptr is None:
502 w = w_mask
503 else:
504 w_ptrs = w_ptr + offset_c
505 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
507 smooth = tl.where(
508 offset_c[:, None] == tgt[None, :],
509 1 - label_smoothing + label_smoothing / C,
510 label_smoothing / C,
511 )
513 grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[:, None]
514 inp_grad = grad * out_grad * mean_num
515 inp_grad_ptrs = inp_grad_ptr + (
516 (pid_n * C * D).to(tl.int64)
517 + (offset_c[:, None] * D).to(tl.int64)
518 + offset_d[None, :]
519 ).to(tl.int64)
520 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask)
523@libentry()
524@triton.jit
525def sum_and_scale(
526 inp_ptr,
527 out_ptr,
528 N,
529 scalebyw: tl.constexpr,
530 BLOCK_N: tl.constexpr = 128,
531 scale=1.0,
532 mean_num=None,
533):
534 mid_sum = tl.zeros(
535 [
536 BLOCK_N,
537 ],
538 dtype=tl.float32,
539 )
540 if scalebyw:
541 mid_wgt = tl.zeros(
542 [
543 BLOCK_N,
544 ],
545 dtype=tl.float32,
546 )
547 for off in range(0, N, BLOCK_N):
548 offset = off + tl.arange(0, BLOCK_N)
549 inp_ptrs = inp_ptr + offset
550 mask = offset < N
551 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0)
552 mid_sum += inp_vals
553 wgt_ptrs = scale + offset
554 wgt_vals = tl.load(wgt_ptrs, mask=mask, other=0.0)
555 mid_wgt += wgt_vals
556 out_val = tl.sum(mid_sum)
557 scale_val = tl.sum(mid_wgt)
558 tl.store(mean_num, scale_val)
559 else:
560 for off in range(0, N, BLOCK_N):
561 offset = off + tl.arange(0, BLOCK_N)
562 inp_ptrs = inp_ptr + offset
563 mask = offset < N
564 inp_vals = tl.load(inp_ptrs, mask=mask, other=0.0)
565 mid_sum += inp_vals
566 out_val = tl.sum(mid_sum)
567 scale_val = scale
568 out_val /= scale_val
569 tl.store(out_ptr, out_val)
572class CrossEntropyLoss(torch.autograd.Function):
573 @staticmethod
574 def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):
575 logger.debug("GEMS CrossEntropyLoss")
577 shape = list(inp.shape)
578 dim = inp.ndim
579 N = 1 if dim == 1 else shape[0]
580 C = shape[0] if dim == 1 else shape[1]
581 D = inp.numel() // N // C
582 axis = 0 if dim == 1 else 1
583 del shape[axis]
585 inp = inp.contiguous()
586 tgt = target.contiguous()
587 weight = weight.contiguous() if weight is not None else None
588 out = torch.empty(shape, dtype=torch.float32, device=inp.device)
589 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N)
591 if tgt.ndim == dim:
592 # target probabilities
593 with torch_device_fn.device(inp.device):
594 celoss_probability_kernel[grid](
595 inp,
596 tgt,
597 weight,
598 out,
599 label_smoothing,
600 C,
601 D,
602 )
603 elif label_smoothing == 0:
604 # target indices
605 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device)
606 with torch_device_fn.device(inp.device):
607 celoss_indices_kernel[grid](
608 inp,
609 tgt,
610 weight,
611 out,
612 w_tgt,
613 ignore_index,
614 C,
615 D,
616 )
617 else:
618 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device)
619 with torch_device_fn.device(inp.device):
620 celoss_indices_smooth_kernel[grid](
621 inp,
622 tgt,
623 weight,
624 out,
625 w_tgt,
626 ignore_index,
627 label_smoothing,
628 C,
629 D,
630 )
632 if reduction == 1: # MEAN
633 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device)
634 if tgt.ndim == dim:
635 sum_and_scale[(1,)](out, out_reduce, N * D, False, scale=N * D)
636 else:
637 wgt_sum = torch.empty([], dtype=torch.float32, device=inp.device)
638 sum_and_scale[(1,)](
639 out, out_reduce, N * D, True, scale=w_tgt, mean_num=wgt_sum
640 )
641 out = out_reduce
642 elif reduction == 2: # SUM
643 out_reduce = torch.empty([], dtype=inp.dtype, device=inp.device)
644 sum_and_scale[(1,)](out, out_reduce, N * D, False)
645 out = out_reduce
647 if inp.requires_grad:
648 ctx.save_for_backward(inp, tgt, weight)
649 ctx.N = N
650 ctx.C = C
651 ctx.D = D
652 ctx.ignore_index = ignore_index
653 ctx.label_smoothing = label_smoothing
654 ctx.shape = shape
655 ctx.mean_num = 1
656 if reduction == 1:
657 ctx.mean_num = N * D if tgt.ndim == dim else wgt_sum
659 return out.to(inp.dtype)
661 @staticmethod
662 def backward(ctx, out_grad):
663 logger.debug("GEMS CrossEntropyLoss VJP")
665 inp, tgt, weight = ctx.saved_tensors
666 N = ctx.N
667 C = ctx.C
668 D = ctx.D
669 ignore_index = ctx.ignore_index
670 label_smoothing = ctx.label_smoothing
671 mean_num = (
672 1 / ctx.mean_num.item()
673 if isinstance(ctx.mean_num, torch.Tensor)
674 else 1 / ctx.mean_num
675 )
676 shape = ctx.shape
678 out_grad = out_grad.broadcast_to(shape).contiguous()
680 inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device)
681 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N)
682 if tgt.ndim == inp.ndim:
683 celoss_probability_bwd[grid](
684 out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, C, D
685 )
686 elif label_smoothing == 0:
687 celoss_indices_bwd[grid](
688 out_grad, inp, tgt, weight, inp_grad, ignore_index, mean_num, C, D
689 )
690 else:
691 celoss_indices_smooth_bwd[grid](
692 out_grad,
693 inp,
694 tgt,
695 weight,
696 inp_grad,
697 ignore_index,
698 label_smoothing,
699 mean_num,
700 C,
701 D,
702 )
703 return inp_grad, None, None, None, None, None
706def cross_entropy_loss(
707 inp, target, weight=None, reduction="mean", ignore_index=-100, label_smoothing=0.0
708):
709 return CrossEntropyLoss.apply(
710 inp,
711 target,
712 weight,
713 _Reduction.get_enum(reduction),
714 ignore_index,
715 label_smoothing,
716 )