Coverage for src/flag_gems/runtime/backend/_cambricon/ops/celu.py: 0%
15 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import triton
4import triton.language as tl
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(is_tensor=[True, False, False], promotion_methods=[(0, "DEFAULT")])
12@triton.jit
13def celu_forward_kernel(x, alpha, inplace):
14 return tl.where(
15 x > 0,
16 x,
17 alpha * (tl.exp(x / alpha) - 1),
18 )
21def celu(A, alpha=1.0):
22 logger.debug("GEMS_CAMBRICON CELU")
23 return celu_forward_kernel(A, alpha, False)
26def celu_(A, alpha=1.0):
27 logger.debug("GEMS_CAMBRICON CELU_")
28 return celu_forward_kernel(A, alpha, True, out0=A)