Coverage for src/flag_gems/runtime/backend/_mthreads/ops/tanh.py: 0%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1# This custom op requires musa device capability >= 31. 

2# We determine whether to enable this op by distinguish the op registration for different arch. 

3 

4import logging 

5 

6import torch 

7import triton 

8import triton.language as tl 

9 

10from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

11 

12pow = tl_extra_shim.pow 

13fast_tanh = tl_extra_shim.fast_tanh 

14 

15logger = logging.getLogger( 

16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

17) 

18 

19 

20@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) 

21@triton.jit 

22def tanh_forward(x): 

23 return fast_tanh(x.to(tl.float32)) 

24 

25 

26@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) 

27@triton.jit 

28def tanh_backward(y, dy): 

29 return dy * (1.0 - pow(y.to(tl.float32), 2)) 

30 

31 

32class Tanh(torch.autograd.Function): 

33 @staticmethod 

34 def forward(ctx, A): 

35 logger.debug("GEMS_MTHREADS TANH FORWARD") 

36 if A.requires_grad is True: 

37 out = tanh_forward(A.to(torch.float32)) 

38 ctx.save_for_backward(out) 

39 return out.to(A.dtype) 

40 else: 

41 out = tanh_forward(A) 

42 return out 

43 

44 @staticmethod 

45 def backward(ctx, out_grad): 

46 logger.debug("GEMS_MTHREADS TANH BACKWARD") 

47 (out,) = ctx.saved_tensors 

48 in_grad = tanh_backward(out, out_grad) 

49 return in_grad 

50 

51 

52def tanh(A): 

53 return Tanh.apply(A)