Coverage for src/flag_gems/runtime/backend/_ascend/ops/softmax.py: 0%
213 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner"))
17@triton.jit
18def softmax_kernel_non_inner(
19 output_ptr,
20 input_ptr,
21 M,
22 N,
23 K,
24 TILE_N: tl.constexpr,
25 TILE_K: tl.constexpr,
26 ONE_TILE_PER_CTA: tl.constexpr,
27):
28 pid_k = tle.program_id(1)
29 pid_m = tle.program_id(0)
31 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)
33 if ONE_TILE_PER_CTA:
34 n_offsets = tl.arange(0, TILE_N)
35 offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets
36 mask = (n_offsets[:, None] < N) & (k_offsets < K)
37 input_ptrs = input_ptr + offset
38 inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
39 m = tl.max(inp, 0)
40 e = tl.exp(inp - m[None, :])
41 z = tl.sum(e, 0)
42 out = e / z
43 output_ptrs = output_ptr + offset
44 tl.store(output_ptrs, out, mask=mask)
45 else:
46 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32)
47 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32)
49 # specialization does not improve performance inn this example, as tested
50 for start_n in range(0, N, TILE_N):
51 n_offsets = start_n + tl.arange(0, TILE_N)
52 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
53 mask = (n_offsets[:, None] < N) & (k_offsets < K)
54 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
55 m_new = tl.maximum(m, inp)
56 all_neg_inf = m_new == float("-inf")
57 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
58 m = m_new
60 m_reduced = tl.max(m, 0) # (TILE_K,)
61 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, )
62 m = m_reduced
64 # specialization does not improve performance inn this example, as tested
65 previous_multiple = prev_multiple_of(N, TILE_N)
66 for start_n in range(0, N, TILE_N):
67 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
68 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
69 mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K)
70 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf"))
71 o = tl.exp(inp - m[None, :]) / z[None, :]
72 tl.store(output_ptr + offsets, o, mask=mask)
75@triton.jit
76def next_multiple_of(a, b):
77 # the smallest x>=a that x%b ==0
78 return tl.cidv(a, b) * b
81@triton.jit
82def prev_multiple_of(a, b):
83 # the largest x<a that x%b ==0
84 return tl.cdiv(a, b) * b - b
87@libentry()
88@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
89@triton.jit
90def softmax_kernel_inner(
91 output_ptr,
92 input_ptr,
93 M,
94 N,
95 TILE_N: tl.constexpr,
96 ONE_TILE_PER_CTA: tl.constexpr,
97):
98 pid_m = tle.program_id(0)
99 if ONE_TILE_PER_CTA:
100 n_offsets = tl.arange(0, TILE_N)
101 offset = pid_m * N + n_offsets
102 input_ptrs = input_ptr + offset
103 mask = n_offsets < N
104 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(
105 output_ptr.dtype.element_ty
106 )
107 m = tl.max(inp, 0)
108 e = tl.exp(inp - m)
109 z = tl.sum(e, 0)
110 out = e / z
111 output_ptrs = output_ptr + offset
112 tl.store(output_ptrs, out, mask=mask)
113 else:
114 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32)
115 z = tl.full([TILE_N], value=0.0, dtype=tl.float32)
116 input_ptr += pid_m * N
117 output_ptr += pid_m * N
119 previous_multiple = prev_multiple_of(N, TILE_N)
120 for start_n in range(0, previous_multiple, TILE_N):
121 n_offsets = start_n + tl.arange(0, TILE_N)
122 inp = tl.load(input_ptr + n_offsets)
123 m_new = tl.maximum(m, inp)
124 # it is possible that there are -inf's in the input
125 all_neg_inf = m_new == float("-inf")
126 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
127 m = m_new
128 # specialize the last iteration
129 for start_n in range(previous_multiple, N, TILE_N):
130 n_offsets = start_n + tl.arange(0, TILE_N)
131 mask = n_offsets < N
132 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf"))
133 m_new = tl.maximum(m, inp)
134 all_neg_inf = m_new == float("-inf")
135 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
136 m = m_new
138 m_reduced = tl.max(m, 0)
139 z = tl.sum(z * tl.exp(m - m_reduced), 0)
140 m = m_reduced
142 previous_multiple = prev_multiple_of(N, TILE_N)
143 # specialize the first iteration
144 for start_n in range(0, TILE_N, TILE_N):
145 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
146 mask = n_offsets < N
147 inp = tl.load(
148 input_ptr + n_offsets,
149 mask=mask,
150 other=-float("inf"),
151 eviction_policy="evict_first",
152 )
153 o = tl.exp(inp - m) / z
154 tl.store(output_ptr + n_offsets, o, mask=mask)
155 for start_n in range(TILE_N, N, TILE_N):
156 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
157 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first")
158 o = tl.exp(inp - m) / z
159 tl.store(output_ptr + n_offsets, o)
162# ------------------------ backward -------------------------------
163@libentry()
164@triton.autotune(
165 configs=runtime.get_tuned_config("softmax_non_inner"),
166 key=[
167 "M",
168 "N",
169 "K",
170 ],
171)
172@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner"))
173@triton.jit
174def softmax_backward_kernel_non_inner(
175 out_ptr,
176 out_grad_ptr,
177 in_grad_ptr,
178 M,
179 N,
180 K,
181 TILE_N: tl.constexpr,
182 TILE_K: tl.constexpr,
183 ONE_TILE_PER_CTA: tl.constexpr,
184):
185 pid_m = tle.program_id(0)
186 pid_k = tle.program_id(1)
187 offsets_k = pid_k * TILE_K + tl.arange(0, TILE_K)
189 if ONE_TILE_PER_CTA:
190 offsets_n = tl.arange(0, TILE_N)
191 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
192 mask = (offsets_n < N)[:, None] & (offsets_k < K)
193 out_tile = tl.load(out_ptr + offsets, mask=mask)
194 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask)
195 scale = tl.sum(out_tile * out_grad_tile, axis=0)
196 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
197 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
198 else:
199 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32)
200 for off in range(0, N, TILE_N):
201 offsets_n = tl.arange(0, TILE_N) + off
202 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
203 mask = (offsets_n < N)[:, None] & (offsets_k < K)
204 out_tile = tl.load(out_ptr + offsets, mask=mask)
205 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask)
206 scale += out_tile * out_grad_tile
207 # offsets_n += TILE_N
208 # offsets += TILE_N * K
209 scale = tl.sum(scale, axis=0) # (TILE_K)
211 for off in range(0, N, TILE_N):
212 offsets_n = tl.arange(0, TILE_N) + off
213 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
214 mask = (offsets_n < N)[:, None] & (offsets_k < K)
215 out_tile = tl.load(out_ptr + offsets, mask=mask)
216 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask)
217 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
218 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
219 # offsets_n += TILE_N
220 # offsets += TILE_N * K
223@libentry()
224@triton.autotune(
225 configs=runtime.get_tuned_config("softmax_inner"),
226 key=["M", "N"],
227)
228@triton.heuristics(
229 values=runtime.get_heuristic_config("softmax_backward_inner"),
230)
231@triton.jit
232def softmax_backward_kernel_inner(
233 out_ptr,
234 out_grad_ptr,
235 in_grad_ptr,
236 M,
237 N,
238 TILE_M: tl.constexpr,
239 TILE_N: tl.constexpr,
240 ONE_TILE_PER_CTA: tl.constexpr,
241):
242 pid_m = tle.program_id(0)
243 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M)
244 if ONE_TILE_PER_CTA:
245 n_offsets = tl.arange(0, TILE_N)
246 offsets = m_offsets[:, None] * N + n_offsets
247 mask = (m_offsets[:, None] < M) & (n_offsets < N)
248 out_tile = tl.load(out_ptr + offsets, mask=mask)
249 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask)
250 scale = tl.sum(out_tile * out_grad_tile, 1)
251 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
252 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
253 else:
254 # scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float32)
255 scale = tl.zeros([TILE_M], dtype=tl.float32)
257 for off in range(0, N, TILE_N):
258 n_offsets = tl.arange(0, TILE_N) + off
259 offsets = m_offsets[:, None] * N + n_offsets
260 mask = (m_offsets[:, None] < M) & (n_offsets < N)
261 out_tile = tl.load(
262 out_ptr + offsets, mask=mask, eviction_policy="evict_last"
263 )
264 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask)
265 # scale += out_tile * out_grad_tile
266 scale += tl.sum(out_tile * out_grad_tile, axis=1)
267 # n_offsets += TILE_N
268 # offsets += TILE_N
269 # scale = tl.sum(scale, 1) # (TILE_M,)
271 for off in range(0, N, TILE_N):
272 n_offsets = tl.arange(0, TILE_N) + off
273 offsets = m_offsets[:, None] * N + n_offsets
274 mask = (m_offsets[:, None] < M) & (n_offsets < N)
275 out_tile = tl.load(
276 out_ptr + offsets, mask=mask, eviction_policy="evict_first"
277 )
278 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask)
279 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
280 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
281 # n_offsets += TILE_N
282 # offsets += TILE_N
285def softmax(self, dim, half_to_float=False):
286 logger.debug("GEMS_ASCEND SOFTMAX")
288 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
289 dim = dim % self.ndim
290 M = 1
291 N = self.shape[dim]
292 for i in range(dim):
293 M *= self.shape[i] # pre_dim
294 self = self.contiguous()
295 if half_to_float:
296 dtype = torch.float32
297 else:
298 dtype = self.dtype
299 out = torch.empty_like(self, dtype=dtype)
300 K = self.numel() // M // N # post_dim
302 with torch_device_fn.device(self.device):
303 if K > 1:
304 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
305 softmax_kernel_non_inner[grid](
306 out,
307 self,
308 M,
309 N,
310 K,
311 )
312 else:
313 grid = (M, 1, 1)
314 softmax_kernel_inner[grid](
315 out,
316 self,
317 M,
318 N,
319 )
320 return out
323def softmax_backward(grad_output, output, dim, input_dtype):
324 logger.debug("GEMS_ASCEND SOFTMAX VJP")
326 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
327 dim = dim % output.ndim
328 M = 1
329 N = output.shape[dim]
330 for i in range(dim):
331 M *= output.shape[i]
333 grad_output = grad_output.contiguous()
334 in_grad = torch.empty_like(output, dtype=input_dtype)
335 K = output.numel() // M // N
337 with torch_device_fn.device(in_grad.device):
338 if K > 1:
339 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
340 softmax_backward_kernel_non_inner[grid](
341 output,
342 grad_output,
343 in_grad,
344 M,
345 N,
346 K,
347 )
348 else:
349 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
350 softmax_backward_kernel_inner[grid](
351 output,
352 grad_output,
353 in_grad,
354 M,
355 N,
356 )
357 return in_grad