Coverage for src/flag_gems/runtime/backend/_metax/ops/tanh.py: 0%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

8 

9logger = logging.getLogger("flag_gems." + __name__) 

10pow = tl_extra_shim.pow 

11_tanh = tl_extra_shim.tanh 

12 

13 

14@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) 

15@triton.jit 

16def tanh_forward(x): 

17 return _tanh(x.to(tl.float32)) 

18 

19 

20@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) 

21@triton.jit 

22def tanh_backward(y, dy): 

23 return dy * (1.0 - y * y) 

24 

25 

26@triton.jit 

27def tanh_backward_custom_kernel( 

28 x_ptr: tl.tensor, # *Pointer* to first input vector. 

29 y_ptr: tl.tensor, # *Pointer* to second input vector. 

30 output_ptr: tl.tensor, # *Pointer* to output vector. 

31 n_elements: int, # Size of the vector. 

32 BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. 

33 # NOTE: `constexpr` so it can be used as a shape value. 

34): 

35 # There are multiple 'programs' processing different data. We identify which program 

36 # we are here: 

37 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. 

38 block_start = pid * BLOCK_SIZE 

39 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

40 

41 # Create a mask to guard memory operations against out-of-bounds accesses. 

42 mask = offsets < n_elements 

43 # Load x and y from DRAM, masking out any extra elements in case the input is not a 

44 # multiple of the block size. 

45 x = tl.load(x_ptr + offsets, mask=mask) 

46 

47 # No need to add offset and mask, as its stride is 0 

48 y = tl.load(y_ptr) 

49 

50 output = y * (1 - x * x) 

51 # Write output back to DRAM. 

52 tl.store(output_ptr + offsets, output, mask=mask) 

53 

54 

55def tanh_backward_custom(x: torch.Tensor, y: torch.Tensor): 

56 # We need to preallocate the output. 

57 output = torch.empty_like(x) 

58 assert x.is_cuda and y.is_cuda and output.is_cuda 

59 

60 n_elements = output.numel() 

61 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

62 tanh_backward_custom_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) 

63 return output 

64 

65 

66class Tanh(torch.autograd.Function): 

67 @staticmethod 

68 def forward(ctx, A): 

69 logger.debug("METAX GEMS TANH FORWARD") 

70 if A.requires_grad is True: 

71 out = tanh_forward(A.to(torch.float32)) 

72 ctx.save_for_backward(out) 

73 return out.to(A.dtype) 

74 else: 

75 out = tanh_forward(A) 

76 return out 

77 

78 @staticmethod 

79 def backward(ctx, out_grad): 

80 logger.debug("METAX GEMS TANH BACKWARD") 

81 (out,) = ctx.saved_tensors 

82 

83 is_grad_stride_0 = True 

84 for i in range(len(out_grad.stride())): 

85 if out_grad.stride()[i] != 0: 

86 is_grad_stride_0 = False 

87 break 

88 

89 # temporay plan 

90 if (is_grad_stride_0) and (out_grad.numel() % 1024 == 0): 

91 in_grad = tanh_backward_custom(out, out_grad) 

92 return in_grad 

93 

94 in_grad = tanh_backward(out, out_grad) 

95 return in_grad 

96 

97 

98def tanh(A): 

99 return Tanh.apply(A)