Coverage for src/flag_gems/ops/softplus.py: 71%
17 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(is_tensor=[True, False, False], promotion_methods=[(0, "DEFAULT")])
12@triton.jit
13def softplus_forward(x, beta, threshold):
14 x_fp = x.to(tl.float32)
15 z = x_fp * beta
16 soft_z = tl.where(z > threshold, z, tl.log(1 + tl.exp(z)))
17 out = (soft_z / beta).to(x.dtype)
18 return out
21def softplus(self, beta=1.0, threshold=20.0):
22 logger.debug("GEMS SOFTPLUS FORWARD")
23 output = softplus_forward(self, beta, threshold)
24 return output