Coverage for src/flag_gems/ops/log_softmax.py: 49%
98 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.jit
17def log_softmax_kernel(
18 output_ptr,
19 input_ptr,
20 M,
21 N,
22 K,
23 BLOCK_M: tl.constexpr = 8,
24 BLOCK_N: tl.constexpr = 256,
25):
26 pid_m = tle.program_id(0)
27 pid_k = tle.program_id(1)
28 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
30 # TODO(chenfeiyu): consider float64 add add a utility function to get accumulator type
31 m = tl.full([BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32)
32 z = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32)
33 for start_n in range(0, N, BLOCK_N):
34 n_offset = start_n + tl.arange(0, BLOCK_N)
35 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
36 mask = m_offset[:, None] < M and n_offset[None, :] < N
37 input_ptrs = input_ptr + offset
38 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
39 m_new = tl.maximum(inp, m)
40 all_neg_inf = m_new == float("-inf")
41 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
42 m = m_new
44 m_reduced = tl.max(m, 1)
45 z = tl.sum(z * tl.exp(m - m_reduced[:, None]), 1)
46 m = m_reduced
48 for start_n in range(0, N, BLOCK_N):
49 n_offset = start_n + tl.arange(0, BLOCK_N)
50 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
51 mask = m_offset[:, None] < M and n_offset[None, :] < N
52 input_ptrs = input_ptr + offset
53 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
54 o = inp - m[:, None] - tl.log(z[:, None])
55 tl.store(output_ptr + offset, o, mask=mask)
58@libentry()
59@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"])
60@triton.jit
61def log_softmax_backward_kernel(
62 out_ptr,
63 out_grad_ptr,
64 in_grad_ptr,
65 M,
66 N,
67 K,
68 BLOCK_M: tl.constexpr,
69 BLOCK_N: tl.constexpr,
70):
71 pid_m = tle.program_id(0)
72 pid_k = tle.program_id(1)
73 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
75 scale = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
76 for start_n in range(0, N, BLOCK_N):
77 n_offset = start_n + tl.arange(0, BLOCK_N)
78 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
79 mask = m_offset[:, None] < M and n_offset[None, :] < N
80 out_grad_ptrs = out_grad_ptr + offsets
81 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
82 scale += out_grad
83 scale = tl.sum(scale, 1)
85 for start_n in range(0, N, BLOCK_N):
86 n_offset = start_n + tl.arange(0, BLOCK_N)
87 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
88 mask = m_offset[:, None] < M and n_offset[None, :] < N
89 out_ptrs = out_ptr + offsets
90 out = tl.load(out_ptrs, mask=mask).to(tl.float32)
91 out_grad_ptrs = out_grad_ptr + offsets
92 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)
93 in_grad = out_grad - tl.exp(out) * scale[:, None]
94 in_grad_ptrs = in_grad_ptr + offsets
95 tl.store(in_grad_ptrs, in_grad, mask=mask)
98def log_softmax(self, dim, half_to_float=False):
99 logger.debug("GEMS LOG_SOFTMAX")
101 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
102 dim = dim % self.ndim
103 M = 1
104 N = self.shape[dim]
105 for i in range(dim):
106 M *= self.shape[i]
107 inp = self.contiguous()
108 if half_to_float:
109 dtype = torch.float32
110 else:
111 dtype = self.dtype
112 out = torch.empty_like(inp, dtype=dtype)
113 K = inp.numel() // M // N
115 grid = lambda meta: (
116 triton.cdiv(M, meta["BLOCK_M"]),
117 K,
118 )
119 with torch_device_fn.device(inp.device):
120 log_softmax_kernel[grid](
121 out,
122 inp,
123 M,
124 N,
125 K,
126 num_warps=8,
127 )
128 return out
131def log_softmax_backward(grad_output, output, dim, input_dtype):
132 logger.debug("GEMS LOG_SOFTMAX VJP")
134 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
135 dim = dim % output.ndim
136 M = 1
137 N = output.shape[dim]
138 for i in range(dim):
139 M *= output.shape[i]
141 grad_output = grad_output.contiguous()
142 in_grad = torch.empty_like(output, dtype=input_dtype)
143 K = output.numel() // M // N
145 grid = lambda meta: (
146 triton.cdiv(M, meta["BLOCK_M"]),
147 K,
148 )
149 with torch_device_fn.device(in_grad.device):
150 log_softmax_backward_kernel[grid](
151 output,
152 grad_output,
153 in_grad,
154 M,
155 N,
156 K,
157 )
158 return in_grad