Coverage for src/flag_gems/runtime/backend/_cambricon/fused/gelu_and_mul.py: 0%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import tl_extra_shim 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11logger = logging.getLogger(__name__) 

12fast_erf = tl_extra_shim.fast_erf 

13fast_tanh = tl_extra_shim.fast_tanh 

14 

15 

16@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

17@triton.jit 

18def gelu_none_and_mul_kernel(x, y): 

19 x_fp32 = x.to(tl.float32) 

20 x_gelu = 0.5 * x_fp32 * (1 + fast_erf(x_fp32 * 0.7071067811)) 

21 return x_gelu * y 

22 

23 

24@pointwise_dynamic( 

25 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2 

26) 

27@triton.jit 

28def gelu_none_and_mul_grad_kernel(x, y, dgrad): 

29 RCP_SQRT_2: tl.constexpr = 0.7071067811 

30 COEFF: tl.constexpr = 0.7978845608028654 

31 

32 x_fp32 = x.to(tl.float32) 

33 x_gelu = 0.5 * x_fp32 * (1 + fast_erf(x_fp32 * RCP_SQRT_2)) 

34 

35 d_gelu = dgrad * y 

36 dx = ( 

37 d_gelu 

38 * 0.5 

39 * ( 

40 1.0 

41 + fast_erf(x_fp32 * RCP_SQRT_2) 

42 + x_fp32 * COEFF * tl.exp(-0.5 * x_fp32 * x_fp32) 

43 ) 

44 ) 

45 

46 dy = dgrad * x_gelu 

47 

48 return dx, dy 

49 

50 

51@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

52@triton.jit 

53def gelu_tanh_and_mul_kernel(x, y): 

54 x_fp32 = x.to(tl.float32) 

55 x_gelu = ( 

56 0.5 

57 * x_fp32 

58 * (1 + fast_tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * x_fp32 * x_fp32))) 

59 ) 

60 return x_gelu * y 

61 

62 

63@pointwise_dynamic( 

64 promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2 

65) 

66@triton.jit 

67def gelu_tanh_and_mul_grad_kernel(x, y, dgrad): 

68 x_fp32 = x.to(tl.float32) 

69 y_fp32 = y.to(tl.float32) 

70 

71 sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) 

72 a_cubed = x_fp32 * x_fp32 * x_fp32 

73 tanh_arg = sqrt_2_over_pi * (x_fp32 + 0.044715 * a_cubed) 

74 tanh_result = fast_tanh(tanh_arg) 

75 geglu_a = 0.5 * x_fp32 * (1 + tanh_result) 

76 dy = geglu_a * dgrad 

77 

78 term1 = 0.5 * (1 + tanh_result) 

79 tanh_sq = tanh_result * tanh_result 

80 term2 = ( 

81 0.5 

82 * x_fp32 

83 * (1 - tanh_sq) 

84 * (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_fp32 * x_fp32)) 

85 ) 

86 dx = dgrad * y_fp32 * (term1 + term2) 

87 

88 return dx, dy 

89 

90 

91class GeluAndMul(torch.autograd.Function): 

92 @staticmethod 

93 def forward(ctx, x, y, approximate="none"): 

94 logger.debug("GEMS_CAMBRICON GELU AND MUL FORWARD") 

95 ctx.save_for_backward(x, y) 

96 ctx.approximate = approximate 

97 if approximate == "none": 

98 return gelu_none_and_mul_kernel(x, y) 

99 elif approximate == "tanh": 

100 return gelu_tanh_and_mul_kernel(x, y) 

101 else: 

102 raise ValueError(f"Invalid approximate value: {approximate}") 

103 

104 @staticmethod 

105 def backward(ctx, dgrad): 

106 logger.debug("GEMS_CAMBRICON GELU AND MUL BACKWARD") 

107 x, y = ctx.saved_tensors 

108 if ctx.approximate == "none": 

109 dx, dy = gelu_none_and_mul_grad_kernel(x, y, dgrad) 

110 else: 

111 dx, dy = gelu_tanh_and_mul_grad_kernel(x, y, dgrad) 

112 return dx, dy, None 

113 

114 

115def gelu_and_mul(A, B, approximate="none"): 

116 return GeluAndMul.apply(A, B, approximate)