Coverage for src/flag_gems/runtime/backend/_cambricon/ops/gelu.py: 0%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import tl_extra_shim 

7 

8from ..utils.pointwise_dynamic import pointwise_dynamic 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11fast_erf = tl_extra_shim.fast_erf 

12exp = tl_extra_shim.exp 

13fast_tanh = tl_extra_shim.fast_tanh 

14 

15 

16@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")]) 

17@triton.jit 

18def gelu_none(x, inplace): 

19 scale: tl.constexpr = 0.7071067811 

20 x_f32 = x.to(tl.float32) 

21 output = 0.5 * x_f32 + 0.5 * x_f32 * fast_erf(x_f32 * scale) 

22 return output 

23 

24 

25@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")]) 

26@triton.jit 

27def gelu_tanh(x, inplace): 

28 x_f32 = x.to(tl.float32) 

29 output = 0.5 * x_f32 + 0.5 * x_f32 * fast_tanh( 

30 x_f32 * 0.79788456 + x_f32 * 0.79788456 * 0.044715 * x_f32 * x_f32 

31 ) 

32 return output 

33 

34 

35@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

36@triton.jit 

37def gelu_backward_none(x, dy): 

38 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2) 

39 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi) 

40 x_fp32 = x.to(tl.float32) 

41 x_sqrt = scale1 * x_fp32 

42 dydx = scale2 * x_fp32 * exp(-x_sqrt * x_sqrt) + 0.5 * fast_erf(x_sqrt) + 0.5 

43 dx = dydx * dy 

44 return dx 

45 

46 

47@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

48@triton.jit 

49def gelu_backward_tanh(x, dy): 

50 x_fp32 = x.to(tl.float32) 

51 c1 = 0.79788456 # math.sqrt(2 / math.pi) 

52 c2 = 0.044715 

53 # z = c1 * (x + c2 * x**3) 

54 tanh_out = fast_tanh(c1 * x_fp32 + c1 * x_fp32 * c2 * x_fp32 * x_fp32) 

55 # dz_dx = c1 * (1 + 3 * c2 * x * x) 

56 # 0.1070322243 = c1 * 3 *c2 

57 dydx = ( 

58 0.5 * ((x - x * tanh_out * tanh_out) * (c1 + 0.1070322243 * x_fp32 * x_fp32)) 

59 + 0.5 

60 + 0.5 * tanh_out 

61 ) 

62 dx = dydx * dy 

63 return dx 

64 

65 

66def gelu(self, *, approximate="none"): 

67 logger.debug("GEMS_CAMBRICON GELU FORWARD") 

68 if approximate == "tanh": 

69 out = gelu_tanh(self, False) 

70 else: 

71 out = gelu_none(self, False) 

72 return out 

73 

74 

75def gelu_backward(grad_output, self, *, approximate="none"): 

76 logger.debug("GEMS_CAMBRICON GELU BACKWARD") 

77 if approximate == "tanh": 

78 in_grad = gelu_backward_tanh(self, grad_output) 

79 else: 

80 in_grad = gelu_backward_none(self, grad_output) 

81 return in_grad 

82 

83 

84def gelu_(A, *, approximate="none"): 

85 logger.debug("GEMS_CAMBRICON GELU_ FORWARD") 

86 if approximate == "tanh": 

87 out = gelu_tanh(A, True, out0=A) 

88 else: 

89 out = gelu_none(A, True, out0=A) 

90 return out