Coverage for src/flag_gems/experimental_ops/softplus.py: 0%
69 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def softplus_kernel(
8 x_ptr, # *Pointer* to input tensor
9 out_ptr, # *Pointer* to output tensor
10 n_elements, # Number of elements
11 beta, # beta scalar (float32)
12 threshold, # threshold scalar (float32)
13 BLOCK_SIZE: tl.constexpr,
14):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask)
21 x_fp32 = x.to(tl.float32)
23 z = x_fp32 * beta
24 # compute softplus in a numerically stable way:
25 # if z > threshold => x
26 # else => log(1 + exp(z)) / beta
27 exp_z = tl.exp(z)
28 sp = tl.log(1.0 + exp_z) / beta
29 y_fp32 = tl.where(z > threshold, x_fp32, sp)
31 y = y_fp32.to(x.dtype)
32 tl.store(out_ptr + offsets, y, mask=mask)
35def _softplus_launch(x: torch.Tensor, beta: float, threshold: float, out: torch.Tensor):
36 assert x.is_cuda and out.is_cuda, "Inputs must be CUDA tensors"
37 assert x.is_contiguous(), "Input tensor must be contiguous"
38 assert out.is_contiguous(), "Output tensor must be contiguous"
39 assert (
40 x.numel() == out.numel()
41 ), "Input and output must have the same number of elements"
42 assert x.dtype in (
43 torch.float16,
44 torch.bfloat16,
45 torch.float32,
46 ), "Supported dtypes: float16, bfloat16, float32"
47 assert out.dtype == x.dtype, "Output dtype must match input dtype"
49 n_elements = x.numel()
50 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
51 softplus_kernel[grid](
52 x, out, n_elements, float(beta), float(threshold), BLOCK_SIZE=1024
53 )
54 return out
57def _parse_softplus_args(args, kwargs, expect_out: bool = False):
58 # ATen signature: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor
59 # ATen signature: softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!)
60 x = None
61 if len(args) >= 1:
62 x = args[0]
63 else:
64 x = kwargs.get("self", kwargs.get("input", None))
65 if x is None:
66 raise ValueError("softplus expects 'self' tensor as the first argument")
68 beta = kwargs.get("beta", 1.0)
69 if len(args) >= 2:
70 beta = args[1]
71 threshold = kwargs.get("threshold", 20.0)
72 if len(args) >= 3:
73 threshold = args[2]
75 out = None
76 if expect_out:
77 if "out" in kwargs:
78 out = kwargs["out"]
79 elif len(args) >= 4:
80 out = args[3]
82 return x, float(beta), float(threshold), out
85def softplus(*args, **kwargs):
86 x, beta, threshold, _ = _parse_softplus_args(args, kwargs, expect_out=False)
87 if not x.is_contiguous():
88 x = x.contiguous()
89 out = torch.empty_like(x)
90 _softplus_launch(x.view(-1), beta, threshold, out.view(-1))
91 return out
94def softplus_out(*args, **kwargs):
95 x, beta, threshold, out = _parse_softplus_args(args, kwargs, expect_out=True)
96 if out is None:
97 out = torch.empty_like(x)
98 # Ensure contiguous buffers for kernel execution
99 x_c = x.contiguous() if not x.is_contiguous() else x
100 out_c_needs_copyback = False
101 if (
102 (not out.is_contiguous())
103 or (out.shape != x.shape)
104 or (out.dtype != x.dtype)
105 or (out.device != x.device)
106 ):
107 out_c = torch.empty_like(x_c)
108 out_c_needs_copyback = True
109 else:
110 out_c = out
112 _softplus_launch(x_c.view(-1), beta, threshold, out_c.view(-1))
114 if out_c_needs_copyback:
115 out.copy_(out_c)
116 return out