Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/celu.py: 0%

20 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5from _kunlunxin.utils.codegen_config_utils import CodeGenConfig 

6 

7from ..utils.pointwise_dynamic import pointwise_dynamic 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11 

12config_ = CodeGenConfig( 

13 512, 

14 (65536, 65536, 65536), 

15 32, 

16 True, 

17 prefer_1d_tile=True, 

18 isCloseVectorization=True, # TODO: Wait LLVM FIX 

19) 

20 

21 

22@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")]) 

23@triton.jit 

24# celu another way: max(0, x) + alpha * (exp(min(0, x) / alpha) - 1), getting smaller instrs. 

25def celu_forward_kernel(x, alpha): 

26 inv_alpha = 1.0 / alpha 

27 

28 pos_part = tl.maximum(0.0, x) 

29 

30 neg_part_input = x - pos_part 

31 

32 return pos_part + alpha * (tl.exp(neg_part_input * inv_alpha) - 1.0) 

33 

34 

35def celu(A, alpha=1.0): 

36 logger.debug("GEMS CELU") 

37 return celu_forward_kernel(A, alpha) 

38 

39 

40def celu_(A, alpha=1.0): 

41 logger.debug("GEMS CELU_") 

42 return celu_forward_kernel(A, alpha, out0=A)