Coverage for src/flag_gems/ops/softmax.py: 31%
222 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner"))
17@triton.jit
18def softmax_kernel_non_inner(
19 output_ptr,
20 input_ptr,
21 M,
22 N,
23 K,
24 TILE_N: tl.constexpr,
25 TILE_K: tl.constexpr,
26 ONE_TILE_PER_CTA: tl.constexpr,
27):
28 pid_k = tle.program_id(1)
29 pid_m = tle.program_id(0)
31 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)
33 if ONE_TILE_PER_CTA:
34 n_offsets = tl.arange(0, TILE_N)
35 offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets
36 mask = (n_offsets[:, None] < N) & (k_offsets < K)
37 input_ptrs = input_ptr + offset
38 inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
39 m = tl.max(inp, 0)
40 e = tl.exp(inp - m[None, :])
41 z = tl.sum(e, 0)
42 out = e / z
43 output_ptrs = output_ptr + offset
44 tl.store(output_ptrs, out, mask=mask)
45 else:
46 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32)
47 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32)
49 # specialization does not improve performance inn this example, as tested
50 for start_n in range(0, N, TILE_N):
51 n_offsets = start_n + tl.arange(0, TILE_N)
52 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
53 mask = (n_offsets[:, None] < N) & (k_offsets < K)
54 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
55 m_new = tl.maximum(m, inp)
56 all_neg_inf = m_new == float("-inf")
57 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
58 m = m_new
60 m_reduced = tl.max(m, 0) # (TILE_K,)
61 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, )
62 m = m_reduced
64 # specialization does not improve performance inn this example, as tested
65 previous_multiple = prev_multiple_of(N, TILE_N)
66 for start_n in range(0, N, TILE_N):
67 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
68 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
69 mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K)
70 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
71 o = tl.exp(inp - m[None, :]) / z[None, :]
72 tl.store(output_ptr + offsets, o, mask=mask)
75@triton.jit
76def next_multiple_of(a, b):
77 # the smallest x>=a that x%b ==0
78 return tl.cidv(a, b) * b
81@triton.jit
82def prev_multiple_of(a, b):
83 # the largest x<a that x%b ==0
84 return tl.cdiv(a, b) * b - b
87@libentry()
88@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
89@triton.jit
90def softmax_kernel_inner(
91 output_ptr,
92 input_ptr,
93 M,
94 N,
95 TILE_N: tl.constexpr,
96 ONE_TILE_PER_CTA: tl.constexpr,
97):
98 pid_m = tle.program_id(0)
99 if ONE_TILE_PER_CTA:
100 n_offsets = tl.arange(0, TILE_N)
101 offset = pid_m * N + n_offsets
102 input_ptrs = input_ptr + offset
103 mask = n_offsets < N
104 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(
105 output_ptr.dtype.element_ty
106 )
107 m = tl.max(inp, 0)
108 e = tl.exp(inp - m)
109 z = tl.sum(e, 0)
110 out = e / z
111 output_ptrs = output_ptr + offset
112 tl.store(output_ptrs, out, mask=mask)
113 else:
114 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32)
115 z = tl.full([TILE_N], value=0.0, dtype=tl.float32)
116 input_ptr += pid_m * N
117 output_ptr += pid_m * N
119 previous_multiple = prev_multiple_of(N, TILE_N)
120 for start_n in range(0, previous_multiple, TILE_N):
121 n_offsets = start_n + tl.arange(0, TILE_N)
122 inp = tl.load(input_ptr + n_offsets)
123 m_new = tl.maximum(m, inp)
124 # it is possible that there are -inf's in the input
125 all_neg_inf = m_new == float("-inf")
126 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
127 m = m_new
128 # specialize the last iteration
129 for start_n in range(previous_multiple, N, TILE_N):
130 n_offsets = start_n + tl.arange(0, TILE_N)
131 mask = n_offsets < N
132 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf"))
133 m_new = tl.maximum(m, inp)
134 all_neg_inf = m_new == float("-inf")
135 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
136 m = m_new
138 m_reduced = tl.max(m, 0)
139 z = tl.sum(z * tl.exp(m - m_reduced), 0)
140 m = m_reduced
142 previous_multiple = prev_multiple_of(N, TILE_N)
143 # specialize the first iteration
144 for start_n in range(0, TILE_N, TILE_N):
145 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
146 mask = n_offsets < N
147 inp = tl.load(
148 input_ptr + n_offsets,
149 mask=mask,
150 other=-float("inf"),
151 eviction_policy="evict_first",
152 )
153 o = tl.exp(inp - m) / z
154 tl.store(output_ptr + n_offsets, o, mask=mask)
155 for start_n in range(TILE_N, N, TILE_N):
156 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
157 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first")
158 o = tl.exp(inp - m) / z
159 tl.store(output_ptr + n_offsets, o)
162# ------------------------ backward -------------------------------
163@libentry()
164@triton.autotune(
165 configs=runtime.get_tuned_config("softmax_non_inner"),
166 key=[
167 "M",
168 "N",
169 "K",
170 ],
171)
172@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner"))
173@triton.jit
174def softmax_backward_kernel_non_inner(
175 out_ptr,
176 out_grad_ptr,
177 in_grad_ptr,
178 M,
179 N,
180 K,
181 TILE_N: tl.constexpr,
182 TILE_K: tl.constexpr,
183 ONE_TILE_PER_CTA: tl.constexpr,
184):
185 pid_m = tle.program_id(0)
186 pid_k = tle.program_id(1)
187 offsets_k = pid_k * TILE_K + tl.arange(0, TILE_K)
189 if ONE_TILE_PER_CTA:
190 offsets_n = tl.arange(0, TILE_N)
191 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
192 mask = (offsets_n < N)[:, None] & (offsets_k < K)
193 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
194 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
195 scale = tl.sum(out_tile * out_grad_tile, axis=0)
196 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
197 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
198 else:
199 offsets_n = tl.arange(0, TILE_N)
200 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
201 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32)
202 for _ in range(0, N, TILE_N):
203 mask = (offsets_n < N)[:, None] & (offsets_k < K)
204 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
205 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
206 scale += out_tile * out_grad_tile
207 offsets_n += TILE_N
208 offsets += TILE_N * K
209 scale = tl.sum(scale, axis=0) # (TILE_K)
211 offsets_n = tl.arange(0, TILE_N)
212 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
213 for _ in range(0, N, TILE_N):
214 mask = (offsets_n < N)[:, None] & (offsets_k < K)
215 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
216 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
217 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
218 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
219 offsets_n += TILE_N
220 offsets += TILE_N * K
223@libentry()
224@triton.autotune(
225 configs=runtime.get_tuned_config("softmax_inner"),
226 key=["M", "N"],
227)
228@triton.heuristics(
229 values=runtime.get_heuristic_config("softmax_backward_inner"),
230)
231@triton.jit
232def softmax_backward_kernel_inner(
233 out_ptr,
234 out_grad_ptr,
235 in_grad_ptr,
236 M,
237 N,
238 TILE_M: tl.constexpr,
239 TILE_N: tl.constexpr,
240 ONE_TILE_PER_CTA: tl.constexpr,
241):
242 pid_m = tle.program_id(0)
243 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M)
244 if ONE_TILE_PER_CTA:
245 n_offsets = tl.arange(0, TILE_N)
246 offsets = m_offsets[:, None] * N + n_offsets
247 mask = (m_offsets[:, None] < M) & (n_offsets < N)
248 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
249 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
250 scale = tl.sum(out_tile * out_grad_tile, 1)
251 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
252 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
253 else:
254 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float32)
256 n_offsets = tl.arange(0, TILE_N)
257 offsets = m_offsets[:, None] * N + n_offsets
258 for _ in range(0, N, TILE_N):
259 mask = (m_offsets[:, None] < M) & (n_offsets < N)
260 out_tile = tl.load(
261 out_ptr + offsets, mask=mask, eviction_policy="evict_last"
262 ).to(tl.float32)
263 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
264 scale += out_tile * out_grad_tile
265 n_offsets += TILE_N
266 offsets += TILE_N
267 scale = tl.sum(scale, 1) # (TILE_M,)
269 n_offsets = tl.arange(0, TILE_N)
270 offsets = m_offsets[:, None] * N + n_offsets
271 for _ in range(0, N, TILE_N):
272 mask = (m_offsets[:, None] < M) & (n_offsets < N)
273 out_tile = tl.load(
274 out_ptr + offsets, mask=mask, eviction_policy="evict_first"
275 ).to(tl.float32)
276 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
277 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
278 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
279 n_offsets += TILE_N
280 offsets += TILE_N
283def softmax(self, dim, half_to_float=False):
284 logger.debug("GEMS SOFTMAX")
286 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
287 dim = dim % self.ndim
288 M = 1
289 N = self.shape[dim]
290 for i in range(dim):
291 M *= self.shape[i] # pre_dim
292 self = self.contiguous()
293 if half_to_float:
294 dtype = torch.float32
295 else:
296 dtype = self.dtype
297 out = torch.empty_like(self, dtype=dtype)
298 K = self.numel() // M // N # post_dim
300 with torch_device_fn.device(self.device):
301 if K > 1:
302 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
303 softmax_kernel_non_inner[grid](
304 out,
305 self,
306 M,
307 N,
308 K,
309 )
310 else:
311 grid = (M, 1, 1)
312 softmax_kernel_inner[grid](
313 out,
314 self,
315 M,
316 N,
317 )
318 return out
321def softmax_backward(grad_output, output, dim, input_dtype):
322 logger.debug("GEMS SOFTMAX VJP")
324 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
325 dim = dim % output.ndim
326 M = 1
327 N = output.shape[dim]
328 for i in range(dim):
329 M *= output.shape[i]
331 grad_output = grad_output.contiguous()
332 in_grad = torch.empty_like(output, dtype=input_dtype)
333 K = output.numel() // M // N
335 with torch_device_fn.device(in_grad.device):
336 if K > 1:
337 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
338 softmax_backward_kernel_non_inner[grid](
339 output,
340 grad_output,
341 in_grad,
342 M,
343 N,
344 K,
345 )
346 else:
347 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
348 softmax_backward_kernel_inner[grid](
349 output,
350 grad_output,
351 in_grad,
352 M,
353 N,
354 )
355 return in_grad