Coverage for src/flag_gems/experimental_ops/_safe_softmax.py: 0%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _safe_softmax(input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr): 

8 row_id = tl.program_id(0) 

9 cols = tl.arange(0, BLOCK_SIZE) 

10 mask = cols < n_cols 

11 

12 row_offset = row_id * n_cols 

13 x = tl.load(input_ptr + row_offset + cols, mask=mask, other=-float("inf")) 

14 x_fp32 = x.to(tl.float32) 

15 

16 x_max = tl.max(x_fp32, axis=0) 

17 all_neginf = x_max == -float("inf") 

18 

19 x_shifted = x_fp32 - x_max 

20 exp_x = tl.exp(x_shifted) 

21 sum_exp = tl.sum(exp_x, axis=0) 

22 softmax = exp_x / sum_exp 

23 

24 softmax = tl.where(all_neginf, tl.zeros([BLOCK_SIZE], dtype=tl.float32), softmax) 

25 

26 tl.store(output_ptr + row_offset + cols, softmax, mask=mask) 

27 

28 

29# Preserve kernel handle before defining wrapper with the same name 

30_safe_softmax_kernel = _safe_softmax 

31 

32 

33def _safe_softmax(x: torch.Tensor, dim: int = -1, dtype: torch.dtype = None): 

34 assert x.is_cuda, "Input tensor must be on CUDA device" 

35 assert x.ndim >= 1, "Input tensor must have at least 1 dimension" 

36 

37 dim = dim if dim >= 0 else x.ndim + dim 

38 assert 0 <= dim < x.ndim, "Invalid dim for softmax" 

39 

40 if dim != x.ndim - 1: 

41 perm = list(range(x.ndim)) 

42 perm[dim], perm[-1] = perm[-1], perm[dim] 

43 y = x.permute(perm).contiguous() 

44 inv_perm = [0] * x.ndim 

45 for i, p in enumerate(perm): 

46 inv_perm[p] = i 

47 else: 

48 y = x.contiguous() 

49 inv_perm = None 

50 

51 n_cols = y.shape[-1] 

52 n_rows = y.numel() // n_cols 

53 

54 y_fp32 = y.float() 

55 out_fp32 = torch.empty_like(y_fp32) 

56 

57 def _next_pow2(v: int) -> int: 

58 if v <= 1: 

59 return 1 

60 return 1 << (v - 1).bit_length() 

61 

62 BLOCK_SIZE = min(4096, _next_pow2(n_cols)) 

63 grid = lambda meta: (n_rows,) 

64 

65 _safe_softmax_kernel[grid](y_fp32, out_fp32, n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE) 

66 

67 out = out_fp32 

68 if dtype is not None: 

69 out = out.to(dtype) 

70 else: 

71 out = out.to(x.dtype) 

72 

73 out = out.view(*y.shape) 

74 if inv_perm is not None: 

75 out = out.permute(inv_perm) 

76 

77 return out