Coverage for src/flag_gems/experimental_ops/gelu_.py: 0%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import torch # noqa: F401 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def gelu_( 

8 x_ptr, # *Pointer* to the input/output tensor (in-place). 

9 n_elements, # Number of elements. 

10 USE_TANH: tl.constexpr, # Whether to use tanh approximation. 

11 BLOCK_SIZE: tl.constexpr, # Elements per program. 

12): 

13 pid = tl.program_id(axis=0) 

14 block_start = pid * BLOCK_SIZE 

15 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

16 mask = offsets < n_elements 

17 

18 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

19 x_f32 = x.to(tl.float32) 

20 

21 # Compute GELU either exact (via erf approximation) or tanh approximation 

22 if USE_TANH: 

23 # tanh approximation: 

24 # gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 x^3))) 

25 c0 = 0.7978845608028654 # sqrt(2/pi) 

26 c1 = 0.044715 

27 x3 = x_f32 * x_f32 * x_f32 

28 z = c0 * (x_f32 + c1 * x3) 

29 # tanh(z) = (1 - e^{-2z}) / (1 + e^{-2z}) 

30 t = tl.exp(-2.0 * z) 

31 tanh_z = (1.0 - t) / (1.0 + t) 

32 y = 0.5 * x_f32 * (1.0 + tanh_z) 

33 else: 

34 # exact (erf-based) GELU: 

35 # gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) 

36 inv_sqrt2 = 0.7071067811865476 

37 z = x_f32 * inv_sqrt2 

38 

39 # Abramowitz and Stegun formula 7.1.26 for erf approximation 

40 # erf(x) ≈ sign(x) * (1 - (((((a5*t + a4)*t + a3)*t + a2)*t + a1)*t) * e^{-x^2}) 

41 # where t = 1 / (1 + p*|x|) 

42 p = 0.3275911 

43 a1 = 0.254829592 

44 a2 = -0.284496736 

45 a3 = 1.421413741 

46 a4 = -1.453152027 

47 a5 = 1.061405429 

48 

49 az = tl.abs(z) 

50 t = 1.0 / (1.0 + p * az) 

51 poly = a5 

52 poly = poly * t + a4 

53 poly = poly * t + a3 

54 poly = poly * t + a2 

55 poly = poly * t + a1 

56 poly = poly * t 

57 erf_abs = 1.0 - poly * tl.exp(-az * az) 

58 erf_z = tl.where(z >= 0, erf_abs, -erf_abs) 

59 

60 y = 0.5 * x_f32 * (1.0 + erf_z) 

61 

62 y_cast = y.to(x.dtype) 

63 tl.store(x_ptr + offsets, y_cast, mask=mask) 

64 

65 

66# Preserve a handle to the kernel before defining the Python wrapper of the same name 

67gelu__kernel = gelu_ 

68 

69 

70def gelu_(*args, **kwargs): 

71 # Resolve input tensor 

72 x = None 

73 if len(args) >= 1: 

74 x = args[0] 

75 else: 

76 # Try common names 

77 x = kwargs.get("input", None) 

78 if x is None: 

79 x = kwargs.get("self", None) 

80 if x is None: 

81 x = kwargs.get("x", None) 

82 if x is None: 

83 raise ValueError("gelu_ expects a tensor as the first argument.") 

84 

85 # Determine approximation mode 

86 approx = kwargs.get("approximate", "none") 

87 if isinstance(approx, bool): 

88 use_tanh = bool(approx) 

89 else: 

90 approx_str = str(approx).lower() 

91 if approx_str in ("tanh", "true"): 

92 use_tanh = True 

93 elif approx_str in ("none", "false"): 

94 use_tanh = False 

95 else: 

96 raise ValueError( 

97 f"Unsupported approximate mode: {approx}. Use 'none' or 'tanh'." 

98 ) 

99 

100 if not x.is_cuda: 

101 raise AssertionError("Input tensor must be on CUDA device for Triton kernel.") 

102 if not x.is_contiguous(): 

103 raise AssertionError("Input tensor must be contiguous.") 

104 if not x.is_floating_point(): 

105 raise AssertionError("gelu_ expects a floating point tensor.") 

106 

107 n_elements = x.numel() 

108 if n_elements == 0: 

109 return x 

110 

111 BLOCK_SIZE = 1024 

112 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

113 

114 gelu__kernel[grid](x, n_elements, USE_TANH=use_tanh, BLOCK_SIZE=BLOCK_SIZE) 

115 return x