Coverage for src/flag_gems/experimental_ops/softplus.py: 0%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def softplus_kernel( 

8 x_ptr, # *Pointer* to input tensor 

9 out_ptr, # *Pointer* to output tensor 

10 n_elements, # Number of elements 

11 beta, # beta scalar (float32) 

12 threshold, # threshold scalar (float32) 

13 BLOCK_SIZE: tl.constexpr, 

14): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask) 

21 x_fp32 = x.to(tl.float32) 

22 

23 z = x_fp32 * beta 

24 # compute softplus in a numerically stable way: 

25 # if z > threshold => x 

26 # else => log(1 + exp(z)) / beta 

27 exp_z = tl.exp(z) 

28 sp = tl.log(1.0 + exp_z) / beta 

29 y_fp32 = tl.where(z > threshold, x_fp32, sp) 

30 

31 y = y_fp32.to(x.dtype) 

32 tl.store(out_ptr + offsets, y, mask=mask) 

33 

34 

35def _softplus_launch(x: torch.Tensor, beta: float, threshold: float, out: torch.Tensor): 

36 assert x.is_cuda and out.is_cuda, "Inputs must be CUDA tensors" 

37 assert x.is_contiguous(), "Input tensor must be contiguous" 

38 assert out.is_contiguous(), "Output tensor must be contiguous" 

39 assert ( 

40 x.numel() == out.numel() 

41 ), "Input and output must have the same number of elements" 

42 assert x.dtype in ( 

43 torch.float16, 

44 torch.bfloat16, 

45 torch.float32, 

46 ), "Supported dtypes: float16, bfloat16, float32" 

47 assert out.dtype == x.dtype, "Output dtype must match input dtype" 

48 

49 n_elements = x.numel() 

50 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

51 softplus_kernel[grid]( 

52 x, out, n_elements, float(beta), float(threshold), BLOCK_SIZE=1024 

53 ) 

54 return out 

55 

56 

57def _parse_softplus_args(args, kwargs, expect_out: bool = False): 

58 # ATen signature: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor 

59 # ATen signature: softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!) 

60 x = None 

61 if len(args) >= 1: 

62 x = args[0] 

63 else: 

64 x = kwargs.get("self", kwargs.get("input", None)) 

65 if x is None: 

66 raise ValueError("softplus expects 'self' tensor as the first argument") 

67 

68 beta = kwargs.get("beta", 1.0) 

69 if len(args) >= 2: 

70 beta = args[1] 

71 threshold = kwargs.get("threshold", 20.0) 

72 if len(args) >= 3: 

73 threshold = args[2] 

74 

75 out = None 

76 if expect_out: 

77 if "out" in kwargs: 

78 out = kwargs["out"] 

79 elif len(args) >= 4: 

80 out = args[3] 

81 

82 return x, float(beta), float(threshold), out 

83 

84 

85def softplus(*args, **kwargs): 

86 x, beta, threshold, _ = _parse_softplus_args(args, kwargs, expect_out=False) 

87 if not x.is_contiguous(): 

88 x = x.contiguous() 

89 out = torch.empty_like(x) 

90 _softplus_launch(x.view(-1), beta, threshold, out.view(-1)) 

91 return out 

92 

93 

94def softplus_out(*args, **kwargs): 

95 x, beta, threshold, out = _parse_softplus_args(args, kwargs, expect_out=True) 

96 if out is None: 

97 out = torch.empty_like(x) 

98 # Ensure contiguous buffers for kernel execution 

99 x_c = x.contiguous() if not x.is_contiguous() else x 

100 out_c_needs_copyback = False 

101 if ( 

102 (not out.is_contiguous()) 

103 or (out.shape != x.shape) 

104 or (out.dtype != x.dtype) 

105 or (out.device != x.device) 

106 ): 

107 out_c = torch.empty_like(x_c) 

108 out_c_needs_copyback = True 

109 else: 

110 out_c = out 

111 

112 _softplus_launch(x_c.view(-1), beta, threshold, out_c.view(-1)) 

113 

114 if out_c_needs_copyback: 

115 out.copy_(out_c) 

116 return out