Coverage for src/flag_gems/runtime/backend/_cambricon/ops/tanh.py: 0%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import tl_extra_shim 

7 

8from ..utils.pointwise_dynamic import pointwise_dynamic 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11pow = tl_extra_shim.pow 

12_tanh = tl_extra_shim.fast_tanh 

13 

14 

15@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "INT_TO_FLOAT")]) 

16@triton.jit 

17def tanh_kernel(x, inplace): 

18 return _tanh(x.to(tl.float32)) 

19 

20 

21@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) 

22@triton.jit 

23def tanh_backward_kernel(y, dy): 

24 y = y.to(tl.float32) 

25 return dy.to(tl.float32) * (1.0 - y * y) 

26 

27 

28def tanh(self): 

29 logger.debug("GEMS_CAMBRICON TANH FORWARD") 

30 out = tanh_kernel(self, False) 

31 return out 

32 

33 

34def tanh_backward(grad_output, output): 

35 logger.debug("GEMS_CAMBRICON TANH BACKWARD") 

36 in_grad = tanh_backward_kernel(output, grad_output) 

37 return in_grad 

38 

39 

40def tanh_(A): 

41 logger.debug("GEMS_CAMBRICON TANH_ FORWARD") 

42 out = tanh_kernel(A, True, out0=A) 

43 return out