Coverage for src/flag_gems/runtime/backend/_cambricon/fused/cross_entropy_loss.py: 0%
507 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch.nn import _reduction as _Reduction
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
12from ..ops import sum
13from ..utils import TOTAL_CORE_NUM
15logger = logging.getLogger(__name__)
18@libentry()
19@triton.autotune(
20 configs=[
21 triton.Config({"BLOCK_C": 2**n}, num_warps=1, num_stages=3)
22 for n in range(10, 17, 2)
23 ],
24 key=["C"],
25)
26@triton.jit
27def softmax_forward_kernel(
28 inp_ptr,
29 final_max_ptr,
30 final_sum_ptr,
31 N,
32 C: tl.constexpr,
33 D: tl.constexpr,
34 BLOCK_C: tl.constexpr,
35):
36 job_id = tl.program_id(0)
37 job_num = tl.num_programs(0)
39 batch_per_job = N // job_num
40 job_remain_batch = N - batch_per_job * job_num
41 batch_per_job += 1
42 batch_begin = job_id * batch_per_job
43 if job_id >= job_remain_batch:
44 batch_per_job -= 1
45 batch_begin = job_id * batch_per_job + job_remain_batch
46 batch_end = batch_begin + batch_per_job
48 for batch_idx in range(batch_begin, batch_end):
49 pid_n = batch_idx
51 if C <= BLOCK_C:
52 offset_d = tl.arange(0, D)
53 offset_c = tl.arange(0, C)
55 inp_ptrs = (
56 inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
57 )
58 inp = tl.load(inp_ptrs).to(tl.float32)
59 final_max = tl.max(inp, axis=0)
60 final_sum = tl.sum(tl.exp(inp - final_max[None, :]), axis=0)
62 final_max_ptrs = final_max_ptr + pid_n * D + offset_d
63 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d
65 tl.store(final_max_ptrs, final_max)
66 tl.store(final_sum_ptrs, final_sum)
67 else:
68 tmp_max = tl.zeros([BLOCK_C, D], dtype=tl.float32)
69 tmp_sum = tl.zeros([BLOCK_C, D], dtype=tl.float32)
70 offset_d = tl.arange(0, D)
72 for off in range(0, C, BLOCK_C):
73 offset_c = off + tl.arange(0, BLOCK_C)
74 inp_ptrs = (
75 inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
76 )
77 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
78 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(
79 tl.float32
80 )
81 cur_max = tl.maximum(tmp_max, inp)
82 cur_exp = tl.exp(inp - cur_max)
83 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
84 tmp_max = cur_max
86 final_max = tl.max(tmp_max, axis=0)
87 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :])
88 final_sum = tl.sum(tmp_sum, axis=0)
90 final_max_ptrs = final_max_ptr + pid_n * D + offset_d
91 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d
93 tl.store(final_max_ptrs, final_max)
94 tl.store(final_sum_ptrs, final_sum)
97@libentry()
98@triton.autotune(
99 configs=[
100 triton.Config({"C_TILE_NUM": num}, num_warps=1, num_stages=s)
101 for num in [4, 8, 16, 48]
102 for s in [0, 3]
103 ],
104 key=["C"],
105 restore_value=["final_max_ptr"],
106)
107@triton.jit
108def max_kernel(
109 inp_ptr,
110 final_max_ptr,
111 N,
112 C: tl.constexpr,
113 D: tl.constexpr,
114 C_TILE_NUM: tl.constexpr,
115):
116 job_id = tl.program_id(0)
117 job_num = tl.num_programs(0)
119 batch_per_job = N // job_num
120 job_remain_batch = N - batch_per_job * job_num
121 batch_per_job += 1
122 batch_begin = job_id * batch_per_job
123 if job_id >= job_remain_batch:
124 batch_per_job -= 1
125 batch_begin = job_id * batch_per_job + job_remain_batch
126 batch_end = batch_begin + batch_per_job
128 core_id = tl.program_id(1)
129 offset_d = tl.arange(0, D)
130 BLOCK_C: tl.constexpr = (C + C_TILE_NUM - 1) // C_TILE_NUM
132 for batch_idx in range(batch_begin, batch_end):
133 pid_n = batch_idx
134 offset_c = core_id * BLOCK_C + tl.arange(0, BLOCK_C)
136 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
137 inp_mask = offset_c[:, None] < C
138 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(tl.float32)
140 final_max = tl.max(inp, axis=0)
141 final_max_ptrs = final_max_ptr + pid_n * D + offset_d
142 tl.atomic_max(final_max_ptrs, final_max)
145@libentry()
146@triton.autotune(
147 configs=[
148 triton.Config({"C_TILE_NUM": num}, num_warps=1, num_stages=s)
149 for num in [4, 8, 16, 48]
150 for s in [0, 3]
151 ],
152 key=["C"],
153 reset_to_zero=["final_sum_ptr"],
154)
155@triton.jit
156def softmax_forward_with_max_kernel(
157 inp_ptr,
158 final_max_ptr,
159 final_sum_ptr,
160 N,
161 C: tl.constexpr,
162 D: tl.constexpr,
163 C_TILE_NUM: tl.constexpr,
164):
165 job_id = tl.program_id(0)
166 job_num = tl.num_programs(0)
168 batch_per_job = N // job_num
169 job_remain_batch = N - batch_per_job * job_num
170 batch_per_job += 1
171 batch_begin = job_id * batch_per_job
172 if job_id >= job_remain_batch:
173 batch_per_job -= 1
174 batch_begin = job_id * batch_per_job + job_remain_batch
175 batch_end = batch_begin + batch_per_job
177 core_id = tl.program_id(1)
178 offset_d = tl.arange(0, D)
179 BLOCK_C: tl.constexpr = (C + C_TILE_NUM - 1) // C_TILE_NUM
181 for batch_idx in range(batch_begin, batch_end):
182 pid_n = batch_idx
183 offset_c = core_id * BLOCK_C + tl.arange(0, BLOCK_C)
185 final_max_ptrs = final_max_ptr + pid_n * D + offset_d
186 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d
187 final_max = tl.load(final_max_ptrs)
189 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
190 inp_mask = offset_c[:, None] < C
191 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(tl.float32)
193 final_sum = tl.sum(tl.exp(inp - final_max[None, :]), axis=0)
194 tl.atomic_add(final_sum_ptrs, final_sum)
197@libentry()
198@triton.autotune(
199 configs=[
200 triton.Config({"BLOCK_N": 2**n}, num_warps=4, num_stages=0)
201 for n in range(4, 11, 2)
202 ],
203 key=["N"],
204)
205@triton.jit(do_not_specialize=["ignore_index"])
206def nllloss_without_weight_kernel(
207 inp_ptr,
208 tgt_ptr,
209 final_max_ptr,
210 final_sum_ptr,
211 out_ptr,
212 ignore_index,
213 N,
214 C,
215 D: tl.constexpr,
216 BLOCK_N: tl.constexpr,
217):
218 core_id = tl.program_id(0)
219 offset_n = core_id * BLOCK_N + tl.arange(0, BLOCK_N)
220 offset_d = tl.arange(0, D)
222 tgt_ptrs = tgt_ptr + offset_n * D + offset_d
223 tgt_mask = offset_n < N
224 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
226 ignore_mask = not (tgt == ignore_index)
228 final_max_ptrs = final_max_ptr + offset_n * D + offset_d
229 final_sum_ptrs = final_sum_ptr + offset_n * D + offset_d
230 final_max = tl.load(final_max_ptrs, mask=tgt_mask, other=0)
231 final_sum = tl.load(final_sum_ptrs, mask=tgt_mask, other=1)
233 inp_tgt_ptrs = inp_ptr + offset_n * C * D + tgt * D + offset_d
234 inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32)
236 loge2 = 0.693147
237 out = tl.log2(final_sum) * loge2 + final_max - inp_tgt
239 out_ptrs = out_ptr + offset_n * D + offset_d
240 tl.store(out_ptrs, out, mask=tgt_mask and ignore_mask)
243@libentry()
244@triton.heuristics(
245 values={
246 "num_warps": lambda args: 1,
247 "num_stages": lambda args: 0,
248 },
249)
250@triton.jit(do_not_specialize=["ignore_index"])
251def nllloss_with_weight_kernel(
252 inp_ptr,
253 tgt_ptr,
254 w_ptr,
255 w_tgt_ptr,
256 final_max_ptr,
257 final_sum_ptr,
258 out_ptr,
259 ignore_index,
260 N,
261 C,
262 D: tl.constexpr,
263):
264 pid_n = tl.program_id(0)
265 offset_d = tl.arange(0, D)
267 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
268 tgt = tl.load(tgt_ptrs)
270 ignore_mask = not (tgt == ignore_index)
272 if w_ptr is None:
273 w_tgt = ignore_mask
274 else:
275 w_ptrs = w_ptr + tgt
276 w_tgt = tl.load(w_ptrs).to(tl.float32)
277 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
278 tl.store(w_tgt_ptrs, w_tgt, mask=ignore_mask)
280 final_max_ptrs = final_max_ptr + pid_n * D + offset_d
281 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d
282 final_max = tl.load(final_max_ptrs)
283 final_sum = tl.load(final_sum_ptrs)
285 inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d
286 inp_tgt = tl.load(inp_tgt_ptrs).to(tl.float32)
288 loge2 = 0.693147
289 out = (tl.log2(final_sum) * loge2 + final_max - inp_tgt) * w_tgt
291 out_ptrs = out_ptr + pid_n * D + offset_d
292 tl.store(out_ptrs, out, mask=ignore_mask)
295@libentry()
296@triton.autotune(
297 configs=runtime.get_tuned_config("cross_entropy_loss"),
298 key=["C", "D"],
299)
300@triton.jit(do_not_specialize=["label_smoothing"])
301def celoss_probability_kernel(
302 inp_ptr,
303 tgt_ptr,
304 w_ptr,
305 out_ptr,
306 label_smoothing,
307 C,
308 D,
309 BLOCK_C: tl.constexpr,
310 BLOCK_D: tl.constexpr,
311):
312 pid_d = tl.program_id(0)
313 pid_n = tl.program_id(1)
314 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
316 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
317 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
319 for off in range(0, C, BLOCK_C):
320 offset_c = off + tl.arange(0, BLOCK_C)
321 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
322 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
323 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
324 cur_max = tl.maximum(tmp_max, inp)
325 cur_exp = tl.exp(inp - cur_max)
326 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
327 tmp_max = cur_max
328 final_max = tl.max(tmp_max, axis=0)[None, :]
329 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
330 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :]
332 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
333 for off in range(0, C, BLOCK_C):
334 offset_c = off + tl.arange(0, BLOCK_C)
335 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
336 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
337 mask = offset_c[:, None] < C and offset_d[None, :] < D
338 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
339 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
340 tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C
341 log = final_sum + final_max - inp
342 w_mask = offset_c < C
343 if w_ptr is None:
344 w = w_mask
345 else:
346 w = tl.load(w_ptr + offset_c, mask=w_mask, other=0).to(tl.float32)
347 _sum += log * tgt * w[:, None]
349 out = tl.sum(_sum, axis=0)
350 out_ptrs = out_ptr + pid_n * D + offset_d
351 tl.store(out_ptrs, out, mask=offset_d < D)
354@libentry()
355@triton.autotune(
356 configs=runtime.get_tuned_config("cross_entropy_loss"),
357 key=["C", "D"],
358)
359@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"])
360def celoss_indices_smooth_kernel(
361 inp_ptr,
362 tgt_ptr,
363 w_ptr,
364 out_ptr,
365 w_tgt_ptr,
366 ignore_index,
367 label_smoothing,
368 C,
369 D,
370 BLOCK_C: tl.constexpr,
371 BLOCK_D: tl.constexpr,
372):
373 pid_d = tl.program_id(0)
374 pid_n = tl.program_id(1)
375 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
377 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
378 tgt_mask = offset_d < D
379 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
381 ignore_mask = not (tgt == ignore_index) and tgt_mask
383 if w_ptr is None:
384 w_tgt = ignore_mask
385 else:
386 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0)
387 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
388 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask)
390 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
391 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
393 for off in range(0, C, BLOCK_C):
394 offset_c = off + tl.arange(0, BLOCK_C)
395 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
396 mask = offset_c[:, None] < C and offset_d[None, :] < D
397 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32)
398 cur_max = tl.maximum(tmp_max, inp)
399 cur_exp = tl.exp(inp - cur_max)
400 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
401 tmp_max = cur_max
402 final_max = tl.max(tmp_max, axis=0)[None, :]
403 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
404 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :]
405 final_sum_max = final_sum + final_max
407 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
408 for off in range(0, C, BLOCK_C):
409 offset_c = off + tl.arange(0, BLOCK_C)
410 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
411 mask = offset_c[:, None] < C and offset_d[None, :] < D
412 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
414 w_mask = offset_c < C
415 if w_ptr is None:
416 w = w_mask
417 else:
418 w = tl.load(w_ptr + offset_c, w_mask, other=0).to(tl.float32)
420 smooth = tl.where(
421 offset_c[:, None] == tgt[None, :],
422 1 - label_smoothing + label_smoothing / C,
423 label_smoothing / C,
424 ).to(tl.float32)
426 log = final_sum_max - inp
427 _sum += log * smooth * w[:, None]
429 out = tl.sum(_sum, axis=0)
430 out = tl.where(ignore_mask, out, 0)
431 out_ptrs = out_ptr + pid_n * D + offset_d
432 tl.store(out_ptrs, out, mask=tgt_mask)
435@triton.jit
436def single_celoss_indice_bwd(
437 pid_n,
438 offset_c,
439 offset_d,
440 final_max,
441 final_sum,
442 tgt,
443 w_tgt,
444 out_grad,
445 mean_num,
446 inp_ptr,
447 inp_grad_ptr,
448 ignore_mask,
449 C,
450 D,
451):
452 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
453 inp_mask = offset_c[:, None] < C
454 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(tl.float32)
456 minus_one = offset_c[:, None] == tgt[None, :]
457 inp_grad = (
458 (tl.exp(inp - final_max[None, :]) / final_sum[None, :] - minus_one)
459 * w_tgt
460 * out_grad
461 * mean_num
462 )
463 inp_grad_ptrs = (
464 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
465 )
466 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask)
469def config_prune(configs, named_args, **kwargs):
470 pruned_configs = []
472 for config in configs:
473 kw = config.kwargs
474 mode, num, BLOCK_C = (kw["TILE_MODE"], kw["C_TILE_NUM"], kw["BLOCK_C"])
475 if (mode == 0 and num == 1) or (mode == 1 and num >= 4 and BLOCK_C <= 1024):
476 pruned_configs.append(config)
477 return pruned_configs
480@libentry()
481@triton.autotune(
482 configs=[
483 triton.Config(
484 {
485 "TILE_MODE": mode,
486 "C_TILE_NUM": num,
487 "BLOCK_C": 2**n,
488 },
489 num_warps=1,
490 num_stages=s,
491 )
492 for mode in [0, 1]
493 for num in [1, 4, 8, 16, 48]
494 for n in range(10, 17, 2)
495 for s in [0, 3]
496 ],
497 key=["C"],
498 prune_configs_by={
499 "early_config_prune": config_prune,
500 },
501)
502@triton.jit(do_not_specialize=["ignore_index", "mean_num"])
503def celoss_indice_bwd_with_saved_sum_kernel(
504 out_grad_ptr,
505 inp_ptr,
506 tgt_ptr,
507 w_ptr,
508 inp_grad_ptr,
509 final_max_ptr,
510 final_sum_ptr,
511 ignore_index,
512 mean_num,
513 N,
514 C: tl.constexpr,
515 D: tl.constexpr,
516 is_has_weight: tl.constexpr,
517 is_has_ignore_index: tl.constexpr,
518 is_tgt_in_i32: tl.constexpr,
519 TILE_MODE: tl.constexpr,
520 C_TILE_NUM: tl.constexpr,
521 BLOCK_C: tl.constexpr,
522):
523 job_id = tl.program_id(0)
524 job_num = tl.num_programs(0)
526 batch_per_job = N // job_num
527 job_remain_batch = N - batch_per_job * job_num
528 batch_per_job += 1
529 batch_begin = job_id * batch_per_job
530 if job_id >= job_remain_batch:
531 batch_per_job -= 1
532 batch_begin = job_id * batch_per_job + job_remain_batch
533 batch_end = batch_begin + batch_per_job
535 for batch_idx in range(batch_begin, batch_end):
536 pid_n = batch_idx
537 offset_d = tl.arange(0, D)
539 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
540 if is_tgt_in_i32:
541 tgt = tl.load(tgt_ptrs).to(tl.int32)
542 else:
543 tgt = tl.load(tgt_ptrs)
545 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
546 out_grad = tl.load(out_grad_ptrs).to(tl.float32)[None, :]
548 if is_has_weight:
549 w_ptrs = w_ptr + tgt
550 w_tgt = tl.load(w_ptrs).to(tl.float32)[None, :]
551 else:
552 w_tgt = 1
554 if is_has_ignore_index:
555 ignore_mask = (tgt != ignore_index)[None, :]
556 else:
557 ignore_mask = True
559 final_max_ptrs = final_max_ptr + pid_n * D + offset_d
560 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d
561 final_max = tl.load(final_max_ptrs)
562 final_sum = tl.load(final_sum_ptrs)
564 if TILE_MODE == 0:
565 if C <= BLOCK_C:
566 offset_c = tl.arange(0, C)
567 single_celoss_indice_bwd(
568 pid_n,
569 offset_c,
570 offset_d,
571 final_max,
572 final_sum,
573 tgt,
574 w_tgt,
575 out_grad,
576 mean_num,
577 inp_ptr,
578 inp_grad_ptr,
579 ignore_mask,
580 C,
581 D,
582 )
583 else:
584 for off in range(0, C, BLOCK_C):
585 offset_c = off + tl.arange(0, BLOCK_C)
586 single_celoss_indice_bwd(
587 pid_n,
588 offset_c,
589 offset_d,
590 final_max,
591 final_sum,
592 tgt,
593 w_tgt,
594 out_grad,
595 mean_num,
596 inp_ptr,
597 inp_grad_ptr,
598 ignore_mask,
599 C,
600 D,
601 )
602 else:
603 core_id = tl.program_id(1)
604 C_TILE_SIZE: tl.constexpr = (C + C_TILE_NUM - 1) // C_TILE_NUM
605 offset_c = core_id * C_TILE_SIZE + tl.arange(0, C_TILE_SIZE)
607 single_celoss_indice_bwd(
608 pid_n,
609 offset_c,
610 offset_d,
611 final_max,
612 final_sum,
613 tgt,
614 w_tgt,
615 out_grad,
616 mean_num,
617 inp_ptr,
618 inp_grad_ptr,
619 ignore_mask,
620 C,
621 D,
622 )
625@libentry()
626@triton.autotune(
627 configs=runtime.get_tuned_config("cross_entropy_loss"),
628 key=["C", "D"],
629)
630@triton.jit(do_not_specialize=["label_smoothing", "mean_num"])
631def celoss_probability_bwd(
632 out_grad_ptr,
633 inp_ptr,
634 tgt_ptr,
635 w_ptr,
636 inp_grad_ptr,
637 label_smoothing,
638 mean_num,
639 C,
640 D,
641 BLOCK_C: tl.constexpr,
642 BLOCK_D: tl.constexpr,
643):
644 pid_d = tl.program_id(0)
645 pid_n = tl.program_id(1)
646 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
648 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
649 out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32)[
650 None, :
651 ]
653 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
654 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
655 w_tgt_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
657 for off in range(0, C, BLOCK_C):
658 offset_c = off + tl.arange(0, BLOCK_C)
659 mask = offset_c[:, None] < C and offset_d[None, :] < D
660 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
661 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32)
663 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
664 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
665 tgt = tgt * (1 - label_smoothing) + label_smoothing / C
667 w_mask = offset_c < C
668 if w_ptr is None:
669 w = w_mask
670 else:
671 w_ptrs = w_ptr + offset_c
672 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
674 w_tgt_sum += tgt * w[:, None]
676 cur_max = tl.maximum(tmp_max, inp)
677 cur_exp = tl.exp(inp - cur_max)
678 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
679 tmp_max = cur_max
680 final_max = tl.max(tmp_max, axis=0)[None, :]
681 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
682 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
683 w_tgt_sum = tl.sum(w_tgt_sum, axis=0)[None, :]
685 for off in range(0, C, BLOCK_C):
686 offset_c = off + tl.arange(0, BLOCK_C)
687 offset = pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
688 inp_ptrs = inp_ptr + offset
689 mask = offset_c[:, None] < C and offset_d[None, :] < D
690 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
692 tgt_ptrs = tgt_ptr + offset
693 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
694 tgt = tgt * (1 - label_smoothing) + label_smoothing / C
696 w_mask = offset_c < C
697 if w_ptr is None:
698 w = w_mask
699 else:
700 w_ptrs = w_ptr + offset_c
701 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
703 grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - tgt * w[:, None]
704 inp_grad = grad * out_grad * mean_num
706 inp_grad_ptrs = inp_grad_ptr + offset
707 tl.store(inp_grad_ptrs, inp_grad, mask)
710@libentry()
711@triton.autotune(
712 configs=runtime.get_tuned_config("cross_entropy_loss"),
713 key=["C", "D"],
714)
715@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"])
716def celoss_indices_smooth_bwd(
717 out_grad_ptr,
718 inp_ptr,
719 tgt_ptr,
720 w_ptr,
721 inp_grad_ptr,
722 ignore_index,
723 label_smoothing,
724 mean_num,
725 C,
726 D,
727 BLOCK_C: tl.constexpr,
728 BLOCK_D: tl.constexpr,
729):
730 pid_d = tl.program_id(0)
731 pid_n = tl.program_id(1)
732 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
734 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
735 tgt_mask = offset_d < D
736 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
737 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
738 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]
740 ignore_mask = (tgt != ignore_index)[None, :]
742 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
743 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
744 w_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
746 for off in range(0, C, BLOCK_C):
747 offset_c = off + tl.arange(0, BLOCK_C)
748 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
749 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
750 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
752 w_mask = offset_c < C
753 if w_ptr is None:
754 w = w_mask
755 else:
756 w_ptrs = w_ptr + offset_c
757 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
759 smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32)
760 smooth = tl.where(
761 offset_c[:, None] == tgt[None, :],
762 1 - label_smoothing + label_smoothing / C,
763 smooth,
764 )
766 w_sum += smooth * w[:, None]
768 cur_max = tl.maximum(tmp_max, inp)
769 cur_exp = tl.exp(inp - cur_max)
770 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
771 tmp_max = cur_max
772 final_max = tl.max(tmp_max, axis=0)[None, :]
773 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
774 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
775 w_sum = tl.sum(w_sum, axis=0)[None, :]
777 for off in range(0, C, BLOCK_C):
778 offset_c = off + tl.arange(0, BLOCK_C)
779 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
780 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
781 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
783 w_mask = offset_c < C
784 if w_ptr is None:
785 w = w_mask
786 else:
787 w_ptrs = w_ptr + offset_c
788 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
790 smooth = tl.where(
791 offset_c[:, None] == tgt[None, :],
792 1 - label_smoothing + label_smoothing / C,
793 label_smoothing / C,
794 )
796 grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[:, None]
797 inp_grad = grad * out_grad * mean_num
798 inp_grad_ptrs = (
799 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
800 )
801 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask)
804class CrossEntropyLoss(torch.autograd.Function):
805 @staticmethod
806 def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):
807 logger.debug("GEMS_CAMBRICON CrossEntropyLoss")
809 shape = list(inp.shape)
810 dim = inp.ndim
811 N = 1 if dim == 1 else shape[0]
812 C = shape[0] if dim == 1 else shape[1]
813 D = inp.numel() // N // C
814 axis = 0 if dim == 1 else 1
815 del shape[axis]
817 inp = inp.contiguous()
818 tgt = target.contiguous()
820 ctx.N = N
821 ctx.C = C
822 ctx.D = D
823 ctx.ignore_index = ignore_index
824 ctx.label_smoothing = label_smoothing
825 ctx.shape = shape
827 final_max = None
828 final_sum = None
830 mean_num = 1
831 if reduction == 1 and tgt.ndim == dim:
832 mean_num = 1 / (N * D)
833 out = torch.empty(shape, dtype=torch.float32, device=inp.device)
835 def get_result(inp, tgt, out, reduction, mean_num):
836 if reduction == 0: # NONE
837 return out.to(inp.dtype)
838 elif reduction == 1: # MEAN
839 return (sum(out) * mean_num).to(inp.dtype)
840 else: # SUM
841 return sum(out).to(inp.dtype)
843 if weight is None and tgt.ndim != dim and label_smoothing == 0:
844 final_max = torch.full(
845 shape,
846 torch.finfo(torch.float32).min,
847 dtype=torch.float32,
848 device=inp.device,
849 )
850 final_sum = torch.zeros(shape, dtype=torch.float32, device=inp.device)
851 with torch.mlu.device(inp.device):
852 if C <= (32 * 1000) or C > (2048 * 1000):
853 softmax_forward_kernel[(TOTAL_CORE_NUM,)](
854 inp, final_max, final_sum, N, C, D
855 )
856 else:
857 grid = lambda meta: (
858 triton.cdiv(TOTAL_CORE_NUM, meta["C_TILE_NUM"]),
859 meta["C_TILE_NUM"],
860 )
861 max_kernel[grid](inp, final_max, N, C, D)
862 softmax_forward_with_max_kernel[grid](
863 inp, final_max, final_sum, N, C, D
864 )
866 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),)
867 nllloss_without_weight_kernel[grid](
868 inp, tgt, final_max, final_sum, out, ignore_index, N, C, D
869 )
870 if reduction == 1:
871 if ignore_index < 0 or ignore_index >= C:
872 mean_num = 1 / C
873 else:
874 mean_num = 1 / (C - 1)
875 ctx.mean_num = mean_num
877 ctx.save_for_backward(inp, tgt, weight, final_max, final_sum)
878 return get_result(inp, tgt, out, reduction, mean_num)
880 weight = weight.contiguous() if weight is not None else None
881 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N)
883 if tgt.ndim == dim:
884 # target probabilities
885 with torch_device_fn.device(inp.device):
886 celoss_probability_kernel[grid](
887 inp,
888 tgt,
889 weight,
890 out,
891 label_smoothing,
892 C,
893 D,
894 )
895 elif label_smoothing == 0:
896 # target indices
897 w_tgt = torch.zeros(shape, dtype=torch.float32, device=inp.device)
898 final_max = torch.empty(shape, dtype=torch.float32, device=inp.device)
899 final_sum = torch.empty(shape, dtype=torch.float32, device=inp.device)
900 with torch_device_fn.device(inp.device):
901 softmax_forward_kernel[(TOTAL_CORE_NUM,)](
902 inp, final_max, final_sum, N, C, D
903 )
904 nllloss_with_weight_kernel[(N,)](
905 inp,
906 tgt,
907 weight,
908 w_tgt,
909 final_max,
910 final_sum,
911 out,
912 ignore_index,
913 N,
914 C,
915 D,
916 )
917 else:
918 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device)
919 with torch_device_fn.device(inp.device):
920 celoss_indices_smooth_kernel[grid](
921 inp,
922 tgt,
923 weight,
924 out,
925 w_tgt,
926 ignore_index,
927 label_smoothing,
928 C,
929 D,
930 )
931 ctx.save_for_backward(inp, tgt, weight, final_max, final_sum)
932 ctx.mean_num = 1
934 if reduction == 1 and tgt.ndim != dim:
935 mean_num = 1 / sum(w_tgt).item()
936 ctx.mean_num = mean_num
937 return get_result(inp, tgt, out, reduction, mean_num)
939 @staticmethod
940 def backward(ctx, out_grad):
941 logger.debug("GEMS_CAMBRICON CrossEntropyLoss VJP")
943 inp, tgt, weight, final_max, final_sum = ctx.saved_tensors
944 N = ctx.N
945 C = ctx.C
946 D = ctx.D
947 ignore_index = ctx.ignore_index
948 label_smoothing = ctx.label_smoothing
949 mean_num = ctx.mean_num
950 shape = ctx.shape
952 out_grad = out_grad.broadcast_to(shape).contiguous()
954 inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device)
955 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N)
956 if tgt.ndim == inp.ndim:
957 celoss_probability_bwd[grid](
958 out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, C, D
959 )
960 elif label_smoothing == 0:
961 if final_sum is not None:
962 is_has_weight = weight is not None
963 is_has_ignore_index = ignore_index >= 0 and ignore_index < C
964 is_tgt_in_i32 = C < (1 << 31)
965 grid = lambda meta: (
966 triton.cdiv(TOTAL_CORE_NUM, meta["C_TILE_NUM"]),
967 meta["C_TILE_NUM"],
968 )
969 celoss_indice_bwd_with_saved_sum_kernel[grid](
970 out_grad,
971 inp,
972 tgt,
973 weight,
974 inp_grad,
975 final_max,
976 final_sum,
977 ignore_index,
978 mean_num,
979 N,
980 C,
981 D,
982 is_has_weight,
983 is_has_ignore_index,
984 is_tgt_in_i32,
985 )
986 else:
987 celoss_indices_smooth_bwd[grid](
988 out_grad,
989 inp,
990 tgt,
991 weight,
992 inp_grad,
993 ignore_index,
994 label_smoothing,
995 mean_num,
996 C,
997 D,
998 )
999 return inp_grad, None, None, None, None, None
1002def cross_entropy_loss(
1003 inp, target, weight=None, reduction="mean", ignore_index=-100, label_smoothing=0.0
1004):
1005 return CrossEntropyLoss.apply(
1006 inp,
1007 target,
1008 weight,
1009 _Reduction.get_enum(reduction),
1010 ignore_index,
1011 label_smoothing,
1012 )