Coverage for src/flag_gems/runtime/backend/_tsingmicro/fused/cross_entropy_loss.py: 0%
506 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch.nn import _reduction as _Reduction
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
12logger = logging.getLogger(__name__)
14TOTAL_CORE_NUM = torch_device_fn.get_device_properties().multi_processor_count
17@libentry()
18@triton.autotune(
19 configs=[
20 triton.Config({"BLOCK_C": 2**n}, num_warps=1, num_stages=3)
21 for n in range(10, 17, 2)
22 ],
23 key=["C"],
24)
25@triton.jit
26def softmax_forward_kernel(
27 inp_ptr,
28 final_max_ptr,
29 final_sum_ptr,
30 N,
31 C: tl.constexpr,
32 D: tl.constexpr,
33 BLOCK_C: tl.constexpr,
34):
35 job_id = tl.program_id(0)
36 job_num = tl.num_programs(0)
38 batch_per_job = N // job_num
39 job_remain_batch = N - batch_per_job * job_num
40 batch_per_job += 1
41 batch_begin = job_id * batch_per_job
42 if job_id >= job_remain_batch:
43 batch_per_job -= 1
44 batch_begin = job_id * batch_per_job + job_remain_batch
45 batch_end = batch_begin + batch_per_job
47 for batch_idx in range(batch_begin, batch_end):
48 pid_n = batch_idx
50 if C <= BLOCK_C:
51 offset_d = tl.arange(0, D)
52 offset_c = tl.arange(0, C)
54 inp_ptrs = (
55 inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
56 )
57 inp = tl.load(inp_ptrs).to(tl.float32)
58 final_max = tl.max(inp, axis=0)
59 final_sum = tl.sum(tl.exp(inp - final_max[None, :]), axis=0)
61 final_max_ptrs = final_max_ptr + pid_n * D + offset_d
62 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d
64 tl.store(final_max_ptrs, final_max)
65 tl.store(final_sum_ptrs, final_sum)
66 else:
67 tmp_max = tl.zeros([BLOCK_C, D], dtype=tl.float32)
68 tmp_sum = tl.zeros([BLOCK_C, D], dtype=tl.float32)
69 offset_d = tl.arange(0, D)
71 for off in range(0, C, BLOCK_C):
72 offset_c = off + tl.arange(0, BLOCK_C)
73 inp_ptrs = (
74 inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
75 )
76 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
77 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(
78 tl.float32
79 )
80 cur_max = tl.maximum(tmp_max, inp)
81 cur_exp = tl.exp(inp - cur_max)
82 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
83 tmp_max = cur_max
85 final_max = tl.max(tmp_max, axis=0)
86 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max[None, :])
87 final_sum = tl.sum(tmp_sum, axis=0)
89 final_max_ptrs = final_max_ptr + pid_n * D + offset_d
90 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d
92 tl.store(final_max_ptrs, final_max)
93 tl.store(final_sum_ptrs, final_sum)
96@libentry()
97@triton.autotune(
98 configs=[
99 triton.Config({"C_TILE_NUM": num}, num_warps=1, num_stages=s)
100 for num in [4, 8, 16, 48]
101 for s in [0, 3]
102 ],
103 key=["C"],
104 restore_value=["final_max_ptr"],
105)
106@triton.jit
107def max_kernel(
108 inp_ptr,
109 final_max_ptr,
110 N,
111 C: tl.constexpr,
112 D: tl.constexpr,
113 C_TILE_NUM: tl.constexpr,
114):
115 job_id = tl.program_id(0)
116 job_num = tl.num_programs(0)
118 batch_per_job = N // job_num
119 job_remain_batch = N - batch_per_job * job_num
120 batch_per_job += 1
121 batch_begin = job_id * batch_per_job
122 if job_id >= job_remain_batch:
123 batch_per_job -= 1
124 batch_begin = job_id * batch_per_job + job_remain_batch
125 batch_end = batch_begin + batch_per_job
127 core_id = tl.program_id(1)
128 offset_d = tl.arange(0, D)
129 BLOCK_C: tl.constexpr = (C + C_TILE_NUM - 1) // C_TILE_NUM
131 for batch_idx in range(batch_begin, batch_end):
132 pid_n = batch_idx
133 offset_c = core_id * BLOCK_C + tl.arange(0, BLOCK_C)
135 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
136 inp_mask = offset_c[:, None] < C
137 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(tl.float32)
139 final_max = tl.max(inp, axis=0)
140 final_max_ptrs = final_max_ptr + pid_n * D + offset_d
141 tl.atomic_max(final_max_ptrs, final_max)
144@libentry()
145@triton.autotune(
146 configs=[
147 triton.Config({"C_TILE_NUM": num}, num_warps=1, num_stages=s)
148 for num in [4, 8, 16, 48]
149 for s in [0, 3]
150 ],
151 key=["C"],
152 reset_to_zero=["final_sum_ptr"],
153)
154@triton.jit
155def softmax_forward_with_max_kernel(
156 inp_ptr,
157 final_max_ptr,
158 final_sum_ptr,
159 N,
160 C: tl.constexpr,
161 D: tl.constexpr,
162 C_TILE_NUM: tl.constexpr,
163):
164 job_id = tl.program_id(0)
165 job_num = tl.num_programs(0)
167 batch_per_job = N // job_num
168 job_remain_batch = N - batch_per_job * job_num
169 batch_per_job += 1
170 batch_begin = job_id * batch_per_job
171 if job_id >= job_remain_batch:
172 batch_per_job -= 1
173 batch_begin = job_id * batch_per_job + job_remain_batch
174 batch_end = batch_begin + batch_per_job
176 core_id = tl.program_id(1)
177 offset_d = tl.arange(0, D)
178 BLOCK_C: tl.constexpr = (C + C_TILE_NUM - 1) // C_TILE_NUM
180 for batch_idx in range(batch_begin, batch_end):
181 pid_n = batch_idx
182 offset_c = core_id * BLOCK_C + tl.arange(0, BLOCK_C)
184 final_max_ptrs = final_max_ptr + pid_n * D + offset_d
185 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d
186 final_max = tl.load(final_max_ptrs)
188 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
189 inp_mask = offset_c[:, None] < C
190 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(tl.float32)
192 final_sum = tl.sum(tl.exp(inp - final_max[None, :]), axis=0)
193 tl.atomic_add(final_sum_ptrs, final_sum)
196@libentry()
197@triton.autotune(
198 configs=[
199 triton.Config({"BLOCK_N": 2**n}, num_warps=4, num_stages=0)
200 for n in range(4, 11, 2)
201 ],
202 key=["N"],
203)
204@triton.jit(do_not_specialize=["ignore_index"])
205def nllloss_without_weight_kernel(
206 inp_ptr,
207 tgt_ptr,
208 final_max_ptr,
209 final_sum_ptr,
210 out_ptr,
211 ignore_index,
212 N,
213 C,
214 D: tl.constexpr,
215 BLOCK_N: tl.constexpr,
216):
217 core_id = tl.program_id(0)
218 offset_n = core_id * BLOCK_N + tl.arange(0, BLOCK_N)
219 offset_d = tl.arange(0, D)
221 tgt_ptrs = tgt_ptr + offset_n * D + offset_d
222 tgt_mask = offset_n < N
223 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
225 ignore_mask = not (tgt == ignore_index)
227 final_max_ptrs = final_max_ptr + offset_n * D + offset_d
228 final_sum_ptrs = final_sum_ptr + offset_n * D + offset_d
229 final_max = tl.load(final_max_ptrs, mask=tgt_mask, other=0)
230 final_sum = tl.load(final_sum_ptrs, mask=tgt_mask, other=1)
232 inp_tgt_ptrs = inp_ptr + offset_n * C * D + tgt * D + offset_d
233 inp_tgt = tl.load(inp_tgt_ptrs, mask=tgt_mask, other=-float("inf")).to(tl.float32)
235 loge2 = 0.693147
236 out = tl.log2(final_sum) * loge2 + final_max - inp_tgt
238 out_ptrs = out_ptr + offset_n * D + offset_d
239 tl.store(out_ptrs, out, mask=tgt_mask and ignore_mask)
242@libentry()
243@triton.heuristics(
244 values={
245 "num_warps": lambda args: 1,
246 "num_stages": lambda args: 0,
247 },
248)
249@triton.jit(do_not_specialize=["ignore_index"])
250def nllloss_with_weight_kernel(
251 inp_ptr,
252 tgt_ptr,
253 w_ptr,
254 w_tgt_ptr,
255 final_max_ptr,
256 final_sum_ptr,
257 out_ptr,
258 ignore_index,
259 N,
260 C,
261 D: tl.constexpr,
262):
263 pid_n = tl.program_id(0)
264 offset_d = tl.arange(0, D)
266 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
267 tgt = tl.load(tgt_ptrs)
269 ignore_mask = not (tgt == ignore_index)
271 if w_ptr is None:
272 w_tgt = ignore_mask
273 else:
274 w_ptrs = w_ptr + tgt
275 w_tgt = tl.load(w_ptrs).to(tl.float32)
276 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
277 tl.store(w_tgt_ptrs, w_tgt, mask=ignore_mask)
279 final_max_ptrs = final_max_ptr + pid_n * D + offset_d
280 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d
281 final_max = tl.load(final_max_ptrs)
282 final_sum = tl.load(final_sum_ptrs)
284 inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d
285 inp_tgt = tl.load(inp_tgt_ptrs).to(tl.float32)
287 loge2 = 0.693147
288 out = (tl.log2(final_sum) * loge2 + final_max - inp_tgt) * w_tgt
290 out_ptrs = out_ptr + pid_n * D + offset_d
291 tl.store(out_ptrs, out, mask=ignore_mask)
294@libentry()
295@triton.autotune(
296 configs=runtime.get_tuned_config("cross_entropy_loss"),
297 key=["C", "D"],
298)
299@triton.jit(do_not_specialize=["label_smoothing"])
300def celoss_probability_kernel(
301 inp_ptr,
302 tgt_ptr,
303 w_ptr,
304 out_ptr,
305 label_smoothing,
306 C,
307 D,
308 BLOCK_C: tl.constexpr,
309 BLOCK_D: tl.constexpr,
310):
311 pid_d = tl.program_id(0)
312 pid_n = tl.program_id(1)
313 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
315 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
316 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
318 for off in range(0, C, BLOCK_C):
319 offset_c = off + tl.arange(0, BLOCK_C)
320 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
321 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
322 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
323 cur_max = tl.maximum(tmp_max, inp)
324 cur_exp = tl.exp(inp - cur_max)
325 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
326 tmp_max = cur_max
327 final_max = tl.max(tmp_max, axis=0)[None, :]
328 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
329 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :]
331 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
332 for off in range(0, C, BLOCK_C):
333 offset_c = off + tl.arange(0, BLOCK_C)
334 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
335 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
336 mask = offset_c[:, None] < C and offset_d[None, :] < D
337 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
338 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
339 tgt = tgt * (1.0 - label_smoothing) + label_smoothing / C
340 log = final_sum + final_max - inp
341 w_mask = offset_c < C
342 if w_ptr is None:
343 w = w_mask
344 else:
345 w = tl.load(w_ptr + offset_c, mask=w_mask, other=0).to(tl.float32)
346 _sum += log * tgt * w[:, None]
348 out = tl.sum(_sum, axis=0)
349 out_ptrs = out_ptr + pid_n * D + offset_d
350 tl.store(out_ptrs, out, mask=offset_d < D)
353@libentry()
354@triton.autotune(
355 configs=runtime.get_tuned_config("cross_entropy_loss"),
356 key=["C", "D"],
357)
358@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"])
359def celoss_indices_smooth_kernel(
360 inp_ptr,
361 tgt_ptr,
362 w_ptr,
363 out_ptr,
364 w_tgt_ptr,
365 ignore_index,
366 label_smoothing,
367 C,
368 D,
369 BLOCK_C: tl.constexpr,
370 BLOCK_D: tl.constexpr,
371):
372 pid_d = tl.program_id(0)
373 pid_n = tl.program_id(1)
374 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
376 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
377 tgt_mask = offset_d < D
378 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
380 ignore_mask = not (tgt == ignore_index) and tgt_mask
382 if w_ptr is None:
383 w_tgt = ignore_mask
384 else:
385 w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0)
386 w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
387 tl.store(w_tgt_ptrs, w_tgt, mask=tgt_mask)
389 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
390 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
392 for off in range(0, C, BLOCK_C):
393 offset_c = off + tl.arange(0, BLOCK_C)
394 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
395 mask = offset_c[:, None] < C and offset_d[None, :] < D
396 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32)
397 cur_max = tl.maximum(tmp_max, inp)
398 cur_exp = tl.exp(inp - cur_max)
399 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
400 tmp_max = cur_max
401 final_max = tl.max(tmp_max, axis=0)[None, :]
402 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
403 final_sum = tl.log(tl.sum(tmp_sum, axis=0))[None, :]
404 final_sum_max = final_sum + final_max
406 _sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
407 for off in range(0, C, BLOCK_C):
408 offset_c = off + tl.arange(0, BLOCK_C)
409 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
410 mask = offset_c[:, None] < C and offset_d[None, :] < D
411 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
413 w_mask = offset_c < C
414 if w_ptr is None:
415 w = w_mask
416 else:
417 w = tl.load(w_ptr + offset_c, w_mask, other=0).to(tl.float32)
419 smooth = tl.where(
420 offset_c[:, None] == tgt[None, :],
421 1 - label_smoothing + label_smoothing / C,
422 label_smoothing / C,
423 ).to(tl.float32)
425 log = final_sum_max - inp
426 _sum += log * smooth * w[:, None]
428 out = tl.sum(_sum, axis=0)
429 out = tl.where(ignore_mask, out, 0)
430 out_ptrs = out_ptr + pid_n * D + offset_d
431 tl.store(out_ptrs, out, mask=tgt_mask)
434@triton.jit
435def single_celoss_indice_bwd(
436 pid_n,
437 offset_c,
438 offset_d,
439 final_max,
440 final_sum,
441 tgt,
442 w_tgt,
443 out_grad,
444 mean_num,
445 inp_ptr,
446 inp_grad_ptr,
447 ignore_mask,
448 C,
449 D,
450):
451 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
452 inp_mask = offset_c[:, None] < C
453 inp = tl.load(inp_ptrs, mask=inp_mask, other=-float("inf")).to(tl.float32)
455 minus_one = offset_c[:, None] == tgt[None, :]
456 inp_grad = (
457 (tl.exp(inp - final_max[None, :]) / final_sum[None, :] - minus_one)
458 * w_tgt
459 * out_grad
460 * mean_num
461 )
462 inp_grad_ptrs = (
463 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
464 )
465 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask)
468def config_prune(configs, named_args, **kwargs):
469 pruned_configs = []
471 for config in configs:
472 kw = config.kwargs
473 mode, num, BLOCK_C = (kw["TILE_MODE"], kw["C_TILE_NUM"], kw["BLOCK_C"])
474 if (mode == 0 and num == 1) or (mode == 1 and num >= 4 and BLOCK_C <= 1024):
475 pruned_configs.append(config)
476 return pruned_configs
479@libentry()
480@triton.autotune(
481 configs=[
482 triton.Config(
483 {
484 "TILE_MODE": mode,
485 "C_TILE_NUM": num,
486 "BLOCK_C": 2**n,
487 },
488 num_warps=1,
489 num_stages=s,
490 )
491 for mode in [0, 1]
492 for num in [1, 4, 8, 16, 48]
493 for n in range(10, 17, 2)
494 for s in [0, 3]
495 ],
496 key=["C"],
497 prune_configs_by={
498 "early_config_prune": config_prune,
499 },
500)
501@triton.jit(do_not_specialize=["ignore_index", "mean_num"])
502def celoss_indice_bwd_with_saved_sum_kernel(
503 out_grad_ptr,
504 inp_ptr,
505 tgt_ptr,
506 w_ptr,
507 inp_grad_ptr,
508 final_max_ptr,
509 final_sum_ptr,
510 ignore_index,
511 mean_num,
512 N,
513 C: tl.constexpr,
514 D: tl.constexpr,
515 is_has_weight: tl.constexpr,
516 is_has_ignore_index: tl.constexpr,
517 is_tgt_in_i32: tl.constexpr,
518 TILE_MODE: tl.constexpr,
519 C_TILE_NUM: tl.constexpr,
520 BLOCK_C: tl.constexpr,
521):
522 job_id = tl.program_id(0)
523 job_num = tl.num_programs(0)
525 batch_per_job = N // job_num
526 job_remain_batch = N - batch_per_job * job_num
527 batch_per_job += 1
528 batch_begin = job_id * batch_per_job
529 if job_id >= job_remain_batch:
530 batch_per_job -= 1
531 batch_begin = job_id * batch_per_job + job_remain_batch
532 batch_end = batch_begin + batch_per_job
534 for batch_idx in range(batch_begin, batch_end):
535 pid_n = batch_idx
536 offset_d = tl.arange(0, D)
538 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
539 if is_tgt_in_i32:
540 tgt = tl.load(tgt_ptrs).to(tl.int32)
541 else:
542 tgt = tl.load(tgt_ptrs)
544 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
545 out_grad = tl.load(out_grad_ptrs).to(tl.float32)[None, :]
547 if is_has_weight:
548 w_ptrs = w_ptr + tgt
549 w_tgt = tl.load(w_ptrs).to(tl.float32)[None, :]
550 else:
551 w_tgt = 1
553 if is_has_ignore_index:
554 ignore_mask = (tgt != ignore_index)[None, :]
555 else:
556 ignore_mask = True
558 final_max_ptrs = final_max_ptr + pid_n * D + offset_d
559 final_sum_ptrs = final_sum_ptr + pid_n * D + offset_d
560 final_max = tl.load(final_max_ptrs)
561 final_sum = tl.load(final_sum_ptrs)
563 if TILE_MODE == 0:
564 if C <= BLOCK_C:
565 offset_c = tl.arange(0, C)
566 single_celoss_indice_bwd(
567 pid_n,
568 offset_c,
569 offset_d,
570 final_max,
571 final_sum,
572 tgt,
573 w_tgt,
574 out_grad,
575 mean_num,
576 inp_ptr,
577 inp_grad_ptr,
578 ignore_mask,
579 C,
580 D,
581 )
582 else:
583 for off in range(0, C, BLOCK_C):
584 offset_c = off + tl.arange(0, BLOCK_C)
585 single_celoss_indice_bwd(
586 pid_n,
587 offset_c,
588 offset_d,
589 final_max,
590 final_sum,
591 tgt,
592 w_tgt,
593 out_grad,
594 mean_num,
595 inp_ptr,
596 inp_grad_ptr,
597 ignore_mask,
598 C,
599 D,
600 )
601 else:
602 core_id = tl.program_id(1)
603 C_TILE_SIZE: tl.constexpr = (C + C_TILE_NUM - 1) // C_TILE_NUM
604 offset_c = core_id * C_TILE_SIZE + tl.arange(0, C_TILE_SIZE)
606 single_celoss_indice_bwd(
607 pid_n,
608 offset_c,
609 offset_d,
610 final_max,
611 final_sum,
612 tgt,
613 w_tgt,
614 out_grad,
615 mean_num,
616 inp_ptr,
617 inp_grad_ptr,
618 ignore_mask,
619 C,
620 D,
621 )
624@libentry()
625@triton.autotune(
626 configs=runtime.get_tuned_config("cross_entropy_loss"),
627 key=["C", "D"],
628)
629@triton.jit(do_not_specialize=["label_smoothing", "mean_num"])
630def celoss_probability_bwd(
631 out_grad_ptr,
632 inp_ptr,
633 tgt_ptr,
634 w_ptr,
635 inp_grad_ptr,
636 label_smoothing,
637 mean_num,
638 C,
639 D,
640 BLOCK_C: tl.constexpr,
641 BLOCK_D: tl.constexpr,
642):
643 pid_d = tl.program_id(0)
644 pid_n = tl.program_id(1)
645 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
647 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
648 out_grad = tl.load(out_grad_ptrs, mask=offset_d < D, other=0).to(tl.float32)[
649 None, :
650 ]
652 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
653 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
654 w_tgt_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
656 for off in range(0, C, BLOCK_C):
657 offset_c = off + tl.arange(0, BLOCK_C)
658 mask = offset_c[:, None] < C and offset_d[None, :] < D
659 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
660 inp = tl.load(inp_ptrs, mask, other=-float("inf")).to(tl.float32)
662 tgt_ptrs = tgt_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
663 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
664 tgt = tgt * (1 - label_smoothing) + label_smoothing / C
666 w_mask = offset_c < C
667 if w_ptr is None:
668 w = w_mask
669 else:
670 w_ptrs = w_ptr + offset_c
671 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
673 w_tgt_sum += tgt * w[:, None]
675 cur_max = tl.maximum(tmp_max, inp)
676 cur_exp = tl.exp(inp - cur_max)
677 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
678 tmp_max = cur_max
679 final_max = tl.max(tmp_max, axis=0)[None, :]
680 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
681 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
682 w_tgt_sum = tl.sum(w_tgt_sum, axis=0)[None, :]
684 for off in range(0, C, BLOCK_C):
685 offset_c = off + tl.arange(0, BLOCK_C)
686 offset = pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
687 inp_ptrs = inp_ptr + offset
688 mask = offset_c[:, None] < C and offset_d[None, :] < D
689 inp = tl.load(inp_ptrs, mask, other=0).to(tl.float32)
691 tgt_ptrs = tgt_ptr + offset
692 tgt = tl.load(tgt_ptrs, mask, other=0).to(tl.float32)
693 tgt = tgt * (1 - label_smoothing) + label_smoothing / C
695 w_mask = offset_c < C
696 if w_ptr is None:
697 w = w_mask
698 else:
699 w_ptrs = w_ptr + offset_c
700 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
702 grad = w_tgt_sum / final_sum * tl.exp(inp - final_max) - tgt * w[:, None]
703 inp_grad = grad * out_grad * mean_num
705 inp_grad_ptrs = inp_grad_ptr + offset
706 tl.store(inp_grad_ptrs, inp_grad, mask)
709@libentry()
710@triton.autotune(
711 configs=runtime.get_tuned_config("cross_entropy_loss"),
712 key=["C", "D"],
713)
714@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"])
715def celoss_indices_smooth_bwd(
716 out_grad_ptr,
717 inp_ptr,
718 tgt_ptr,
719 w_ptr,
720 inp_grad_ptr,
721 ignore_index,
722 label_smoothing,
723 mean_num,
724 C,
725 D,
726 BLOCK_C: tl.constexpr,
727 BLOCK_D: tl.constexpr,
728):
729 pid_d = tl.program_id(0)
730 pid_n = tl.program_id(1)
731 offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
733 tgt_ptrs = tgt_ptr + pid_n * D + offset_d
734 tgt_mask = offset_d < D
735 tgt = tl.load(tgt_ptrs, mask=tgt_mask, other=0)
736 out_grad_ptrs = out_grad_ptr + pid_n * D + offset_d
737 out_grad = tl.load(out_grad_ptrs, mask=tgt_mask, other=0).to(tl.float32)[None, :]
739 ignore_mask = (tgt != ignore_index)[None, :]
741 tmp_max = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
742 tmp_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
743 w_sum = tl.zeros([BLOCK_C, BLOCK_D], dtype=tl.float32)
745 for off in range(0, C, BLOCK_C):
746 offset_c = off + tl.arange(0, BLOCK_C)
747 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
748 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
749 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
751 w_mask = offset_c < C
752 if w_ptr is None:
753 w = w_mask
754 else:
755 w_ptrs = w_ptr + offset_c
756 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
758 smooth = tl.full([BLOCK_C, BLOCK_D], label_smoothing / C, dtype=tl.float32)
759 smooth = tl.where(
760 offset_c[:, None] == tgt[None, :],
761 1 - label_smoothing + label_smoothing / C,
762 smooth,
763 )
765 w_sum += smooth * w[:, None]
767 cur_max = tl.maximum(tmp_max, inp)
768 cur_exp = tl.exp(inp - cur_max)
769 tmp_sum = tmp_sum * tl.exp(tmp_max - cur_max) + cur_exp
770 tmp_max = cur_max
771 final_max = tl.max(tmp_max, axis=0)[None, :]
772 tmp_sum = tmp_sum * tl.exp(tmp_max - final_max)
773 final_sum = tl.sum(tmp_sum, axis=0)[None, :]
774 w_sum = tl.sum(w_sum, axis=0)[None, :]
776 for off in range(0, C, BLOCK_C):
777 offset_c = off + tl.arange(0, BLOCK_C)
778 inp_ptrs = inp_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
779 inp_mask = offset_c[:, None] < C and offset_d[None, :] < D
780 inp = tl.load(inp_ptrs, inp_mask, other=-float("inf")).to(tl.float32)
782 w_mask = offset_c < C
783 if w_ptr is None:
784 w = w_mask
785 else:
786 w_ptrs = w_ptr + offset_c
787 w = tl.load(w_ptrs, w_mask, other=0).to(tl.float32)
789 smooth = tl.where(
790 offset_c[:, None] == tgt[None, :],
791 1 - label_smoothing + label_smoothing / C,
792 label_smoothing / C,
793 )
795 grad = w_sum / final_sum * tl.exp(inp - final_max) - smooth * w[:, None]
796 inp_grad = grad * out_grad * mean_num
797 inp_grad_ptrs = (
798 inp_grad_ptr + pid_n * C * D + offset_c[:, None] * D + offset_d[None, :]
799 )
800 tl.store(inp_grad_ptrs, inp_grad, mask=inp_mask and ignore_mask)
803class CrossEntropyLoss(torch.autograd.Function):
804 @staticmethod
805 def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):
806 logger.debug("GEMS_TSINGMICRO CrossEntropyLoss")
808 shape = list(inp.shape)
809 dim = inp.ndim
810 N = 1 if dim == 1 else shape[0]
811 C = shape[0] if dim == 1 else shape[1]
812 D = inp.numel() // N // C
813 axis = 0 if dim == 1 else 1
814 del shape[axis]
816 inp = inp.contiguous()
817 tgt = target.contiguous()
819 ctx.N = N
820 ctx.C = C
821 ctx.D = D
822 ctx.ignore_index = ignore_index
823 ctx.label_smoothing = label_smoothing
824 ctx.shape = shape
826 final_max = None
827 final_sum = None
829 mean_num = 1
830 if reduction == 1 and tgt.ndim == dim:
831 mean_num = 1 / (N * D)
832 out = torch.empty(shape, dtype=torch.float32, device=inp.device)
834 def get_result(inp, tgt, out, reduction, mean_num):
835 if reduction == 0: # NONE
836 return out.to(inp.dtype)
837 elif reduction == 1: # MEAN
838 return (torch.sum(out) * mean_num).to(inp.dtype)
839 else: # SUM
840 return torch.sum(out).to(inp.dtype)
842 if weight is None and tgt.ndim != dim and label_smoothing == 0:
843 final_max = torch.full(
844 shape,
845 torch.finfo(torch.float32).min,
846 dtype=torch.float32,
847 device=inp.device,
848 )
849 final_sum = torch.zeros(shape, dtype=torch.float32, device=inp.device)
850 with torch_device_fn.device(inp.device):
851 if C <= (32 * 1000) or C > (2048 * 1000):
852 softmax_forward_kernel[(TOTAL_CORE_NUM,)](
853 inp, final_max, final_sum, N, C, D
854 )
855 else:
856 grid = lambda meta: (
857 triton.cdiv(TOTAL_CORE_NUM, meta["C_TILE_NUM"]),
858 meta["C_TILE_NUM"],
859 )
860 max_kernel[grid](inp, final_max, N, C, D)
861 softmax_forward_with_max_kernel[grid](
862 inp, final_max, final_sum, N, C, D
863 )
865 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),)
866 nllloss_without_weight_kernel[grid](
867 inp, tgt, final_max, final_sum, out, ignore_index, N, C, D
868 )
869 if reduction == 1:
870 if ignore_index < 0 or ignore_index >= C:
871 mean_num = 1 / C
872 else:
873 mean_num = 1 / (C - 1)
874 ctx.mean_num = mean_num
876 ctx.save_for_backward(inp, tgt, weight, final_max, final_sum)
877 return get_result(inp, tgt, out, reduction, mean_num)
879 weight = weight.contiguous() if weight is not None else None
880 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N)
882 if tgt.ndim == dim:
883 # target probabilities
884 with torch_device_fn.device(inp.device):
885 celoss_probability_kernel[grid](
886 inp,
887 tgt,
888 weight,
889 out,
890 label_smoothing,
891 C,
892 D,
893 )
894 elif label_smoothing == 0:
895 # target indices
896 w_tgt = torch.zeros(shape, dtype=torch.float32, device=inp.device)
897 final_max = torch.empty(shape, dtype=torch.float32, device=inp.device)
898 final_sum = torch.empty(shape, dtype=torch.float32, device=inp.device)
899 with torch_device_fn.device(inp.device):
900 softmax_forward_kernel[(TOTAL_CORE_NUM,)](
901 inp, final_max, final_sum, N, C, D
902 )
903 nllloss_with_weight_kernel[(N,)](
904 inp,
905 tgt,
906 weight,
907 w_tgt,
908 final_max,
909 final_sum,
910 out,
911 ignore_index,
912 N,
913 C,
914 D,
915 )
916 else:
917 w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device)
918 with torch_device_fn.device(inp.device):
919 celoss_indices_smooth_kernel[grid](
920 inp,
921 tgt,
922 weight,
923 out,
924 w_tgt,
925 ignore_index,
926 label_smoothing,
927 C,
928 D,
929 )
930 ctx.save_for_backward(inp, tgt, weight, final_max, final_sum)
931 ctx.mean_num = 1
933 if reduction == 1 and tgt.ndim != dim:
934 mean_num = 1 / torch.sum(w_tgt).item()
935 ctx.mean_num = mean_num
936 return get_result(inp, tgt, out, reduction, mean_num)
938 @staticmethod
939 def backward(ctx, out_grad):
940 logger.debug("GEMS_TSINGMICRO CrossEntropyLoss VJP")
942 inp, tgt, weight, final_max, final_sum = ctx.saved_tensors
943 N = ctx.N
944 C = ctx.C
945 D = ctx.D
946 ignore_index = ctx.ignore_index
947 label_smoothing = ctx.label_smoothing
948 mean_num = ctx.mean_num
949 shape = ctx.shape
951 out_grad = out_grad.broadcast_to(shape).contiguous()
953 inp_grad = torch.zeros(inp.shape, dtype=inp.dtype, device=inp.device)
954 grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N)
955 if tgt.ndim == inp.ndim:
956 celoss_probability_bwd[grid](
957 out_grad, inp, tgt, weight, inp_grad, label_smoothing, mean_num, C, D
958 )
959 elif label_smoothing == 0:
960 if final_sum is not None:
961 is_has_weight = weight is not None
962 is_has_ignore_index = ignore_index >= 0 and ignore_index < C
963 is_tgt_in_i32 = C < (1 << 31)
964 grid = lambda meta: (
965 triton.cdiv(TOTAL_CORE_NUM, meta["C_TILE_NUM"]),
966 meta["C_TILE_NUM"],
967 )
968 celoss_indice_bwd_with_saved_sum_kernel[grid](
969 out_grad,
970 inp,
971 tgt,
972 weight,
973 inp_grad,
974 final_max,
975 final_sum,
976 ignore_index,
977 mean_num,
978 N,
979 C,
980 D,
981 is_has_weight,
982 is_has_ignore_index,
983 is_tgt_in_i32,
984 )
985 else:
986 celoss_indices_smooth_bwd[grid](
987 out_grad,
988 inp,
989 tgt,
990 weight,
991 inp_grad,
992 ignore_index,
993 label_smoothing,
994 mean_num,
995 C,
996 D,
997 )
998 return inp_grad, None, None, None, None, None
1001def cross_entropy_loss(
1002 inp, target, weight=None, reduction="mean", ignore_index=-100, label_smoothing=0.0
1003):
1004 return CrossEntropyLoss.apply(
1005 inp,
1006 target,
1007 weight,
1008 _Reduction.get_enum(reduction),
1009 ignore_index,
1010 label_smoothing,
1011 )