Coverage for src/flag_gems/ops/softmax.py: 33%
228 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.ops.zeros import zero_
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
16@libentry()
17@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner"))
18@triton.jit
19def softmax_kernel_non_inner(
20 output_ptr,
21 input_ptr,
22 M,
23 N,
24 K,
25 TILE_N: tl.constexpr,
26 TILE_K: tl.constexpr,
27 ONE_TILE_PER_CTA: tl.constexpr,
28):
29 pid_k = tle.program_id(1)
30 pid_m = tle.program_id(0)
32 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)
34 if ONE_TILE_PER_CTA:
35 n_offsets = tl.arange(0, TILE_N)
36 offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets
37 mask = (n_offsets[:, None] < N) & (k_offsets < K)
38 input_ptrs = input_ptr + offset
39 inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
40 m = tl.max(inp, 0)
41 e = tl.exp(inp - m[None, :])
42 z = tl.sum(e, 0)
43 out = e / z
44 output_ptrs = output_ptr + offset
45 tl.store(output_ptrs, out, mask=mask)
46 else:
47 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32)
48 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32)
50 # specialization does not improve performance inn this example, as tested
51 for start_n in range(0, N, TILE_N):
52 n_offsets = start_n + tl.arange(0, TILE_N)
53 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
54 mask = (n_offsets[:, None] < N) & (k_offsets < K)
55 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
56 m_new = tl.maximum(m, inp)
57 all_neg_inf = m_new == float("-inf")
58 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
59 m = m_new
61 m_reduced = tl.max(m, 0) # (TILE_K,)
62 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, )
63 m = m_reduced
65 # specialization does not improve performance inn this example, as tested
66 previous_multiple = prev_multiple_of(N, TILE_N)
67 for start_n in range(0, N, TILE_N):
68 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
69 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
70 mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K)
71 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
72 o = tl.exp(inp - m[None, :]) / z[None, :]
73 tl.store(output_ptr + offsets, o, mask=mask)
76@triton.jit
77def next_multiple_of(a, b):
78 # the smallest x>=a that x%b ==0
79 return tl.cidv(a, b) * b
82@triton.jit
83def prev_multiple_of(a, b):
84 # the largest x<a that x%b ==0
85 return tl.cdiv(a, b) * b - b
88@libentry()
89@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
90@triton.jit
91def softmax_kernel_inner(
92 output_ptr,
93 input_ptr,
94 M,
95 N,
96 TILE_N: tl.constexpr,
97 ONE_TILE_PER_CTA: tl.constexpr,
98):
99 pid_m = tle.program_id(0)
100 if ONE_TILE_PER_CTA:
101 n_offsets = tl.arange(0, TILE_N)
102 offset = pid_m * N + n_offsets
103 input_ptrs = input_ptr + offset
104 mask = n_offsets < N
105 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(
106 output_ptr.dtype.element_ty
107 )
108 m = tl.max(inp, 0)
109 e = tl.exp(inp - m)
110 z = tl.sum(e, 0)
111 out = e / z
112 output_ptrs = output_ptr + offset
113 tl.store(output_ptrs, out, mask=mask)
114 else:
115 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32)
116 z = tl.full([TILE_N], value=0.0, dtype=tl.float32)
117 input_ptr += pid_m * N
118 output_ptr += pid_m * N
120 previous_multiple = prev_multiple_of(N, TILE_N)
121 for start_n in range(0, previous_multiple, TILE_N):
122 n_offsets = start_n + tl.arange(0, TILE_N)
123 inp = tl.load(input_ptr + n_offsets)
124 m_new = tl.maximum(m, inp)
125 # it is possible that there are -inf's in the input
126 all_neg_inf = m_new == float("-inf")
127 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
128 m = m_new
129 # specialize the last iteration
130 for start_n in range(previous_multiple, N, TILE_N):
131 n_offsets = start_n + tl.arange(0, TILE_N)
132 mask = n_offsets < N
133 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf"))
134 m_new = tl.maximum(m, inp)
135 all_neg_inf = m_new == float("-inf")
136 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
137 m = m_new
139 m_reduced = tl.max(m, 0)
140 z = tl.sum(z * tl.exp(m - m_reduced), 0)
141 m = m_reduced
143 previous_multiple = prev_multiple_of(N, TILE_N)
144 # specialize the first iteration
145 for start_n in range(0, TILE_N, TILE_N):
146 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
147 mask = n_offsets < N
148 inp = tl.load(
149 input_ptr + n_offsets,
150 mask=mask,
151 other=-float("inf"),
152 eviction_policy="evict_first",
153 )
154 o = tl.exp(inp - m) / z
155 tl.store(output_ptr + n_offsets, o, mask=mask)
156 for start_n in range(TILE_N, N, TILE_N):
157 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
158 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first")
159 o = tl.exp(inp - m) / z
160 tl.store(output_ptr + n_offsets, o)
163# ------------------------ backward -------------------------------
164@libentry()
165@triton.autotune(
166 configs=runtime.get_tuned_config("softmax_non_inner"),
167 key=[
168 "M",
169 "N",
170 "K",
171 ],
172)
173@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner"))
174@triton.jit
175def softmax_backward_kernel_non_inner(
176 out_ptr,
177 out_grad_ptr,
178 in_grad_ptr,
179 M,
180 N,
181 K,
182 TILE_N: tl.constexpr,
183 TILE_K: tl.constexpr,
184 ONE_TILE_PER_CTA: tl.constexpr,
185):
186 pid_m = tle.program_id(0)
187 pid_k = tle.program_id(1)
188 offsets_k = pid_k * TILE_K + tl.arange(0, TILE_K)
190 if ONE_TILE_PER_CTA:
191 offsets_n = tl.arange(0, TILE_N)
192 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
193 mask = (offsets_n < N)[:, None] & (offsets_k < K)
194 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
195 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
196 scale = tl.sum(out_tile * out_grad_tile, axis=0)
197 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
198 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
199 else:
200 offsets_n = tl.arange(0, TILE_N)
201 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
202 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32)
203 for _ in range(0, N, TILE_N):
204 mask = (offsets_n < N)[:, None] & (offsets_k < K)
205 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
206 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
207 scale += out_tile * out_grad_tile
208 offsets_n += TILE_N
209 offsets += TILE_N * K
210 scale = tl.sum(scale, axis=0) # (TILE_K)
212 offsets_n = tl.arange(0, TILE_N)
213 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
214 for _ in range(0, N, TILE_N):
215 mask = (offsets_n < N)[:, None] & (offsets_k < K)
216 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
217 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
218 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
219 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
220 offsets_n += TILE_N
221 offsets += TILE_N * K
224@libentry()
225@triton.autotune(
226 configs=runtime.get_tuned_config("softmax_inner"),
227 key=["M", "N"],
228)
229@triton.heuristics(
230 values=runtime.get_heuristic_config("softmax_backward_inner"),
231)
232@triton.jit
233def softmax_backward_kernel_inner(
234 out_ptr,
235 out_grad_ptr,
236 in_grad_ptr,
237 M,
238 N,
239 TILE_M: tl.constexpr,
240 TILE_N: tl.constexpr,
241 ONE_TILE_PER_CTA: tl.constexpr,
242):
243 pid_m = tle.program_id(0)
244 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M)
245 if ONE_TILE_PER_CTA:
246 n_offsets = tl.arange(0, TILE_N)
247 offsets = m_offsets[:, None] * N + n_offsets
248 mask = (m_offsets[:, None] < M) & (n_offsets < N)
249 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
250 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
251 scale = tl.sum(out_tile * out_grad_tile, 1)
252 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
253 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
254 else:
255 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float32)
257 n_offsets = tl.arange(0, TILE_N)
258 offsets = m_offsets[:, None] * N + n_offsets
259 for _ in range(0, N, TILE_N):
260 mask = (m_offsets[:, None] < M) & (n_offsets < N)
261 out_tile = tl.load(
262 out_ptr + offsets, mask=mask, eviction_policy="evict_last"
263 ).to(tl.float32)
264 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
265 scale += out_tile * out_grad_tile
266 n_offsets += TILE_N
267 offsets += TILE_N
268 scale = tl.sum(scale, 1) # (TILE_M,)
270 n_offsets = tl.arange(0, TILE_N)
271 offsets = m_offsets[:, None] * N + n_offsets
272 for _ in range(0, N, TILE_N):
273 mask = (m_offsets[:, None] < M) & (n_offsets < N)
274 out_tile = tl.load(
275 out_ptr + offsets, mask=mask, eviction_policy="evict_first"
276 ).to(tl.float32)
277 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
278 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
279 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
280 n_offsets += TILE_N
281 offsets += TILE_N
284def softmax(self, dim, half_to_float=False):
285 logger.debug("GEMS SOFTMAX")
287 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
289 # special handling for dim = 0 and empty tensor
290 if self.numel() == 0:
291 # empty tensor, return the same shape with 1's
292 out_shape = list(self.shape)
293 out = torch.empty(out_shape, dtype=self.dtype, device=self.device)
294 zero_(out)
295 return out
297 dim = dim % self.ndim
298 M = 1
299 N = self.shape[dim]
300 for i in range(dim):
301 M *= self.shape[i] # pre_dim
302 self = self.contiguous()
303 if half_to_float:
304 dtype = torch.float32
305 else:
306 dtype = self.dtype
307 out = torch.empty_like(self, dtype=dtype)
308 K = self.numel() // M // N # post_dim
310 with torch_device_fn.device(self.device):
311 if K > 1:
312 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
313 softmax_kernel_non_inner[grid](
314 out,
315 self,
316 M,
317 N,
318 K,
319 )
320 else:
321 grid = (M, 1, 1)
322 softmax_kernel_inner[grid](
323 out,
324 self,
325 M,
326 N,
327 )
328 return out
331def softmax_backward(grad_output, output, dim, input_dtype):
332 logger.debug("GEMS SOFTMAX VJP")
334 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
335 dim = dim % output.ndim
336 M = 1
337 N = output.shape[dim]
338 for i in range(dim):
339 M *= output.shape[i]
341 grad_output = grad_output.contiguous()
342 in_grad = torch.empty_like(output, dtype=input_dtype)
343 K = output.numel() // M // N
345 with torch_device_fn.device(in_grad.device):
346 if K > 1:
347 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
348 softmax_backward_kernel_non_inner[grid](
349 output,
350 grad_output,
351 in_grad,
352 M,
353 N,
354 K,
355 )
356 else:
357 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
358 softmax_backward_kernel_inner[grid](
359 output,
360 grad_output,
361 in_grad,
362 M,
363 N,
364 )
365 return in_grad