Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/celu.py: 0%
20 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import triton
4import triton.language as tl
5from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
7from ..utils.pointwise_dynamic import pointwise_dynamic
9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12config_ = CodeGenConfig(
13 512,
14 (65536, 65536, 65536),
15 32,
16 True,
17 prefer_1d_tile=True,
18 isCloseVectorization=True, # TODO: Wait LLVM FIX
19)
22@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")])
23@triton.jit
24# celu another way: max(0, x) + alpha * (exp(min(0, x) / alpha) - 1), getting smaller instrs.
25def celu_forward_kernel(x, alpha):
26 inv_alpha = 1.0 / alpha
28 pos_part = tl.maximum(0.0, x)
30 neg_part_input = x - pos_part
32 return pos_part + alpha * (tl.exp(neg_part_input * inv_alpha) - 1.0)
35def celu(A, alpha=1.0):
36 logger.debug("GEMS CELU")
37 return celu_forward_kernel(A, alpha)
40def celu_(A, alpha=1.0):
41 logger.debug("GEMS CELU_")
42 return celu_forward_kernel(A, alpha, out0=A)