Coverage for src/flag_gems/ops/celu.py: 93%
15 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")])
12@triton.jit
13def celu_forward_kernel(x, alpha):
14 return tl.where(
15 x > 0,
16 x,
17 alpha * (tl.exp(x / alpha) - 1),
18 )
21def celu(A, alpha=1.0):
22 logger.debug("GEMS CELU")
23 return celu_forward_kernel(A, alpha)
26def celu_(A, alpha=1.0):
27 logger.debug("GEMS CELU_")
28 return celu_forward_kernel(A, alpha, out0=A)