Coverage for src/flag_gems/runtime/backend/_sunrise/ops/ctc_loss.py: 0%
446 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
10logger = logging.getLogger(__name__)
13_REDUCTION_NONE = 0
14_REDUCTION_MEAN = 1
15_REDUCTION_SUM = 2
16_LENGTH_STATS_CACHE = {}
17_LENGTH_STATS_CACHE_LIMIT = 256
20if hasattr(tl, "debug_barrier"):
21 _debug_barrier = tl.debug_barrier
22else:
24 @triton.jit
25 def _debug_barrier():
26 return
29@triton.jit
30def _logaddexp(a, b):
31 max_ab = tl.maximum(a, b)
32 min_ab = tl.minimum(a, b)
33 return tl.where(
34 max_ab == -float("inf"),
35 -float("inf"),
36 max_ab + tl.log(1.0 + tl.exp(min_ab - max_ab)),
37 )
40@triton.jit
41def _logaddexp3(a, b, c, use_c):
42 c = tl.where(use_c, c, -float("inf"))
43 max_abc = tl.maximum(tl.maximum(a, b), c)
44 safe_max = tl.where(max_abc == -float("inf"), 0.0, max_abc)
45 exp_sum = tl.exp(a - safe_max) + tl.exp(b - safe_max) + tl.exp(c - safe_max)
46 return tl.where(
47 max_abc == -float("inf"),
48 -float("inf"),
49 max_abc + tl.log(exp_sum),
50 )
53@libentry()
54@triton.jit
55def _ctc_loss_forward_kernel(
56 log_probs,
57 targets,
58 input_lengths,
59 target_lengths,
60 target_offsets,
61 neg_log_likelihood,
62 log_alpha,
63 T: tl.constexpr,
64 N: tl.constexpr,
65 C: tl.constexpr,
66 MAX_TARGET: tl.constexpr,
67 STATE_COUNT_MAX: tl.constexpr,
68 BLANK: tl.constexpr,
69 TARGET_1D: tl.constexpr,
70 BLOCK_S: tl.constexpr,
71):
72 batch = tl.program_id(0)
73 states = tl.arange(0, BLOCK_S)
75 input_len = tl.load(input_lengths + batch)
76 target_len = tl.load(target_lengths + batch)
77 state_count = target_len * 2 + 1
78 valid_state = states < state_count
79 stored_state = states < STATE_COUNT_MAX
81 is_blank_state = (states % 2) == 0
82 target_index = (states - 1) // 2
83 target_mask = (target_index >= 0) & (target_index < target_len)
84 target_safe_index = tl.where(target_mask, target_index, 0)
86 if TARGET_1D:
87 target_base = tl.full((), 0, tl.int64)
88 for prev_batch in tl.range(0, N):
89 target_base += tl.load(
90 target_lengths + prev_batch,
91 mask=prev_batch < batch,
92 other=0,
93 )
94 target_origin = target_base
95 target_ptrs = targets + target_origin + target_safe_index
96 else:
97 target_origin = batch * MAX_TARGET
98 target_ptrs = targets + target_origin + target_safe_index
100 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK)
101 labels = tl.where(is_blank_state, BLANK, target_value)
103 t0_active = input_len > 0
104 init_state = (states == 0) | ((states == 1) & (target_len > 0))
105 init_logp = tl.load(
106 log_probs + batch * C + labels,
107 mask=init_state & stored_state & t0_active,
108 other=0.0,
109 ).to(tl.float32)
110 alpha = tl.where(init_state & valid_state & t0_active, init_logp, -float("inf"))
111 tl.store(
112 log_alpha + batch * T * STATE_COUNT_MAX + states,
113 alpha,
114 mask=stored_state,
115 )
116 _debug_barrier()
118 for t in tl.range(1, T):
119 prev_base = log_alpha + batch * T * STATE_COUNT_MAX + (t - 1) * STATE_COUNT_MAX
120 prev0 = tl.load(prev_base + states, mask=stored_state, other=-float("inf")).to(
121 tl.float32
122 )
123 prev1 = tl.load(
124 prev_base + tl.where(states > 0, states - 1, 0),
125 mask=(states > 0) & stored_state,
126 other=-float("inf"),
127 ).to(tl.float32)
128 prev2 = tl.load(
129 prev_base + tl.where(states > 1, states - 2, 0),
130 mask=(states > 1) & stored_state,
131 other=-float("inf"),
132 ).to(tl.float32)
134 prev_target_index = tl.where(target_index > 0, target_index - 1, 0)
135 prev_target_value = tl.load(
136 targets + target_origin + prev_target_index,
137 mask=target_mask & (target_index > 0),
138 other=BLANK,
139 )
140 skip_allowed = (
141 (~is_blank_state) & (target_index > 0) & (target_value != prev_target_value)
142 )
144 acc = _logaddexp3(prev0, prev1, prev2, skip_allowed)
146 logp = tl.load(
147 log_probs + t * N * C + batch * C + labels,
148 mask=valid_state & (t < input_len),
149 other=0.0,
150 ).to(tl.float32)
151 alpha = tl.where(valid_state & (t < input_len), acc + logp, -float("inf"))
152 tl.store(
153 log_alpha + batch * T * STATE_COUNT_MAX + t * STATE_COUNT_MAX + states,
154 alpha,
155 mask=stored_state,
156 )
157 _debug_barrier()
159 if input_len <= 0:
160 loss = tl.where(target_len == 0, 0.0, float("inf"))
161 else:
162 _debug_barrier()
163 final_base = (
164 log_alpha + batch * T * STATE_COUNT_MAX + (input_len - 1) * STATE_COUNT_MAX
165 )
166 last = tl.load(final_base + state_count - 1).to(tl.float32)
167 prev_last = tl.load(
168 final_base + tl.where(target_len > 0, state_count - 2, 0),
169 mask=target_len > 0,
170 other=-float("inf"),
171 ).to(tl.float32)
172 log_likelihood = _logaddexp(last, prev_last)
173 loss = -log_likelihood
175 tl.store(neg_log_likelihood + batch, loss)
178@libentry()
179@triton.jit
180def _ctc_loss_forward_no_grad_kernel(
181 log_probs,
182 targets,
183 input_lengths,
184 target_lengths,
185 target_offsets,
186 neg_log_likelihood,
187 scratch_alpha,
188 T: tl.constexpr,
189 N: tl.constexpr,
190 C: tl.constexpr,
191 MAX_TARGET: tl.constexpr,
192 STATE_COUNT_MAX: tl.constexpr,
193 BLANK: tl.constexpr,
194 TARGET_1D: tl.constexpr,
195 BLOCK_S: tl.constexpr,
196):
197 batch = tl.program_id(0)
198 states = tl.arange(0, BLOCK_S)
200 target_len = tl.load(target_lengths + batch)
201 state_count = target_len * 2 + 1
202 valid_state = states < state_count
203 stored_state = states < STATE_COUNT_MAX
205 is_blank_state = (states % 2) == 0
206 target_index = (states - 1) // 2
207 target_mask = (target_index >= 0) & (target_index < target_len)
208 target_safe_index = tl.where(target_mask, target_index, 0)
210 if TARGET_1D:
211 target_origin = tl.load(target_offsets + batch)
212 target_ptrs = targets + target_origin + target_safe_index
213 else:
214 target_origin = batch * MAX_TARGET
215 target_ptrs = targets + target_origin + target_safe_index
217 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK)
218 labels = tl.where(is_blank_state, BLANK, target_value)
220 input_len = tl.load(input_lengths + batch)
221 init_state = (states == 0) | ((states == 1) & (target_len > 0))
222 init_logp = tl.load(
223 log_probs + batch * C + labels,
224 mask=init_state & stored_state & (input_len > 0),
225 other=0.0,
226 ).to(tl.float32)
227 alpha = tl.where(
228 init_state & valid_state & (input_len > 0), init_logp, -float("inf")
229 )
230 scratch_batch = scratch_alpha + batch * 2 * STATE_COUNT_MAX
231 tl.store(scratch_batch + states, alpha, mask=stored_state)
232 _debug_barrier()
234 for t in tl.range(1, T):
235 prev_base = scratch_batch + ((t - 1) % 2) * STATE_COUNT_MAX
236 cur_base = scratch_batch + (t % 2) * STATE_COUNT_MAX
237 prev0 = tl.load(prev_base + states, mask=stored_state, other=-float("inf")).to(
238 tl.float32
239 )
240 prev1 = tl.load(
241 prev_base + tl.where(states > 0, states - 1, 0),
242 mask=(states > 0) & stored_state,
243 other=-float("inf"),
244 ).to(tl.float32)
245 prev2 = tl.load(
246 prev_base + tl.where(states > 1, states - 2, 0),
247 mask=(states > 1) & stored_state,
248 other=-float("inf"),
249 ).to(tl.float32)
251 prev_target_index = tl.where(target_index > 0, target_index - 1, 0)
252 prev_target_value = tl.load(
253 targets + target_origin + prev_target_index,
254 mask=target_mask & (target_index > 0),
255 other=BLANK,
256 )
257 skip_allowed = (
258 (~is_blank_state) & (target_index > 0) & (target_value != prev_target_value)
259 )
261 acc = _logaddexp3(prev0, prev1, prev2, skip_allowed)
262 logp = tl.load(
263 log_probs + t * N * C + batch * C + labels,
264 mask=valid_state & (t < input_len),
265 other=0.0,
266 ).to(tl.float32)
267 alpha = tl.where(valid_state & (t < input_len), acc + logp, -float("inf"))
268 tl.store(cur_base + states, alpha, mask=stored_state & (t < input_len))
269 _debug_barrier()
271 if input_len <= 0:
272 loss = tl.where(target_len == 0, 0.0, float("inf"))
273 else:
274 _debug_barrier()
275 final_base = scratch_batch + ((input_len - 1) % 2) * STATE_COUNT_MAX
276 last = tl.load(final_base + state_count - 1).to(tl.float32)
277 prev_last = tl.load(
278 final_base + tl.where(target_len > 0, state_count - 2, 0),
279 mask=target_len > 0,
280 other=-float("inf"),
281 ).to(tl.float32)
282 loss = -_logaddexp(last, prev_last)
284 tl.store(neg_log_likelihood + batch, loss)
287@libentry()
288@triton.jit
289def _ctc_loss_forward_full_length_reduce_kernel(
290 log_probs,
291 targets,
292 target_lengths,
293 target_offsets,
294 contrib,
295 scratch_alpha,
296 T: tl.constexpr,
297 N: tl.constexpr,
298 C: tl.constexpr,
299 MAX_TARGET: tl.constexpr,
300 STATE_COUNT_MAX: tl.constexpr,
301 BLANK: tl.constexpr,
302 TARGET_1D: tl.constexpr,
303 REDUCTION: tl.constexpr,
304 BLOCK_S: tl.constexpr,
305):
306 batch = tl.program_id(0)
307 states = tl.arange(0, BLOCK_S)
309 target_len = tl.load(target_lengths + batch)
310 state_count = target_len * 2 + 1
311 valid_state = states < state_count
312 stored_state = states < STATE_COUNT_MAX
314 is_blank_state = (states % 2) == 0
315 target_index = (states - 1) // 2
316 target_mask = (target_index >= 0) & (target_index < target_len)
317 target_safe_index = tl.where(target_mask, target_index, 0)
319 if TARGET_1D:
320 target_origin = tl.load(target_offsets + batch)
321 target_ptrs = targets + target_origin + target_safe_index
322 else:
323 target_origin = batch * MAX_TARGET
324 target_ptrs = targets + target_origin + target_safe_index
326 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK)
327 labels = tl.where(is_blank_state, BLANK, target_value)
329 init_state = (states == 0) | ((states == 1) & (target_len > 0))
330 init_logp = tl.load(
331 log_probs + batch * C + labels,
332 mask=init_state & stored_state,
333 other=0.0,
334 ).to(tl.float32)
335 alpha = tl.where(init_state & valid_state, init_logp, -float("inf"))
336 scratch_batch = scratch_alpha + batch * 2 * STATE_COUNT_MAX
337 tl.store(scratch_batch + states, alpha, mask=stored_state)
338 _debug_barrier()
340 for t in tl.range(1, T):
341 prev_base = scratch_batch + ((t - 1) % 2) * STATE_COUNT_MAX
342 cur_base = scratch_batch + (t % 2) * STATE_COUNT_MAX
343 prev0 = tl.load(prev_base + states, mask=stored_state, other=-float("inf")).to(
344 tl.float32
345 )
346 prev1 = tl.load(
347 prev_base + tl.where(states > 0, states - 1, 0),
348 mask=(states > 0) & stored_state,
349 other=-float("inf"),
350 ).to(tl.float32)
351 prev2 = tl.load(
352 prev_base + tl.where(states > 1, states - 2, 0),
353 mask=(states > 1) & stored_state,
354 other=-float("inf"),
355 ).to(tl.float32)
357 prev_target_index = tl.where(target_index > 0, target_index - 1, 0)
358 prev_target_value = tl.load(
359 targets + target_origin + prev_target_index,
360 mask=target_mask & (target_index > 0),
361 other=BLANK,
362 )
363 skip_allowed = (
364 (~is_blank_state) & (target_index > 0) & (target_value != prev_target_value)
365 )
367 acc = _logaddexp3(prev0, prev1, prev2, skip_allowed)
368 logp = tl.load(
369 log_probs + t * N * C + batch * C + labels,
370 mask=valid_state,
371 other=0.0,
372 ).to(tl.float32)
373 alpha = tl.where(valid_state, acc + logp, -float("inf"))
374 tl.store(cur_base + states, alpha, mask=stored_state)
375 _debug_barrier()
377 if T <= 0:
378 loss = tl.where(target_len == 0, 0.0, float("inf"))
379 else:
380 _debug_barrier()
381 final_base = scratch_batch + ((T - 1) % 2) * STATE_COUNT_MAX
382 last = tl.load(final_base + state_count - 1).to(tl.float32)
383 prev_last = tl.load(
384 final_base + tl.where(target_len > 0, state_count - 2, 0),
385 mask=target_len > 0,
386 other=-float("inf"),
387 ).to(tl.float32)
388 loss = -_logaddexp(last, prev_last)
390 if REDUCTION == 1:
391 loss = loss / tl.maximum(target_len, 1).to(tl.float32) / N
392 tl.store(contrib + batch, loss)
395@libentry()
396@triton.jit
397def _ctc_loss_init_grad_kernel(
398 log_probs,
399 input_lengths,
400 target_lengths,
401 neg_log_likelihood,
402 grad_output,
403 grad_input,
404 total: tl.constexpr,
405 T: tl.constexpr,
406 N: tl.constexpr,
407 C: tl.constexpr,
408 REDUCTION: tl.constexpr,
409 ZERO_INFINITY: tl.constexpr,
410 BLOCK: tl.constexpr,
411):
412 offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
413 mask = offsets < total
414 batch = (offsets // C) % N
415 t = offsets // (N * C)
417 input_len = tl.load(input_lengths + batch, mask=mask, other=0)
418 target_len = tl.load(target_lengths + batch, mask=mask, other=1)
419 nll = tl.load(neg_log_likelihood + batch, mask=mask, other=0.0).to(tl.float32)
421 if REDUCTION == 0:
422 scale = tl.load(grad_output + batch, mask=mask, other=0.0).to(tl.float32)
423 else:
424 scale = tl.load(grad_output).to(tl.float32)
425 if REDUCTION == 1:
426 denom = tl.maximum(target_len, 1).to(tl.float32) * N
427 scale = scale / denom
429 if ZERO_INFINITY:
430 scale = tl.where(nll == float("inf"), 0.0, scale)
432 logp = tl.load(log_probs + offsets, mask=mask, other=-float("inf")).to(tl.float32)
433 grad = tl.where((t < input_len) & mask, tl.exp(logp) * scale, 0.0)
434 nan_grad = float("nan")
435 grad = tl.where(
436 (t < input_len) & mask & (scale != 0.0) & (logp == -float("inf")),
437 nan_grad,
438 grad,
439 )
440 if not ZERO_INFINITY:
441 grad = tl.where((t < input_len) & mask & (nll == float("inf")), nan_grad, grad)
442 tl.store(grad_input + offsets, grad, mask=mask)
445@libentry()
446@triton.jit
447def _ctc_loss_backward_kernel(
448 log_probs,
449 targets,
450 input_lengths,
451 target_lengths,
452 target_offsets,
453 neg_log_likelihood,
454 log_alpha,
455 grad_output,
456 grad_input,
457 scratch_beta,
458 T: tl.constexpr,
459 N: tl.constexpr,
460 C: tl.constexpr,
461 MAX_TARGET: tl.constexpr,
462 STATE_COUNT_MAX: tl.constexpr,
463 BLANK: tl.constexpr,
464 TARGET_1D: tl.constexpr,
465 REDUCTION: tl.constexpr,
466 ZERO_INFINITY: tl.constexpr,
467 BLOCK_S: tl.constexpr,
468):
469 batch = tl.program_id(0)
470 states = tl.arange(0, BLOCK_S)
472 input_len = tl.load(input_lengths + batch)
473 target_len = tl.load(target_lengths + batch)
474 nll = tl.load(neg_log_likelihood + batch).to(tl.float32)
475 state_count = target_len * 2 + 1
476 valid_state = states < state_count
477 stored_state = states < STATE_COUNT_MAX
479 is_blank_state = (states % 2) == 0
480 target_index = (states - 1) // 2
481 target_mask = (target_index >= 0) & (target_index < target_len)
482 target_safe_index = tl.where(target_mask, target_index, 0)
484 if TARGET_1D:
485 target_origin = tl.load(target_offsets + batch)
486 target_ptrs = targets + target_origin + target_safe_index
487 else:
488 target_origin = batch * MAX_TARGET
489 target_ptrs = targets + target_origin + target_safe_index
491 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK)
492 labels = tl.where(is_blank_state, BLANK, target_value)
494 state1 = states + 1
495 is_blank_state1 = (state1 % 2) == 0
496 target_index1 = (state1 - 1) // 2
497 target_mask1 = (target_index1 >= 0) & (target_index1 < target_len)
498 target_safe_index1 = tl.where(target_mask1, target_index1, 0)
499 target_ptrs1 = targets + target_origin + target_safe_index1
500 target_value1 = tl.load(target_ptrs1, mask=target_mask1, other=BLANK)
501 labels1 = tl.where(is_blank_state1, BLANK, target_value1)
503 state2 = states + 2
504 target_index2 = (state2 - 1) // 2
505 target_mask2 = (target_index2 >= 0) & (target_index2 < target_len)
506 target_safe_index2 = tl.where(target_mask2, target_index2, 0)
507 target_ptrs2 = targets + target_origin + target_safe_index2
508 target_value2 = tl.load(target_ptrs2, mask=target_mask2, other=BLANK)
509 labels2 = target_value2
511 if REDUCTION == 0:
512 scale = tl.load(grad_output + batch).to(tl.float32)
513 else:
514 scale = tl.load(grad_output).to(tl.float32)
515 if REDUCTION == 1:
516 denom = tl.maximum(target_len, 1).to(tl.float32) * N
517 scale = scale / denom
519 if ZERO_INFINITY:
520 scale = tl.where(nll == float("inf"), 0.0, scale)
522 beta_init = tl.where(
523 ((states == state_count - 1) | ((states == state_count - 2) & (target_len > 0)))
524 & valid_state
525 & (input_len > 0),
526 0.0,
527 -float("inf"),
528 )
529 scratch_batch = scratch_beta + batch * 2 * STATE_COUNT_MAX
530 tl.store(scratch_batch + states, beta_init, mask=stored_state)
531 _debug_barrier()
532 log_likelihood = tl.where(scale != 0.0, -nll, 0.0)
534 for step in tl.range(0, T):
535 t = input_len - 1 - step
536 active = t >= 0
537 safe_t = tl.where(active, t, 0)
538 beta_base = scratch_batch + (step % 2) * STATE_COUNT_MAX
539 next_beta_base = scratch_batch + ((step + 1) % 2) * STATE_COUNT_MAX
540 beta = tl.load(beta_base + states, mask=stored_state, other=-float("inf")).to(
541 tl.float32
542 )
544 alpha_t = tl.load(
545 log_alpha + batch * T * STATE_COUNT_MAX + safe_t * STATE_COUNT_MAX + states,
546 mask=active & stored_state,
547 other=-float("inf"),
548 ).to(tl.float32)
549 log_post = alpha_t + beta - log_likelihood
550 posterior = tl.where(
551 active & valid_state & (scale != 0.0),
552 tl.exp(log_post),
553 0.0,
554 )
555 tl.atomic_add(
556 grad_input + safe_t * N * C + batch * C + labels,
557 -scale * posterior,
558 sem="relaxed",
559 mask=active & valid_state & stored_state,
560 )
562 stay = beta + tl.load(
563 log_probs + safe_t * N * C + batch * C + labels,
564 mask=active & valid_state,
565 other=-float("inf"),
566 ).to(tl.float32)
567 next1 = tl.load(
568 beta_base + states + 1,
569 mask=(states + 1 < state_count) & stored_state,
570 other=-float("inf"),
571 ).to(tl.float32) + tl.load(
572 log_probs + safe_t * N * C + batch * C + labels1,
573 mask=active & (states + 1 < state_count) & stored_state,
574 other=-float("inf"),
575 ).to(
576 tl.float32
577 )
578 skip_allowed = (
579 (~is_blank_state)
580 & (states + 2 < state_count)
581 & (target_value != target_value2)
582 )
583 next2 = tl.load(
584 beta_base + states + 2,
585 mask=(states + 2 < state_count) & stored_state,
586 other=-float("inf"),
587 ).to(tl.float32) + tl.load(
588 log_probs + safe_t * N * C + batch * C + labels2,
589 mask=active & skip_allowed & stored_state,
590 other=-float("inf"),
591 ).to(
592 tl.float32
593 )
595 beta_next = _logaddexp3(stay, next1, next2, skip_allowed)
596 tl.store(
597 next_beta_base + states,
598 tl.where(active, beta_next, -float("inf")),
599 mask=stored_state,
600 )
601 _debug_barrier()
604def _reduction_enum(reduction):
605 if isinstance(reduction, str):
606 if reduction == "none":
607 return _REDUCTION_NONE
608 if reduction == "mean":
609 return _REDUCTION_MEAN
610 if reduction == "sum":
611 return _REDUCTION_SUM
612 raise ValueError(
613 "ctc_loss expected reduction to be one of 'none', 'mean', or 'sum', "
614 f"but got {reduction!r}"
615 )
616 return int(reduction)
619_INTEGRAL_DTYPES = {
620 torch.uint8,
621 torch.int8,
622 torch.int16,
623 torch.int32,
624 torch.int64,
625}
628def _is_integral_dtype(dtype):
629 return dtype in _INTEGRAL_DTYPES
632def _lengths_to_tensor(lengths, device, name):
633 if torch.is_tensor(lengths):
634 if not _is_integral_dtype(lengths.dtype):
635 raise RuntimeError(f"{name} must be integral")
636 out = lengths.to(device=device)
637 else:
638 out = torch.tensor(lengths, device=device)
639 if not _is_integral_dtype(out.dtype):
640 raise RuntimeError(f"{name} must be integral")
641 if out.dtype != torch.long:
642 out = out.to(dtype=torch.long)
643 return out.reshape(1) if out.ndim == 0 else out.reshape(-1).contiguous()
646def _length_stats(lengths):
647 key = None
648 if torch.is_tensor(lengths):
649 key = (
650 lengths.device.type,
651 lengths.device.index,
652 lengths.data_ptr(),
653 lengths.numel(),
654 lengths._version,
655 )
656 cached = _LENGTH_STATS_CACHE.get(key)
657 if cached is not None:
658 return cached[1]
660 # [sunrise fix] Could not run 'aten::min' with arguments from the 'ptpu' backend.
661 stats_tensor = torch.stack(
662 (lengths.cpu().min(), lengths.cpu().max(), lengths.cpu().sum())
663 )
664 stats = tuple(int(value) for value in stats_tensor.tolist())
665 if key is not None:
666 if len(_LENGTH_STATS_CACHE) >= _LENGTH_STATS_CACHE_LIMIT:
667 _LENGTH_STATS_CACHE.clear()
668 _LENGTH_STATS_CACHE[key] = (lengths, stats)
669 return stats
672def _compute_dtype(dtype):
673 if dtype in (torch.float16, torch.bfloat16):
674 return torch.float32
675 return dtype
678class CtcLossFunction(torch.autograd.Function):
679 @staticmethod
680 def forward(
681 ctx,
682 log_probs,
683 targets,
684 input_lengths,
685 target_lengths,
686 blank=0,
687 reduction="mean",
688 zero_infinity=False,
689 ):
690 reduction = _reduction_enum(reduction)
691 if reduction not in (_REDUCTION_NONE, _REDUCTION_MEAN, _REDUCTION_SUM):
692 raise ValueError(f"ctc_loss got invalid reduction enum {reduction}")
694 if log_probs.ndim not in (2, 3):
695 raise RuntimeError(
696 "ctc_loss expects log_probs to be a 2D or 3D tensor, "
697 f"but got {log_probs.ndim}D"
698 )
699 if not torch.is_floating_point(log_probs):
700 raise RuntimeError(f'"ctc_loss" not implemented for {log_probs.dtype}')
701 if blank < 0 or blank >= log_probs.shape[-1]:
702 raise RuntimeError("blank must be in label range")
704 original_dtype = log_probs.dtype
705 compute_dtype = _compute_dtype(original_dtype)
706 unbatched = log_probs.ndim == 2
707 batch_size = 1 if unbatched else log_probs.shape[1]
709 work_log_probs = log_probs.unsqueeze(1) if unbatched else log_probs
710 work_log_probs = work_log_probs.contiguous()
711 if work_log_probs.dtype != compute_dtype:
712 work_log_probs = work_log_probs.to(compute_dtype)
714 if torch.is_floating_point(targets):
715 work_targets = targets.to(dtype=torch.long).contiguous()
716 elif _is_integral_dtype(targets.dtype):
717 work_targets = targets.contiguous()
718 else:
719 raise RuntimeError("ctc_loss targets must be integral or floating point")
720 work_input_lengths = _lengths_to_tensor(
721 input_lengths, log_probs.device, "input_lengths"
722 )
723 work_target_lengths = _lengths_to_tensor(
724 target_lengths, log_probs.device, "target_lengths"
725 )
726 if work_input_lengths.numel() != batch_size:
727 raise RuntimeError(
728 f"ctc_loss expected input_lengths to have size {batch_size}, "
729 f"but got {work_input_lengths.numel()}"
730 )
731 if work_target_lengths.numel() != batch_size:
732 raise RuntimeError(
733 f"ctc_loss expected target_lengths to have size {batch_size}, "
734 f"but got {work_target_lengths.numel()}"
735 )
736 min_input_length, max_input_length, _ = _length_stats(work_input_lengths)
737 min_target_length, max_target, total_target_length = _length_stats(
738 work_target_lengths
739 )
740 if min_input_length < 0 or max_input_length > work_log_probs.shape[0]:
741 raise RuntimeError("ctc_loss input_lengths must be in [0, T]")
742 if min_target_length < 0:
743 raise RuntimeError("ctc_loss target_lengths must be non-negative")
745 state_count_max = 2 * max_target + 1
746 target_stride = max_target
747 if work_targets.ndim == 1:
748 target_1d = True
749 if total_target_length != work_targets.numel():
750 raise RuntimeError(
751 "ctc_loss expected concatenated targets length to equal "
752 "sum(target_lengths)"
753 )
754 work_target_offsets = (
755 work_target_lengths.cumsum(0) - work_target_lengths
756 ).contiguous()
757 elif work_targets.ndim == 2:
758 target_1d = False
759 if max_target > work_targets.shape[1]:
760 raise RuntimeError(
761 "ctc_loss target_lengths cannot exceed padded target width"
762 )
763 target_stride = work_targets.shape[1]
764 work_target_offsets = work_target_lengths
765 else:
766 raise RuntimeError(
767 "ctc_loss expects targets to be a 1D concatenated tensor or a "
768 f"2D padded tensor, but got {work_targets.ndim}D"
769 )
771 needs_log_probs_grad = ctx.needs_input_grad[0]
772 block_s = triton.next_power_of_2(state_count_max)
774 if not needs_log_probs_grad:
775 if (
776 not unbatched
777 and not zero_infinity
778 and reduction in (_REDUCTION_MEAN, _REDUCTION_SUM)
779 and min_input_length == work_log_probs.shape[0]
780 and work_log_probs.shape[0] > 0
781 ):
782 contrib = torch.empty(
783 (batch_size,), dtype=torch.float32, device=log_probs.device
784 )
785 scratch_alpha = torch.empty(
786 (batch_size, 2, state_count_max),
787 dtype=torch.float32,
788 device=log_probs.device,
789 )
790 with torch_device_fn.device(log_probs.device):
791 _ctc_loss_forward_full_length_reduce_kernel[(batch_size,)](
792 work_log_probs,
793 work_targets,
794 work_target_lengths,
795 work_target_offsets,
796 contrib,
797 scratch_alpha,
798 work_log_probs.shape[0],
799 batch_size,
800 work_log_probs.shape[2],
801 target_stride,
802 state_count_max,
803 blank,
804 target_1d,
805 reduction,
806 block_s,
807 )
808 output = contrib.sum()
809 if output.dtype != original_dtype:
810 output = output.to(original_dtype)
811 return output
813 raw_neg_log_likelihood = torch.empty(
814 (batch_size,), dtype=torch.float32, device=log_probs.device
815 )
816 scratch_alpha = torch.empty(
817 (batch_size, 2, state_count_max),
818 dtype=torch.float32,
819 device=log_probs.device,
820 )
821 with torch_device_fn.device(log_probs.device):
822 _ctc_loss_forward_no_grad_kernel[(batch_size,)](
823 work_log_probs,
824 work_targets,
825 work_input_lengths,
826 work_target_lengths,
827 work_target_offsets,
828 raw_neg_log_likelihood,
829 scratch_alpha,
830 work_log_probs.shape[0],
831 batch_size,
832 work_log_probs.shape[2],
833 target_stride,
834 state_count_max,
835 blank,
836 target_1d,
837 block_s,
838 )
839 neg_log_likelihood = raw_neg_log_likelihood
840 if zero_infinity:
841 neg_log_likelihood = torch.where(
842 torch.isinf(neg_log_likelihood),
843 torch.zeros(
844 (), dtype=neg_log_likelihood.dtype, device=log_probs.device
845 ),
846 neg_log_likelihood,
847 )
849 if reduction == _REDUCTION_NONE:
850 output = neg_log_likelihood
851 if unbatched:
852 output = output.squeeze(0)
853 elif reduction == _REDUCTION_SUM:
854 output = neg_log_likelihood.sum()
855 else:
856 # denom = work_target_lengths.clamp_min(1)
857 # output = (neg_log_likelihood / denom).mean()
858 # [sunrise fix] Could not run 'aten::min' & 'aten::mean' with arguments from the 'ptpu' backend.
859 denom = (
860 work_target_lengths.cpu()
861 .clamp_min(1)
862 .to(work_target_lengths.device)
863 )
864 output = (
865 (neg_log_likelihood / denom)
866 .cpu()
867 .mean()
868 .to(neg_log_likelihood.device)
869 )
871 if output.dtype != original_dtype:
872 output = output.to(original_dtype)
873 return output
875 raw_neg_log_likelihood = torch.empty(
876 (batch_size,), dtype=torch.float32, device=log_probs.device
877 )
879 log_alpha = torch.empty(
880 (batch_size, work_log_probs.shape[0], state_count_max),
881 dtype=torch.float32,
882 device=log_probs.device,
883 )
884 with torch_device_fn.device(log_probs.device):
885 _ctc_loss_forward_kernel[(batch_size,)](
886 work_log_probs,
887 work_targets,
888 work_input_lengths,
889 work_target_lengths,
890 work_target_offsets,
891 raw_neg_log_likelihood,
892 log_alpha,
893 work_log_probs.shape[0],
894 batch_size,
895 work_log_probs.shape[2],
896 target_stride,
897 state_count_max,
898 blank,
899 target_1d,
900 block_s,
901 )
902 neg_log_likelihood = raw_neg_log_likelihood
903 if zero_infinity:
904 neg_log_likelihood = torch.where(
905 torch.isinf(neg_log_likelihood),
906 torch.zeros(
907 (), dtype=neg_log_likelihood.dtype, device=log_probs.device
908 ),
909 neg_log_likelihood,
910 )
912 if reduction == _REDUCTION_NONE:
913 output = neg_log_likelihood
914 if unbatched:
915 output = output.squeeze(0)
916 if output.dtype != original_dtype:
917 output = output.to(original_dtype)
918 elif reduction == _REDUCTION_SUM:
919 output = neg_log_likelihood.sum()
920 else:
921 # denom = work_target_lengths.clamp_min(1)
922 # output = (neg_log_likelihood / denom).mean()
923 # [sunrise fix] Could not run 'aten::min' & 'aten::mean' with arguments from the 'ptpu' backend.
924 denom = (
925 work_target_lengths.cpu().clamp_min(1).to(work_target_lengths.device)
926 )
927 output = (
928 (neg_log_likelihood / denom).cpu().mean().to(neg_log_likelihood.device)
929 )
931 if output.dtype != original_dtype:
932 output = output.to(original_dtype)
934 ctx.save_for_backward(
935 work_log_probs,
936 work_targets,
937 work_input_lengths,
938 work_target_lengths,
939 work_target_offsets,
940 raw_neg_log_likelihood,
941 log_alpha,
942 )
943 ctx.blank = blank
944 ctx.reduction = reduction
945 ctx.zero_infinity = zero_infinity
946 ctx.unbatched = unbatched
947 ctx.batch_size = batch_size
948 ctx.original_dtype = original_dtype
949 ctx.max_target = target_stride
950 ctx.state_count_max = state_count_max
951 ctx.target_1d = target_1d
953 return output
955 @staticmethod
956 def backward(ctx, grad_output):
957 (
958 work_log_probs,
959 work_targets,
960 work_input_lengths,
961 work_target_lengths,
962 work_target_offsets,
963 neg_log_likelihood,
964 log_alpha,
965 ) = ctx.saved_tensors
967 grad_output = grad_output.contiguous()
969 grad_log_probs = torch.empty_like(work_log_probs)
970 total = work_log_probs.numel()
971 block = 256
972 with torch_device_fn.device(work_log_probs.device):
973 _ctc_loss_init_grad_kernel[(triton.cdiv(total, block),)](
974 work_log_probs,
975 work_input_lengths,
976 work_target_lengths,
977 neg_log_likelihood,
978 grad_output,
979 grad_log_probs,
980 total,
981 work_log_probs.shape[0],
982 ctx.batch_size,
983 work_log_probs.shape[2],
984 ctx.reduction,
985 ctx.zero_infinity,
986 block,
987 )
989 scratch_beta = torch.empty(
990 (ctx.batch_size, 2, ctx.state_count_max),
991 dtype=torch.float32,
992 device=work_log_probs.device,
993 )
994 block_s = triton.next_power_of_2(ctx.state_count_max)
995 _ctc_loss_backward_kernel[(ctx.batch_size,)](
996 work_log_probs,
997 work_targets,
998 work_input_lengths,
999 work_target_lengths,
1000 work_target_offsets,
1001 neg_log_likelihood,
1002 log_alpha,
1003 grad_output,
1004 grad_log_probs,
1005 scratch_beta,
1006 work_log_probs.shape[0],
1007 ctx.batch_size,
1008 work_log_probs.shape[2],
1009 ctx.max_target,
1010 ctx.state_count_max,
1011 ctx.blank,
1012 ctx.target_1d,
1013 ctx.reduction,
1014 ctx.zero_infinity,
1015 block_s,
1016 )
1018 if ctx.unbatched:
1019 grad_log_probs = grad_log_probs.squeeze(1)
1020 if grad_log_probs.dtype != ctx.original_dtype:
1021 grad_log_probs = grad_log_probs.to(ctx.original_dtype)
1023 return grad_log_probs, None, None, None, None, None, None
1026def ctc_loss(
1027 log_probs,
1028 targets,
1029 input_lengths,
1030 target_lengths,
1031 blank=0,
1032 reduction="mean",
1033 zero_infinity=False,
1034):
1035 logger.debug("GEMS CTC LOSS")
1036 return CtcLossFunction.apply(
1037 log_probs,
1038 targets,
1039 input_lengths,
1040 target_lengths,
1041 blank,
1042 reduction,
1043 zero_infinity,
1044 )