Coverage for src/flag_gems/runtime/backend/_hygon/ops/gelu.py: 0%

84 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

8 

9logger = logging.getLogger(__name__) 

10erf = tl_extra_shim.erf 

11exp = tl_extra_shim.exp 

12pow = tl_extra_shim.pow 

13tanh = tl_extra_shim.tanh 

14 

15 

16@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

17@triton.jit 

18def gelu_none(x): 

19 x_fp32 = x.to(tl.float32) 

20 scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2) 

21 output = 0.5 * x_fp32 * (1 + erf(x_fp32 * scale)) 

22 return output 

23 

24 

25@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

26@triton.jit 

27def gelu_tanh(x): 

28 x_fp32 = x.to(tl.float32) 

29 output = ( 

30 0.5 

31 * x_fp32 

32 * ( 

33 1 

34 + tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32.to(tl.float32), 2))) 

35 ) 

36 ) 

37 return output 

38 

39 

40@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

41@triton.jit 

42def gelu_backward_none(x, dy): 

43 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2) 

44 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi) 

45 x_fp32 = x.to(tl.float32) 

46 dydx = ( 

47 scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2)) 

48 + 0.5 * erf(scale1 * x_fp32) 

49 + 0.5 

50 ) 

51 dx = dydx * dy 

52 return dx 

53 

54 

55@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

56@triton.jit 

57def gelu_backward_tanh(x, dy): 

58 x_fp32 = x.to(tl.float32) 

59 # 0.79788456 = math.sqrt(2 / math.pi) 

60 tanh_out = tanh(0.79788456 * x_fp32 * (1 + 0.044715 * pow(x_fp32, 2))) 

61 dydx = 0.5 * x_fp32 * ( 

62 (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2)) 

63 ) + 0.5 * (1 + tanh_out) 

64 dx = dydx * dy 

65 return dx 

66 

67 

68class Gelu(torch.autograd.Function): 

69 @staticmethod 

70 def forward(ctx, A, approximate): 

71 logger.debug("GEMS GELU FORWARD") 

72 if approximate == "tanh": 

73 out = gelu_tanh(A) 

74 else: 

75 out = gelu_none(A) 

76 ctx.save_for_backward(A) 

77 ctx.approximate = approximate 

78 return out 

79 

80 @staticmethod 

81 def backward(ctx, out_grad): 

82 logger.debug("GEMS GELU BACKWARD") 

83 (inp,) = ctx.saved_tensors 

84 approximate = ctx.approximate 

85 if approximate == "tanh": 

86 in_grad = gelu_backward_tanh(inp, out_grad) 

87 else: 

88 in_grad = gelu_backward_none(inp, out_grad) 

89 return in_grad, None 

90 

91 

92def gelu(A, *, approximate="none"): 

93 return Gelu.apply(A, approximate) 

94 

95 

96class InplaceGelu(torch.autograd.Function): 

97 @staticmethod 

98 def forward(ctx, A, approximate): 

99 logger.debug("GEMS GELU_ FORWARD") 

100 ctx.save_for_backward(A.clone()) 

101 ctx.mark_dirty(A) 

102 ctx.approximate = approximate 

103 

104 if approximate == "tanh": 

105 out = gelu_tanh(A, out0=A) 

106 else: 

107 out = gelu_none(A, out0=A) 

108 return out 

109 

110 @staticmethod 

111 def backward(ctx, out_grad): 

112 logger.debug("GEMS GELU_ BACKWARD") 

113 (inp,) = ctx.saved_tensors 

114 approximate = ctx.approximate 

115 if approximate == "tanh": 

116 in_grad = gelu_backward_tanh(inp, out_grad) 

117 else: 

118 in_grad = gelu_backward_none(inp, out_grad) 

119 return in_grad, None 

120 

121 

122def gelu_(A, *, approximate="none"): 

123 InplaceGelu.apply(A, approximate) 

124 return A