Coverage for src/flag_gems/ops/_safe_softmax.py: 73%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def _safe_softmax_kernel(
13 input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr
14):
15 row_id = tl.program_id(0)
16 cols = tl.arange(0, BLOCK_SIZE)
17 mask = cols < n_cols
19 row_offset = row_id * n_cols
20 x = tl.load(input_ptr + row_offset + cols, mask=mask, other=-float("inf"))
21 x_fp32 = x.to(tl.float32)
23 x_max = tl.max(x_fp32, axis=0)
24 all_neginf = x_max == -float("inf")
26 x_shifted = x_fp32 - x_max
27 exp_x = tl.exp(x_shifted)
28 sum_exp = tl.sum(exp_x, axis=0)
29 softmax = exp_x / sum_exp
31 softmax = tl.where(all_neginf, tl.zeros([BLOCK_SIZE], dtype=tl.float32), softmax)
33 tl.store(output_ptr + row_offset + cols, softmax, mask=mask)
36def _safe_softmax(x: torch.Tensor, dim: int = -1, dtype: torch.dtype = None):
37 logger.debug("GEMS _SAFE_SOFTMAX")
38 assert x.is_cuda, "Input tensor must be on CUDA device"
39 assert x.ndim >= 1, "Input tensor must have at least 1 dimension"
41 dim = dim if dim >= 0 else x.ndim + dim
42 assert 0 <= dim < x.ndim, "Invalid dim for softmax"
44 if dim != x.ndim - 1:
45 perm = list(range(x.ndim))
46 perm[dim], perm[-1] = perm[-1], perm[dim]
47 y = x.permute(perm).contiguous()
48 inv_perm = [0] * x.ndim
49 for i, p in enumerate(perm):
50 inv_perm[p] = i
51 else:
52 y = x.contiguous()
53 inv_perm = None
55 n_cols = y.shape[-1]
56 n_rows = y.numel() // n_cols
58 y_fp32 = y.float()
59 out_fp32 = torch.empty_like(y_fp32)
61 def _next_pow2(v: int) -> int:
62 if v <= 1:
63 return 1
64 return 1 << (v - 1).bit_length()
66 BLOCK_SIZE = min(4096, _next_pow2(n_cols))
67 grid = lambda meta: (n_rows,)
69 _safe_softmax_kernel[grid](y_fp32, out_fp32, n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE)
71 out = out_fp32
72 if dtype is not None:
73 out = out.to(dtype)
74 else:
75 out = out.to(x.dtype)
77 out = out.view(*y.shape)
78 if inv_perm is not None:
79 out = out.permute(inv_perm)
81 return out