Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/sigmoid.py: 0%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5from _kunlunxin.utils.codegen_config_utils import CodeGenConfig 

6 

7from flag_gems.utils import tl_extra_shim 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12exp2 = tl_extra_shim.exp2 

13 

14 

15config_ = CodeGenConfig( 

16 512, 

17 (65536, 65536, 65536), 

18 32, 

19 True, 

20 prefer_1d_tile=True, 

21 buffer_size_limit=4096, 

22 isCloseVectorization=True, 

23 kunlunAutoGrid=True, 

24 unroll_num=8, 

25) 

26 

27 

28@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")], config=config_) 

29@triton.jit 

30def sigmoid_forward(x): 

31 # log2e: tl.constexpr = math.log2(math.e) 

32 # triton 3.0.0 disallow calling non-jitted function inside jitted function, even if it is in 

33 # the rhs of an assignment to a constexpr, so we use numeric literal instead to work around this. 

34 # log2e: tl.constexpr = 1.4426950408889634 

35 return 1 / (1 + tl.exp(-x.to(tl.float32))) 

36 

37 

38@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) 

39@triton.jit 

40def sigmoid_backward_kernel(dy, y): 

41 y_f32 = y.to(tl.float32) 

42 dy_f32 = dy.to(tl.float32) 

43 return dy_f32 * (1.0 - y_f32) * y_f32 

44 

45 

46def sigmoid(self): 

47 logger.debug("GEMS SIGMOID FORWARD") 

48 output = sigmoid_forward(self) 

49 return output 

50 

51 

52def sigmoid_backward(grad_output, output): 

53 logger.debug("GEMS SIGMOID BACKWARD") 

54 grad_input = sigmoid_backward_kernel(grad_output, output) 

55 return grad_input 

56 

57 

58def sigmoid_(A): 

59 logger.debug("GEMS SIGMOID_ FORWARD") 

60 out = sigmoid_forward(A, out0=A) 

61 return out