Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/gelu_and_mul.py: 0%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import tl_extra_shim 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12erf = tl_extra_shim.erf 

13pow = tl_extra_shim.pow 

14tanh = tl_extra_shim.tanh 

15 

16 

17@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

18@triton.jit 

19def gelu_none_and_mul_kernel(x, y): 

20 x_fp32 = x.to(tl.float32) 

21 x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * 0.7071067811)) 

22 return x_gelu * y 

23 

24 

25@pointwise_dynamic( 

26 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2 

27) 

28@triton.jit 

29def gelu_none_and_mul_grad_kernel(x, y, dgrad): 

30 RCP_SQRT_2: tl.constexpr = 0.7071067811 

31 COEFF: tl.constexpr = 0.7978845608028654 

32 

33 x_fp32 = x.to(tl.float32) 

34 x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * RCP_SQRT_2)) 

35 

36 d_gelu = dgrad * y 

37 dx = ( 

38 d_gelu 

39 * 0.5 

40 * ( 

41 1.0 

42 + erf(x_fp32 * RCP_SQRT_2) 

43 + x_fp32 * COEFF * tl.exp(-0.5 * x_fp32 * x_fp32) 

44 ) 

45 ) 

46 

47 dy = dgrad * x_gelu 

48 

49 return dx, dy 

50 

51 

52@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

53@triton.jit 

54def gelu_tanh_and_mul_kernel(x, y): 

55 x_fp32 = x.to(tl.float32) 

56 x_gelu = ( 

57 0.5 

58 * x_fp32 

59 * ( 

60 1 

61 + tanh( 

62 x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32.to(tl.float32), 2.0)) 

63 ) 

64 ) 

65 ) 

66 return x_gelu * y 

67 

68 

69@pointwise_dynamic( 

70 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2 

71) 

72@triton.jit 

73def gelu_tanh_and_mul_grad_kernel(x, y, dgrad): 

74 x_fp32 = x.to(tl.float32) 

75 y_fp32 = y.to(tl.float32) 

76 

77 sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) 

78 a_cubed = x_fp32 * x_fp32 * x_fp32 

79 tanh_arg = sqrt_2_over_pi * (x_fp32 + 0.044715 * a_cubed) 

80 tanh_result = tanh(tanh_arg) 

81 geglu_a = 0.5 * x_fp32 * (1 + tanh_result) 

82 dy = geglu_a * dgrad 

83 

84 term1 = 0.5 * (1 + tanh_result) 

85 tanh_sq = tanh_result * tanh_result 

86 term2 = ( 

87 0.5 

88 * x_fp32 

89 * (1 - tanh_sq) 

90 * (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_fp32 * x_fp32)) 

91 ) 

92 dx = dgrad * y_fp32 * (term1 + term2) 

93 

94 return dx, dy 

95 

96 

97class GeluAndMul(torch.autograd.Function): 

98 @staticmethod 

99 def forward(ctx, x, y, approximate="none"): 

100 logger.debug("GEMS GELU AND MUL FORWARD") 

101 ctx.save_for_backward(x, y) 

102 ctx.approximate = approximate 

103 if approximate == "none": 

104 return gelu_none_and_mul_kernel(x, y) 

105 elif approximate == "tanh": 

106 return gelu_tanh_and_mul_kernel(x, y) 

107 else: 

108 raise ValueError(f"Invalid approximate value: {approximate}") 

109 

110 @staticmethod 

111 def backward(ctx, dgrad): 

112 logger.debug("GEMS GELU AND MUL BACKWARD") 

113 x, y = ctx.saved_tensors 

114 if ctx.approximate == "none": 

115 dx, dy = gelu_none_and_mul_grad_kernel(x, y, dgrad) 

116 else: 

117 dx, dy = gelu_tanh_and_mul_grad_kernel(x, y, dgrad) 

118 return dx, dy, None 

119 

120 

121def gelu_and_mul(x, y, approximate="none"): 

122 return GeluAndMul.apply(x, y, approximate)