Coverage for src/flag_gems/runtime/backend/_mthreads/ops/log_softmax.py: 0%
242 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@triton.jit
16def prev_multiple_of(a, b):
17 # the largest x<a that x%b ==0
18 return tl.cdiv(a, b) * b - b
21@libentry()
22@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner"))
23@triton.jit
24def log_softmax_kernel_non_inner(
25 output_ptr,
26 input_ptr,
27 M,
28 N,
29 K,
30 TILE_N: tl.constexpr,
31 TILE_K: tl.constexpr,
32 ONE_TILE_PER_CTA: tl.constexpr,
33):
34 pid_k = tle.program_id(1)
35 pid_m = tle.program_id(0)
37 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)
39 if ONE_TILE_PER_CTA:
40 n_offsets = tl.arange(0, TILE_N)
41 offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets
42 mask = (n_offsets[:, None] < N) & (k_offsets < K)
43 input_ptrs = input_ptr + offset
44 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
45 m = tl.max(inp, 0)
46 e = tl.exp(inp - m[None, :])
47 z = tl.sum(e, 0)
48 out = inp - m[None, :] - tl.log(z)[None, :]
49 output_ptrs = output_ptr + offset
50 tl.store(output_ptrs, out, mask=mask)
51 else:
52 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32)
53 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32)
55 for start_n in range(0, N, TILE_N):
56 n_offsets = start_n + tl.arange(0, TILE_N)
57 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
58 mask = (n_offsets[:, None] < N) & (k_offsets < K)
59 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf")).to(
60 tl.float32
61 )
62 m_new = tl.maximum(m, inp)
63 all_neg_inf = m_new == float("-inf")
64 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
65 m = m_new
67 m_reduced = tl.max(m, 0) # (TILE_K,)
68 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, )
69 m = m_reduced
71 previous_multiple = prev_multiple_of(N, TILE_N)
72 for start_n in range(0, N, TILE_N):
73 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)
74 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets
75 mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K)
76 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf")).to(
77 tl.float32
78 )
79 o = inp - m[None, :] - tl.log(z)[None, :]
80 tl.store(output_ptr + offsets, o, mask=mask)
83def log_softmax_heur_tile_m(args):
84 """Heuristic for TILE_M in inner kernel."""
85 M = args["M"]
86 N = args["N"]
87 if N <= 256:
88 # For small N, process multiple rows
89 if M >= 4096:
90 return 8
91 elif M >= 1024:
92 return 4
93 else:
94 return 1
95 elif N <= 1024:
96 # For medium N
97 if M >= 4096:
98 return 4
99 elif M >= 1024:
100 return 2
101 else:
102 return 1
103 else:
104 return 1
107def log_softmax_heur_tile_n_inner(args):
108 """Heuristic for TILE_N in inner kernel."""
109 N = args["N"]
110 M = args["M"]
111 if N <= (32 * 1024):
112 tile_n = triton.next_power_of_2(N)
113 # For very small N, we might want larger TILE_N
114 if N <= 32 and M > 1000:
115 return 32
116 # For medium-large N where we process 1 row per CTA,
117 # use smaller TILE_N to enable loop for better register usage
118 if N > 1024 and N <= 8192:
119 return min(tile_n, 2048)
120 return tile_n
121 else:
122 return 4096
125def log_softmax_heur_one_tile_per_cta(args):
126 return args["TILE_N"] >= args["N"]
129def log_softmax_heur_num_warps_inner(args):
130 tile_m = args["TILE_M"]
131 tile_n = args["TILE_N"]
132 tile_size = tile_m * tile_n
133 if tile_size < 2048:
134 return 4
135 elif tile_size < 4096:
136 return 8
137 else:
138 return 16
141@libentry()
142@triton.heuristics(
143 {
144 "TILE_M": log_softmax_heur_tile_m,
145 "TILE_N": log_softmax_heur_tile_n_inner,
146 "ONE_TILE_PER_CTA": log_softmax_heur_one_tile_per_cta,
147 "num_warps": log_softmax_heur_num_warps_inner,
148 }
149)
150@triton.jit
151def log_softmax_kernel_inner(
152 output_ptr,
153 input_ptr,
154 M,
155 N,
156 TILE_M: tl.constexpr,
157 TILE_N: tl.constexpr,
158 ONE_TILE_PER_CTA: tl.constexpr,
159):
160 pid_m = tle.program_id(0)
161 m_offset = pid_m * TILE_M + tl.arange(0, TILE_M)
163 if ONE_TILE_PER_CTA:
164 n_offsets = tl.arange(0, TILE_N)
165 offset = m_offset[:, None] * N + n_offsets[None, :]
166 mask = (m_offset[:, None] < M) & (n_offsets[None, :] < N)
167 input_ptrs = input_ptr + offset
168 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
169 m = tl.max(inp, 1)
170 e = tl.exp(inp - m[:, None])
171 z = tl.sum(e, 1)
172 out = inp - m[:, None] - tl.log(z)[:, None]
173 output_ptrs = output_ptr + offset
174 tl.store(output_ptrs, out, mask=mask)
175 else:
176 m = tl.full([TILE_M, TILE_N], value=float("-inf"), dtype=tl.float32)
177 z = tl.full([TILE_M, TILE_N], value=0.0, dtype=tl.float32)
179 for start_n in range(0, N, TILE_N):
180 n_offsets = start_n + tl.arange(0, TILE_N)
181 offset = m_offset[:, None] * N + n_offsets[None, :]
182 mask = (m_offset[:, None] < M) & (n_offsets[None, :] < N)
183 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
184 tl.float32
185 )
186 m_new = tl.maximum(m, inp)
187 all_neg_inf = m_new == float("-inf")
188 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new))
189 m = m_new
191 m_reduced = tl.max(m, 1)
192 z = tl.sum(z * tl.exp(m - m_reduced[:, None]), 1)
193 m = m_reduced
195 for start_n in range(0, N, TILE_N):
196 n_offsets = start_n + tl.arange(0, TILE_N)
197 offset = m_offset[:, None] * N + n_offsets[None, :]
198 mask = (m_offset[:, None] < M) & (n_offsets[None, :] < N)
199 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
200 tl.float32
201 )
202 out = inp - m[:, None] - tl.log(z)[:, None]
203 tl.store(output_ptr + offset, out, mask=mask)
206# ------------------------ backward -------------------------------
207@libentry()
208@triton.autotune(
209 configs=runtime.get_tuned_config("softmax_non_inner"),
210 key=[
211 "M",
212 "N",
213 "K",
214 ],
215)
216@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner"))
217@triton.jit
218def log_softmax_backward_kernel_non_inner(
219 out_ptr,
220 out_grad_ptr,
221 in_grad_ptr,
222 M,
223 N,
224 K,
225 TILE_N: tl.constexpr,
226 TILE_K: tl.constexpr,
227 ONE_TILE_PER_CTA: tl.constexpr,
228):
229 pid_m = tle.program_id(0)
230 pid_k = tle.program_id(1)
231 offsets_k = pid_k * TILE_K + tl.arange(0, TILE_K)
233 if ONE_TILE_PER_CTA:
234 offsets_n = tl.arange(0, TILE_N)
235 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
236 mask = (offsets_n < N)[:, None] & (offsets_k < K)
237 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
238 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
239 scale = tl.sum(out_grad_tile, axis=0)
240 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[None, :]
241 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
242 else:
243 offsets_n = tl.arange(0, TILE_N)
244 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
245 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32)
246 for _ in range(0, N, TILE_N):
247 mask = (offsets_n < N)[:, None] & (offsets_k < K)
248 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
249 scale += out_grad_tile
250 offsets_n += TILE_N
251 offsets += TILE_N * K
252 scale = tl.sum(scale, axis=0) # (TILE_K)
254 offsets_n = tl.arange(0, TILE_N)
255 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k
256 for _ in range(0, N, TILE_N):
257 mask = (offsets_n < N)[:, None] & (offsets_k < K)
258 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
259 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
260 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[None, :]
261 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
262 offsets_n += TILE_N
263 offsets += TILE_N * K
266@libentry()
267@triton.autotune(
268 configs=runtime.get_tuned_config("softmax_inner"),
269 key=["M", "N"],
270)
271@triton.heuristics(
272 values=runtime.get_heuristic_config("softmax_backward_inner"),
273)
274@triton.jit
275def log_softmax_backward_kernel_inner(
276 out_ptr,
277 out_grad_ptr,
278 in_grad_ptr,
279 M,
280 N,
281 TILE_M: tl.constexpr,
282 TILE_N: tl.constexpr,
283 ONE_TILE_PER_CTA: tl.constexpr,
284):
285 pid_m = tle.program_id(0)
286 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M)
287 if ONE_TILE_PER_CTA:
288 n_offsets = tl.arange(0, TILE_N)
289 offsets = m_offsets[:, None] * N + n_offsets
290 mask = (m_offsets[:, None] < M) & (n_offsets < N)
291 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32)
292 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
293 scale = tl.sum(out_grad_tile, 1)
294 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[:, None]
295 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
296 else:
297 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float32)
299 n_offsets = tl.arange(0, TILE_N)
300 offsets = m_offsets[:, None] * N + n_offsets
301 for _ in range(0, N, TILE_N):
302 mask = (m_offsets[:, None] < M) & (n_offsets < N)
303 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
304 scale += out_grad_tile
305 n_offsets += TILE_N
306 offsets += TILE_N
307 scale = tl.sum(scale, 1) # (TILE_M,)
309 n_offsets = tl.arange(0, TILE_N)
310 offsets = m_offsets[:, None] * N + n_offsets
311 for _ in range(0, N, TILE_N):
312 mask = (m_offsets[:, None] < M) & (n_offsets < N)
313 out_tile = tl.load(
314 out_ptr + offsets, mask=mask, eviction_policy="evict_first"
315 ).to(tl.float32)
316 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32)
317 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[:, None]
318 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask)
319 n_offsets += TILE_N
320 offsets += TILE_N
323def log_softmax(self, dim, half_to_float=False):
324 logger.debug("GEMS_MTHREADS LOG_SOFTMAX")
326 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
327 dim = dim % self.ndim
328 M = 1
329 N = self.shape[dim]
330 for i in range(dim):
331 M *= self.shape[i] # pre_dim
332 self = self.contiguous()
333 if half_to_float:
334 dtype = torch.float32
335 else:
336 dtype = self.dtype
337 out = torch.empty_like(self, dtype=dtype)
338 K = self.numel() // M // N # post_dim
340 with torch_device_fn.device(self.device):
341 if K > 1:
342 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
343 log_softmax_kernel_non_inner[grid](
344 out,
345 self,
346 M,
347 N,
348 K,
349 )
350 else:
351 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
352 log_softmax_kernel_inner[grid](
353 out,
354 self,
355 M,
356 N,
357 )
358 return out
361def log_softmax_backward(grad_output, output, dim, input_dtype):
362 logger.debug("GEMS_MTHREADS LOG_SOFTMAX BACKWARD")
364 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
365 dim = dim % output.ndim
366 M = 1
367 N = output.shape[dim]
368 for i in range(dim):
369 M *= output.shape[i]
371 grad_output = grad_output.contiguous()
372 in_grad = torch.empty_like(output, dtype=input_dtype)
373 K = output.numel() // M // N
375 with torch_device_fn.device(in_grad.device):
376 if K > 1:
377 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
378 log_softmax_backward_kernel_non_inner[grid](
379 output,
380 grad_output,
381 in_grad,
382 M,
383 N,
384 K,
385 )
386 else:
387 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
388 log_softmax_backward_kernel_inner[grid](
389 output,
390 grad_output,
391 in_grad,
392 M,
393 N,
394 )
395 return in_grad