Coverage for src/flag_gems/runtime/backend/_mthreads/ops/tanh.py: 0%
34 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1# This custom op requires musa device capability >= 31.
2# We determine whether to enable this op by distinguish the op registration for different arch.
4import logging
6import torch
7import triton
8import triton.language as tl
10from flag_gems.utils import pointwise_dynamic, tl_extra_shim
12pow = tl_extra_shim.pow
13fast_tanh = tl_extra_shim.fast_tanh
15logger = logging.getLogger(
16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
17)
20@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")])
21@triton.jit
22def tanh_forward(x):
23 return fast_tanh(x.to(tl.float32))
26@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")])
27@triton.jit
28def tanh_backward(y, dy):
29 return dy * (1.0 - pow(y.to(tl.float32), 2))
32class Tanh(torch.autograd.Function):
33 @staticmethod
34 def forward(ctx, A):
35 logger.debug("GEMS_MTHREADS TANH FORWARD")
36 if A.requires_grad is True:
37 out = tanh_forward(A.to(torch.float32))
38 ctx.save_for_backward(out)
39 return out.to(A.dtype)
40 else:
41 out = tanh_forward(A)
42 return out
44 @staticmethod
45 def backward(ctx, out_grad):
46 logger.debug("GEMS_MTHREADS TANH BACKWARD")
47 (out,) = ctx.saved_tensors
48 in_grad = tanh_backward(out, out_grad)
49 return in_grad
52def tanh(A):
53 return Tanh.apply(A)