Coverage for src/flag_gems/experimental_ops/_log_softmax_backward_data.py: 0%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import math
3import torch
4import triton
5import triton.language as tl
8@triton.jit
9def _log_softmax_bwd_kernel(
10 grad_ptr, # pointer to grad_output (dL/dy)
11 y_logsm_ptr, # pointer to output of log_softmax (y = log_softmax(x))
12 grad_in_ptr, # pointer to grad_input (dL/dx)
13 M, # number of rows (product of all dims except reduction dim)
14 K, # length of the reduction dimension
15 BLOCK_SIZE: tl.constexpr,
16):
17 pid = tl.program_id(axis=0)
18 row_start = pid * K
19 offs = tl.arange(0, BLOCK_SIZE)
21 # First pass: compute s = sum_j grad_output[j] along dim
22 s = tl.zeros((), dtype=tl.float32)
23 num_chunks = tl.cdiv(K, BLOCK_SIZE)
24 for chunk in range(0, num_chunks):
25 cols = chunk * BLOCK_SIZE + offs
26 mask = cols < K
27 go_chunk = tl.load(grad_ptr + row_start + cols, mask=mask, other=0.0)
28 go32 = go_chunk.to(tl.float32)
29 s += tl.sum(go32, axis=0)
31 # Second pass: grad_input = grad_output - exp(output) * s
32 for chunk in range(0, num_chunks):
33 cols = chunk * BLOCK_SIZE + offs
34 mask = cols < K
35 go_chunk = tl.load(grad_ptr + row_start + cols, mask=mask, other=0.0)
36 go32 = go_chunk.to(tl.float32)
37 y_chunk = tl.load(y_logsm_ptr + row_start + cols, mask=mask, other=0.0)
38 y32 = y_chunk.to(tl.float32)
39 sm = tl.exp(y32)
40 gi32 = go32 - sm * s
41 gi = gi32.to(go_chunk.dtype)
42 tl.store(grad_in_ptr + row_start + cols, gi, mask=mask)
45def _normalize_dim(dim: int, ndim: int) -> int:
46 if dim < 0:
47 dim += ndim
48 return dim
51def _choose_block_size(K: int) -> int:
52 if K <= 1:
53 return 1
54 bs = 1 << (int(math.ceil(math.log2(K))))
55 return min(1024, max(1, bs))
58def _log_softmax_backward_data_impl(
59 grad_output: torch.Tensor, output: torch.Tensor, dim: int
60):
61 assert (
62 grad_output.shape == output.shape
63 ), "grad_output and output must have the same shape"
64 assert (
65 grad_output.device.type == "cuda" and output.device.type == "cuda"
66 ), "Inputs must be CUDA tensors"
67 assert grad_output.device == output.device, "Inputs must be on the same device"
68 assert grad_output.dtype == output.dtype, "Inputs must have the same dtype"
69 assert (
70 grad_output.is_contiguous(memory_format=torch.contiguous_format)
71 == grad_output.is_contiguous()
72 ), "Unsupported memory format"
74 if grad_output.dtype not in (torch.float16, torch.bfloat16, torch.float32):
75 # Fallback for unsupported dtype (e.g., float64), compute via PyTorch
76 # grad_input = grad_output - exp(output) * sum(grad_output, dim=dim, keepdim=True)
77 s = grad_output.sum(dim=dim, keepdim=True)
78 return grad_output - output.exp() * s
80 dim = _normalize_dim(dim, grad_output.ndim)
82 # Move reduction dim to the last for contiguous 2D layout [M, K]
83 go_last = torch.movedim(grad_output, dim, -1).contiguous()
84 y_last = torch.movedim(output, dim, -1).contiguous()
85 K = go_last.shape[-1]
86 if go_last.numel() == 0 or K == 0:
87 return grad_output.clone()
89 M = go_last.numel() // K
91 go_2d = go_last.view(M, K)
92 y_2d = y_last.view(M, K)
93 gi_2d = torch.empty_like(go_2d)
95 BLOCK_SIZE = _choose_block_size(K)
96 grid = (M,)
98 _log_softmax_bwd_kernel[grid](
99 go_2d,
100 y_2d,
101 gi_2d,
102 M,
103 K,
104 BLOCK_SIZE=BLOCK_SIZE,
105 )
107 gi_last = gi_2d.view_as(go_last)
108 grad_input = torch.movedim(gi_last, -1, dim)
109 return grad_input
112def _log_softmax_backward_data(
113 grad_output: torch.Tensor, output: torch.Tensor, dim: int, input_dtype: torch.dtype
114):
115 return _log_softmax_backward_data_impl(grad_output, output, dim)
118def _log_softmax_backward_data_out(
119 grad_output: torch.Tensor,
120 output: torch.Tensor,
121 dim: int,
122 input_dtype: torch.dtype,
123 out: torch.Tensor,
124):
125 res = _log_softmax_backward_data_impl(grad_output, output, dim)
126 out.copy_(res)
127 return out