Coverage for src/flag_gems/fused/gelu_and_mul.py: 50%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

8 

9erf = tl_extra_shim.erf 

10pow = tl_extra_shim.pow 

11tanh = tl_extra_shim.tanh 

12logger = logging.getLogger(__name__) 

13 

14 

15@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

16@triton.jit 

17def gelu_none_and_mul_kernel(x, y): 

18 x_fp32 = x.to(tl.float32) 

19 RCP_SQRT_2: tl.constexpr = 0.7071067811 

20 x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * RCP_SQRT_2)) 

21 return x_gelu * y 

22 

23 

24@pointwise_dynamic( 

25 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2 

26) 

27@triton.jit 

28def gelu_none_and_mul_grad_kernel(x, y, dgrad): 

29 RCP_SQRT_2: tl.constexpr = 0.7071067811 

30 COEFF: tl.constexpr = 0.7978845608028654 

31 

32 x_fp32 = x.to(tl.float32) 

33 x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * RCP_SQRT_2)) 

34 

35 d_gelu = dgrad * y 

36 dx = ( 

37 d_gelu 

38 * 0.5 

39 * ( 

40 1.0 

41 + erf(x_fp32 * RCP_SQRT_2) 

42 + x_fp32 * COEFF * tl.exp(-0.5 * x_fp32 * x_fp32) 

43 ) 

44 ) 

45 

46 dy = dgrad * x_gelu 

47 

48 return dx, dy 

49 

50 

51@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

52@triton.jit 

53def gelu_tanh_and_mul_kernel(x, y): 

54 x_fp32 = x.to(tl.float32) 

55 x_gelu = ( 

56 0.5 

57 * x_fp32 

58 * ( 

59 1 

60 + tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32.to(tl.float32), 2))) 

61 ) 

62 ) 

63 return x_gelu * y 

64 

65 

66@pointwise_dynamic( 

67 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2 

68) 

69@triton.jit 

70def gelu_tanh_and_mul_grad_kernel(x, y, dgrad): 

71 x_fp32 = x.to(tl.float32) 

72 y_fp32 = y.to(tl.float32) 

73 

74 sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) 

75 a_cubed = x_fp32 * x_fp32 * x_fp32 

76 tanh_arg = sqrt_2_over_pi * (x_fp32 + 0.044715 * a_cubed) 

77 tanh_result = tanh(tanh_arg) 

78 geglu_a = 0.5 * x_fp32 * (1 + tanh_result) 

79 dy = geglu_a * dgrad 

80 

81 term1 = 0.5 * (1 + tanh_result) 

82 tanh_sq = tanh_result * tanh_result 

83 term2 = ( 

84 0.5 

85 * x_fp32 

86 * (1 - tanh_sq) 

87 * (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_fp32 * x_fp32)) 

88 ) 

89 dx = dgrad * y_fp32 * (term1 + term2) 

90 

91 return dx, dy 

92 

93 

94class GeluAndMul(torch.autograd.Function): 

95 @staticmethod 

96 def forward(ctx, x, y, approximate="none"): 

97 logger.debug("GEMS GELU AND MUL FORWARD") 

98 ctx.save_for_backward(x, y) 

99 ctx.approximate = approximate 

100 if approximate == "none": 

101 return gelu_none_and_mul_kernel(x, y) 

102 elif approximate == "tanh": 

103 return gelu_tanh_and_mul_kernel(x, y) 

104 else: 

105 raise ValueError(f"Invalid approximate value: {approximate}") 

106 

107 @staticmethod 

108 def backward(ctx, dgrad): 

109 logger.debug("GEMS GELU AND MUL BACKWARD") 

110 x, y = ctx.saved_tensors 

111 if ctx.approximate == "none": 

112 dx, dy = gelu_none_and_mul_grad_kernel(x, y, dgrad) 

113 else: 

114 dx, dy = gelu_tanh_and_mul_grad_kernel(x, y, dgrad) 

115 return dx, dy, None 

116 

117 

118def gelu_and_mul(x, y, approximate="none"): 

119 return GeluAndMul.apply(x, y, approximate)