Coverage for src/flag_gems/experimental_ops/_safe_softmax.py: 0%
53 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _safe_softmax(input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr):
8 row_id = tl.program_id(0)
9 cols = tl.arange(0, BLOCK_SIZE)
10 mask = cols < n_cols
12 row_offset = row_id * n_cols
13 x = tl.load(input_ptr + row_offset + cols, mask=mask, other=-float("inf"))
14 x_fp32 = x.to(tl.float32)
16 x_max = tl.max(x_fp32, axis=0)
17 all_neginf = x_max == -float("inf")
19 x_shifted = x_fp32 - x_max
20 exp_x = tl.exp(x_shifted)
21 sum_exp = tl.sum(exp_x, axis=0)
22 softmax = exp_x / sum_exp
24 softmax = tl.where(all_neginf, tl.zeros([BLOCK_SIZE], dtype=tl.float32), softmax)
26 tl.store(output_ptr + row_offset + cols, softmax, mask=mask)
29# Preserve kernel handle before defining wrapper with the same name
30_safe_softmax_kernel = _safe_softmax
33def _safe_softmax(x: torch.Tensor, dim: int = -1, dtype: torch.dtype = None):
34 assert x.is_cuda, "Input tensor must be on CUDA device"
35 assert x.ndim >= 1, "Input tensor must have at least 1 dimension"
37 dim = dim if dim >= 0 else x.ndim + dim
38 assert 0 <= dim < x.ndim, "Invalid dim for softmax"
40 if dim != x.ndim - 1:
41 perm = list(range(x.ndim))
42 perm[dim], perm[-1] = perm[-1], perm[dim]
43 y = x.permute(perm).contiguous()
44 inv_perm = [0] * x.ndim
45 for i, p in enumerate(perm):
46 inv_perm[p] = i
47 else:
48 y = x.contiguous()
49 inv_perm = None
51 n_cols = y.shape[-1]
52 n_rows = y.numel() // n_cols
54 y_fp32 = y.float()
55 out_fp32 = torch.empty_like(y_fp32)
57 def _next_pow2(v: int) -> int:
58 if v <= 1:
59 return 1
60 return 1 << (v - 1).bit_length()
62 BLOCK_SIZE = min(4096, _next_pow2(n_cols))
63 grid = lambda meta: (n_rows,)
65 _safe_softmax_kernel[grid](y_fp32, out_fp32, n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE)
67 out = out_fp32
68 if dtype is not None:
69 out = out.to(dtype)
70 else:
71 out = out.to(x.dtype)
73 out = out.view(*y.shape)
74 if inv_perm is not None:
75 out = out.permute(inv_perm)
77 return out