Coverage for src/flag_gems/ops/_safe_softmax.py: 73%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def _safe_softmax_kernel( 

13 input_ptr, output_ptr, n_rows, n_cols, BLOCK_SIZE: tl.constexpr 

14): 

15 row_id = tl.program_id(0) 

16 cols = tl.arange(0, BLOCK_SIZE) 

17 mask = cols < n_cols 

18 

19 row_offset = row_id * n_cols 

20 x = tl.load(input_ptr + row_offset + cols, mask=mask, other=-float("inf")) 

21 x_fp32 = x.to(tl.float32) 

22 

23 x_max = tl.max(x_fp32, axis=0) 

24 all_neginf = x_max == -float("inf") 

25 

26 x_shifted = x_fp32 - x_max 

27 exp_x = tl.exp(x_shifted) 

28 sum_exp = tl.sum(exp_x, axis=0) 

29 softmax = exp_x / sum_exp 

30 

31 softmax = tl.where(all_neginf, tl.zeros([BLOCK_SIZE], dtype=tl.float32), softmax) 

32 

33 tl.store(output_ptr + row_offset + cols, softmax, mask=mask) 

34 

35 

36def _safe_softmax(x: torch.Tensor, dim: int = -1, dtype: torch.dtype = None): 

37 logger.debug("GEMS _SAFE_SOFTMAX") 

38 assert x.is_cuda, "Input tensor must be on CUDA device" 

39 assert x.ndim >= 1, "Input tensor must have at least 1 dimension" 

40 

41 dim = dim if dim >= 0 else x.ndim + dim 

42 assert 0 <= dim < x.ndim, "Invalid dim for softmax" 

43 

44 if dim != x.ndim - 1: 

45 perm = list(range(x.ndim)) 

46 perm[dim], perm[-1] = perm[-1], perm[dim] 

47 y = x.permute(perm).contiguous() 

48 inv_perm = [0] * x.ndim 

49 for i, p in enumerate(perm): 

50 inv_perm[p] = i 

51 else: 

52 y = x.contiguous() 

53 inv_perm = None 

54 

55 n_cols = y.shape[-1] 

56 n_rows = y.numel() // n_cols 

57 

58 y_fp32 = y.float() 

59 out_fp32 = torch.empty_like(y_fp32) 

60 

61 def _next_pow2(v: int) -> int: 

62 if v <= 1: 

63 return 1 

64 return 1 << (v - 1).bit_length() 

65 

66 BLOCK_SIZE = min(4096, _next_pow2(n_cols)) 

67 grid = lambda meta: (n_rows,) 

68 

69 _safe_softmax_kernel[grid](y_fp32, out_fp32, n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE) 

70 

71 out = out_fp32 

72 if dtype is not None: 

73 out = out.to(dtype) 

74 else: 

75 out = out.to(x.dtype) 

76 

77 out = out.view(*y.shape) 

78 if inv_perm is not None: 

79 out = out.permute(inv_perm) 

80 

81 return out