Coverage for src/flag_gems/fused/geglu.py: 50%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2from typing import Any, Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import tl_extra_shim 

9 

10erf = tl_extra_shim.erf 

11exp = tl_extra_shim.exp 

12pow = tl_extra_shim.pow 

13tanh = tl_extra_shim.tanh 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18@triton.jit 

19def geglu_kernel( 

20 input_ptr, 

21 output_ptr, 

22 M, 

23 H, 

24 stride_in_m, 

25 stride_in_h, 

26 stride_out_m, 

27 stride_out_h, 

28 BLOCK_SIZE_M: tl.constexpr, 

29 BLOCK_SIZE_H: tl.constexpr, 

30): 

31 pid_m = tl.program_id(0) 

32 pid_h = tl.program_id(1) 

33 

34 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

35 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) 

36 

37 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H) 

38 

39 input_a_ptr = ( 

40 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h 

41 ) 

42 input_b_ptr = ( 

43 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h 

44 ) 

45 output_ptr = ( 

46 output_ptr + offs_m[:, None] * stride_out_m + offs_h[None, :] * stride_out_h 

47 ) 

48 

49 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32) 

50 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32) 

51 

52 gelu_out = 0.5 * x_a * (1 + tanh(0.79788456 * x_a * (1 + 0.044715 * pow(x_a, 2)))) 

53 out = gelu_out * x_b 

54 

55 tl.store(output_ptr, out.to(tl.float32), mask=mask) 

56 

57 

58@triton.jit 

59def dgeglu_kernel( 

60 grad_out_ptr, 

61 input_ptr, 

62 grad_in_ptr, 

63 M, 

64 H, 

65 stride_grad_out_m, 

66 stride_grad_out_h, 

67 stride_in_m, 

68 stride_in_h, 

69 stride_grad_in_m, 

70 stride_grad_in_h, 

71 BLOCK_SIZE_M: tl.constexpr, 

72 BLOCK_SIZE_H: tl.constexpr, 

73): 

74 pid_m = tl.program_id(0) 

75 pid_h = tl.program_id(1) 

76 

77 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

78 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) 

79 

80 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H) 

81 

82 grad_out_ptr = ( 

83 grad_out_ptr 

84 + offs_m[:, None] * stride_grad_out_m 

85 + offs_h[None, :] * stride_grad_out_h 

86 ) 

87 input_a_ptr = ( 

88 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h 

89 ) 

90 input_b_ptr = ( 

91 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h 

92 ) 

93 grad_a_ptr = ( 

94 grad_in_ptr 

95 + offs_m[:, None] * stride_grad_in_m 

96 + offs_h[None, :] * stride_grad_in_h 

97 ) 

98 grad_b_ptr = ( 

99 grad_in_ptr 

100 + offs_m[:, None] * stride_grad_in_m 

101 + (offs_h[None, :] + H) * stride_grad_in_h 

102 ) 

103 

104 grad_out = tl.load(grad_out_ptr, mask=mask, other=0.0).to(tl.float32) 

105 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32) 

106 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32) 

107 

108 tanh_out = tanh(0.79788456 * x_a * (1 + 0.044715 * pow(x_a, 2))) 

109 gelu_out = 0.5 * x_a * (1 + tanh_out) 

110 

111 # dgelu/dx 

112 sech2 = 1 - pow(tanh_out, 2) 

113 dgelu = 0.5 * (1 + tanh_out) + 0.5 * x_a * sech2 * 0.79788456 * ( 

114 1 + 3 * 0.044715 * pow(x_a, 2) 

115 ) 

116 

117 grad_a = grad_out * x_b * dgelu 

118 grad_b = grad_out * gelu_out 

119 

120 tl.store(grad_a_ptr, grad_a.to(x_a.dtype), mask=mask) 

121 tl.store(grad_b_ptr, grad_b.to(x_a.dtype), mask=mask) 

122 

123 

124def geglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor: 

125 shape = input_tensor.shape 

126 H = shape[-1] // 2 

127 M = input_tensor.numel() // (2 * H) 

128 

129 input_2d = input_tensor.contiguous().view(M, 2 * H) 

130 output_2d = torch.empty(M, H, device=input_tensor.device, dtype=input_tensor.dtype) 

131 

132 grid = lambda META: ( 

133 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

134 triton.cdiv(H, META["BLOCK_SIZE_H"]), 

135 ) 

136 

137 geglu_kernel[grid]( 

138 input_2d, 

139 output_2d, 

140 M, 

141 H, 

142 input_2d.stride(0), 

143 input_2d.stride(1), 

144 output_2d.stride(0), 

145 output_2d.stride(1), 

146 BLOCK_SIZE_M=64, 

147 BLOCK_SIZE_H=64, 

148 ) 

149 # print("geglu") 

150 return output_2d.view(*shape[:-1], H) 

151 

152 

153def dgeglu( 

154 grad_output: torch.Tensor, 

155 input_tensor: torch.Tensor, 

156 quantizer: Optional[Any] = None, 

157) -> torch.Tensor: 

158 shape = input_tensor.shape 

159 H = shape[-1] // 2 

160 M = input_tensor.numel() // (2 * H) 

161 

162 grad_out_2d = grad_output.contiguous().view(M, H) 

163 input_2d = input_tensor.contiguous().view(M, 2 * H) 

164 grad_in_2d = torch.empty_like(input_2d) 

165 

166 grid = lambda META: ( 

167 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

168 triton.cdiv(H, META["BLOCK_SIZE_H"]), 

169 ) 

170 

171 dgeglu_kernel[grid]( 

172 grad_out_2d, 

173 input_2d, 

174 grad_in_2d, 

175 M, 

176 H, 

177 grad_out_2d.stride(0), 

178 grad_out_2d.stride(1), 

179 input_2d.stride(0), 

180 input_2d.stride(1), 

181 grad_in_2d.stride(0), 

182 grad_in_2d.stride(1), 

183 BLOCK_SIZE_M=64, 

184 BLOCK_SIZE_H=64, 

185 ) 

186 # print(dgeglu) 

187 return grad_in_2d.view_as(input_tensor)