Coverage for src/flag_gems/runtime/backend/_sunrise/ops/softmax.py: 0%
247 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.ops.zeros import zero_
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(__name__)
16@libentry()
17@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner"))
18@triton.jit
19def softmax_kernel_non_inner(
20 output_ptr,
21 input_ptr,
22 M,
23 N,
24 K,
25 TILE_N: tl.constexpr,
26 TILE_K: tl.constexpr,
27 ONE_TILE_PER_CTA: tl.constexpr,
28):
29 pid_k = ext.program_id(1)
30 pid_m = ext.program_id(0)
32 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)
34 if ONE_TILE_PER_CTA:
35 n_offsets = tl.arange(0, TILE_N)
36 offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets
37 mask = (n_offsets[:, None] < N) & (k_offsets < K)
38 input_ptrs = input_ptr + offset
39 inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
40 m = tl.max(inp, 0)
41 e = tl.exp(inp - m[None, :])
42 z = tl.sum(e, 0)
43 out = e / z
44 output_ptrs = output_ptr + offset
45 tl.store(output_ptrs, out, mask=mask)
46 else:
47 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32)
48 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32)
50 # specialization does not improve performance inn this example, as tested
51 for start_n in range(0, N, TILE_N):
52 n_offsets = start_n + tl.arange(0, TILE_N)
53 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
54 mask = (n_offsets[:, None] < N) & (k_offsets < K)
55 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
56 m_new = tl.maximum(m, inp)
57 all_neg_inf = m_new == float("-inf")
58 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
59 m = m_new
61 m_reduced = tl.max(m, 0) # (TILE_K,)
62 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, )
63 m = m_reduced
65 # specialization does not improve performance inn this example, as tested
66 previous_multiple = prev_multiple_of(N, TILE_N)
67 for start_n in range(0, N, TILE_N):
68 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
69 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
70 mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K)
71 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
72 o = tl.exp(inp - m[None, :]) / z[None, :]
73 tl.store(output_ptr + offsets, o, mask=mask)
76@triton.jit
77def next_multiple_of(a, b):
78 # the smallest x>=a that x%b ==0
79 return tl.cidv(a, b) * b
82@triton.jit
83def prev_multiple_of(a, b):
84 # the largest x<a that x%b ==0
85 return tl.cdiv(a, b) * b - b
88@libentry()
89@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
90@triton.jit
91def softmax_kernel_inner(
92 output_ptr,
93 input_ptr,
94 M,
95 N,
96 TILE_N: tl.constexpr,
97 ONE_TILE_PER_CTA: tl.constexpr,
98):
99 pid_m = ext.program_id(0)
100 if ONE_TILE_PER_CTA:
101 n_offsets = tl.arange(0, TILE_N)
102 offset = pid_m * N + n_offsets
103 input_ptrs = input_ptr + offset
104 mask = n_offsets < N
105 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(
106 output_ptr.dtype.element_ty
107 )
108 m = tl.max(inp, 0)
109 e = tl.exp(inp - m)
110 z = tl.sum(e, 0)
111 out = e / z
112 output_ptrs = output_ptr + offset
113 tl.store(output_ptrs, out, mask=mask)
114 else:
115 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32)
116 z = tl.full([TILE_N], value=0.0, dtype=tl.float32)
117 input_ptr += pid_m * N
118 output_ptr += pid_m * N
120 previous_multiple = prev_multiple_of(N, TILE_N)
121 for start_n in range(0, previous_multiple, TILE_N):
122 n_offsets = start_n + tl.arange(0, TILE_N)
123 inp = tl.load(input_ptr + n_offsets)
124 m_new = tl.maximum(m, inp)
125 # it is possible that there are -inf's in the input
126 all_neg_inf = m_new == float("-inf")
127 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
128 m = m_new
129 # specialize the last iteration
130 for start_n in range(previous_multiple, N, TILE_N):
131 n_offsets = start_n + tl.arange(0, TILE_N)
132 mask = n_offsets < N
133 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf"))
134 m_new = tl.maximum(m, inp)
135 all_neg_inf = m_new == float("-inf")
136 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
137 m = m_new
139 m_reduced = tl.max(m, 0)
140 z = tl.sum(z * tl.exp(m - m_reduced), 0)
141 m = m_reduced
143 previous_multiple = prev_multiple_of(N, TILE_N)
144 # specialize the first iteration
145 for start_n in range(0, TILE_N, TILE_N):
146 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
147 mask = n_offsets < N
148 inp = tl.load(
149 input_ptr + n_offsets,
150 mask=mask,
151 other=-float("inf"),
152 eviction_policy="evict_first",
153 )
154 o = tl.exp(inp - m) / z
155 tl.store(output_ptr + n_offsets, o, mask=mask)
156 for start_n in range(TILE_N, N, TILE_N):
157 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
158 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first")
159 o = tl.exp(inp - m) / z
160 tl.store(output_ptr + n_offsets, o)
163# ------------------------ backward -------------------------------
164@libentry()
165@triton.autotune(
166 configs=runtime.get_tuned_config("softmax_non_inner"),
167 key=[
168 "M",
169 "N",
170 "K",
171 ],
172)
173@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner"))
174@triton.jit
175def softmax_backward_kernel_non_inner(
176 out_ptr,
177 out_grad_ptr,
178 in_grad_ptr,
179 M,
180 N,
181 K,
182 TILE_N: tl.constexpr,
183 TILE_K: tl.constexpr,
184 ONE_TILE_PER_CTA: tl.constexpr,
185):
186 pid_m = ext.program_id(0)
187 pid_k = ext.program_id(1)
188 offsets_k = pid_k * TILE_K + tl.arange(0, TILE_K)
190 if ONE_TILE_PER_CTA:
191 offsets_n = tl.arange(0, TILE_N)
192 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
193 mask = (offsets_n < N)[:, None] & (offsets_k < K)
194 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
195 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
196 scale = tl.sum(out_tile * out_grad_tile, axis=0)
197 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
198 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
199 else:
200 offsets_n = tl.arange(0, TILE_N)
201 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
202 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32)
203 for _ in range(0, N, TILE_N):
204 mask = (offsets_n < N)[:, None] & (offsets_k < K)
205 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
206 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
207 scale += out_tile * out_grad_tile
208 offsets_n += TILE_N
209 offsets += TILE_N * K
210 scale = tl.sum(scale, axis=0) # (TILE_K)
212 offsets_n = tl.arange(0, TILE_N)
213 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
214 for _ in range(0, N, TILE_N):
215 mask = (offsets_n < N)[:, None] & (offsets_k < K)
216 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
217 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
218 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
219 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
220 offsets_n += TILE_N
221 offsets += TILE_N * K
224@libentry()
225@triton.autotune(
226 configs=runtime.get_tuned_config("softmax_inner"),
227 key=["M", "N"],
228)
229@triton.heuristics(
230 values=runtime.get_heuristic_config("softmax_backward_inner"),
231)
232@triton.jit
233def softmax_backward_kernel_inner(
234 out_ptr,
235 out_grad_ptr,
236 in_grad_ptr,
237 M,
238 N,
239 TILE_M: tl.constexpr,
240 TILE_N: tl.constexpr,
241 ONE_TILE_PER_CTA: tl.constexpr,
242):
243 pid_m = ext.program_id(0)
244 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M)
245 if ONE_TILE_PER_CTA:
246 n_offsets = tl.arange(0, TILE_N)
247 offsets = m_offsets[:, None] * N + n_offsets
248 mask = (m_offsets[:, None] < M) & (n_offsets < N)
249 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
250 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
251 scale = tl.sum(out_tile * out_grad_tile, 1)
252 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
253 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
254 else:
255 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float32)
257 n_offsets = tl.arange(0, TILE_N)
258 offsets = m_offsets[:, None] * N + n_offsets
259 for _ in range(0, N, TILE_N):
260 mask = (m_offsets[:, None] < M) & (n_offsets < N)
261 out_tile = tl.load(
262 out_ptr + offsets, mask=mask, eviction_policy="evict_last"
263 ).to(tl.float32)
264 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
265 scale += out_tile * out_grad_tile
266 n_offsets += TILE_N
267 offsets += TILE_N
268 scale = tl.sum(scale, 1) # (TILE_M,)
270 n_offsets = tl.arange(0, TILE_N)
271 offsets = m_offsets[:, None] * N + n_offsets
272 for _ in range(0, N, TILE_N):
273 mask = (m_offsets[:, None] < M) & (n_offsets < N)
274 out_tile = tl.load(
275 out_ptr + offsets, mask=mask, eviction_policy="evict_first"
276 ).to(tl.float32)
277 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
278 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
279 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
280 n_offsets += TILE_N
281 offsets += TILE_N
284def softmax_out(self, dim, half_to_float=False, *, out):
285 logger.debug("GEMS SOFTMAX_OUT")
287 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
289 if self.numel() == 0:
290 if tuple(out.shape) != tuple(self.shape):
291 # out.resize_(self.shape) # [sunrise fix][PTPU] out.resize_(shape) not supported.
292 out = out.cpu().resize_(self.shape).to(out.device)
293 zero_(out)
294 return out
296 dim = dim % self.ndim
297 M = 1
298 N = self.shape[dim]
299 for i in range(dim):
300 M *= self.shape[i]
301 self = self.contiguous()
302 dtype = torch.float32 if half_to_float else self.dtype
303 if tuple(out.shape) != tuple(self.shape):
304 # out.resize_(self.shape) # [sunrise fix][PTPU] out.resize_(shape) not supported.
305 out = out.cpu().resize_(self.shape).to(out.device)
306 if out.dtype != dtype:
307 raise RuntimeError(f"_softmax.out: expected out dtype {dtype}, got {out.dtype}")
308 K = self.numel() // M // N
310 with torch_device_fn.device(self.device):
311 if K > 1:
312 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
313 softmax_kernel_non_inner[grid](
314 out,
315 self,
316 M,
317 N,
318 K,
319 )
320 else:
321 grid = (M, 1, 1)
322 softmax_kernel_inner[grid](
323 out,
324 self,
325 M,
326 N,
327 )
328 return out
331def softmax(self, dim, half_to_float=False):
332 logger.debug("GEMS SOFTMAX")
334 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
336 if self.numel() == 0:
337 out_shape = list(self.shape)
338 out = torch.empty(out_shape, dtype=self.dtype, device=self.device)
339 zero_(out)
340 return out
342 dtype = torch.float32 if half_to_float else self.dtype
343 out = torch.empty_like(self, dtype=dtype)
344 return softmax_out(self, dim, half_to_float, out=out)
347def softmax_backward_out(grad_output, output, dim, input_dtype, *, grad_input):
348 logger.debug("GEMS SOFTMAX_BACKWARD_OUT")
350 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
351 dim = dim % output.ndim
352 M = 1
353 N = output.shape[dim]
354 for i in range(dim):
355 M *= output.shape[i]
357 grad_output = grad_output.contiguous()
358 if tuple(grad_input.shape) != tuple(output.shape):
359 # grad_input.resize_(output.shape) # [sunrise fix][PTPU] out.resize_(shape) not supported.
360 grad_input = grad_input.cpu().resize_(output.shape).to(grad_input.device)
361 if grad_input.dtype != input_dtype:
362 raise RuntimeError(
363 f"_softmax_backward_data.out: expected grad_input dtype {input_dtype}, got {grad_input.dtype}"
364 )
365 K = output.numel() // M // N
367 with torch_device_fn.device(grad_input.device):
368 if K > 1:
369 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
370 softmax_backward_kernel_non_inner[grid](
371 output,
372 grad_output,
373 grad_input,
374 M,
375 N,
376 K,
377 )
378 else:
379 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
380 softmax_backward_kernel_inner[grid](
381 output,
382 grad_output,
383 grad_input,
384 M,
385 N,
386 )
387 return grad_input
390def softmax_backward(grad_output, output, dim, input_dtype):
391 logger.debug("GEMS SOFTMAX_BACKWARD")
392 in_grad = torch.empty_like(output, dtype=input_dtype)
393 return softmax_backward_out(
394 grad_output, output, dim, input_dtype, grad_input=in_grad
395 )