Coverage for src/flag_gems/runtime/backend/_metax/ops/tanh.py: 0%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic, tl_extra_shim
9logger = logging.getLogger("flag_gems." + __name__)
10pow = tl_extra_shim.pow
11_tanh = tl_extra_shim.tanh
14@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")])
15@triton.jit
16def tanh_forward(x):
17 return _tanh(x.to(tl.float32))
20@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")])
21@triton.jit
22def tanh_backward(y, dy):
23 return dy * (1.0 - y * y)
26@triton.jit
27def tanh_backward_custom_kernel(
28 x_ptr: tl.tensor, # *Pointer* to first input vector.
29 y_ptr: tl.tensor, # *Pointer* to second input vector.
30 output_ptr: tl.tensor, # *Pointer* to output vector.
31 n_elements: int, # Size of the vector.
32 BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
33 # NOTE: `constexpr` so it can be used as a shape value.
34):
35 # There are multiple 'programs' processing different data. We identify which program
36 # we are here:
37 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
38 block_start = pid * BLOCK_SIZE
39 offsets = block_start + tl.arange(0, BLOCK_SIZE)
41 # Create a mask to guard memory operations against out-of-bounds accesses.
42 mask = offsets < n_elements
43 # Load x and y from DRAM, masking out any extra elements in case the input is not a
44 # multiple of the block size.
45 x = tl.load(x_ptr + offsets, mask=mask)
47 # No need to add offset and mask, as its stride is 0
48 y = tl.load(y_ptr)
50 output = y * (1 - x * x)
51 # Write output back to DRAM.
52 tl.store(output_ptr + offsets, output, mask=mask)
55def tanh_backward_custom(x: torch.Tensor, y: torch.Tensor):
56 # We need to preallocate the output.
57 output = torch.empty_like(x)
58 assert x.is_cuda and y.is_cuda and output.is_cuda
60 n_elements = output.numel()
61 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
62 tanh_backward_custom_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
63 return output
66class Tanh(torch.autograd.Function):
67 @staticmethod
68 def forward(ctx, A):
69 logger.debug("METAX GEMS TANH FORWARD")
70 if A.requires_grad is True:
71 out = tanh_forward(A.to(torch.float32))
72 ctx.save_for_backward(out)
73 return out.to(A.dtype)
74 else:
75 out = tanh_forward(A)
76 return out
78 @staticmethod
79 def backward(ctx, out_grad):
80 logger.debug("METAX GEMS TANH BACKWARD")
81 (out,) = ctx.saved_tensors
83 is_grad_stride_0 = True
84 for i in range(len(out_grad.stride())):
85 if out_grad.stride()[i] != 0:
86 is_grad_stride_0 = False
87 break
89 # temporay plan
90 if (is_grad_stride_0) and (out_grad.numel() % 1024 == 0):
91 in_grad = tanh_backward_custom(out, out_grad)
92 return in_grad
94 in_grad = tanh_backward(out, out_grad)
95 return in_grad
98def tanh(A):
99 return Tanh.apply(A)