Coverage for src/flag_gems/fused/geglu.py: 51%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2from typing import Any, Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import tl_extra_shim 

9 

10erf = tl_extra_shim.erf 

11exp = tl_extra_shim.exp 

12pow = tl_extra_shim.pow 

13tanh = tl_extra_shim.tanh 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18@triton.jit 

19def geglu_kernel( 

20 input_ptr, 

21 output_ptr, 

22 M, 

23 H, 

24 stride_in_m, 

25 stride_in_h, 

26 stride_out_m, 

27 stride_out_h, 

28 BLOCK_SIZE_M: tl.constexpr, 

29 BLOCK_SIZE_H: tl.constexpr, 

30): 

31 pid_m = tl.program_id(0) 

32 pid_h = tl.program_id(1) 

33 

34 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

35 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) 

36 

37 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H) 

38 

39 input_a_ptr = ( 

40 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h 

41 ) 

42 input_b_ptr = ( 

43 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h 

44 ) 

45 output_ptr = ( 

46 output_ptr + offs_m[:, None] * stride_out_m + offs_h[None, :] * stride_out_h 

47 ) 

48 

49 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32) 

50 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32) 

51 

52 gelu_out = 0.5 * x_a * (1 + tanh(0.79788456 * x_a * (1 + 0.044715 * pow(x_a, 2)))) 

53 out = gelu_out * x_b 

54 

55 tl.store(output_ptr, out.to(tl.float32), mask=mask) 

56 

57 

58@triton.jit 

59def dgeglu_kernel( 

60 grad_out_ptr, 

61 input_ptr, 

62 grad_in_ptr, 

63 M, 

64 H, 

65 stride_grad_out_m, 

66 stride_grad_out_h, 

67 stride_in_m, 

68 stride_in_h, 

69 stride_grad_in_m, 

70 stride_grad_in_h, 

71 BLOCK_SIZE_M: tl.constexpr, 

72 BLOCK_SIZE_H: tl.constexpr, 

73): 

74 pid_m = tl.program_id(0) 

75 pid_h = tl.program_id(1) 

76 

77 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

78 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) 

79 

80 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H) 

81 

82 grad_out_ptr = ( 

83 grad_out_ptr 

84 + offs_m[:, None] * stride_grad_out_m 

85 + offs_h[None, :] * stride_grad_out_h 

86 ) 

87 input_a_ptr = ( 

88 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h 

89 ) 

90 input_b_ptr = ( 

91 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h 

92 ) 

93 grad_a_ptr = ( 

94 grad_in_ptr 

95 + offs_m[:, None] * stride_grad_in_m 

96 + offs_h[None, :] * stride_grad_in_h 

97 ) 

98 grad_b_ptr = ( 

99 grad_in_ptr 

100 + offs_m[:, None] * stride_grad_in_m 

101 + (offs_h[None, :] + H) * stride_grad_in_h 

102 ) 

103 

104 grad_out = tl.load(grad_out_ptr, mask=mask, other=0.0).to(tl.float32) 

105 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32) 

106 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32) 

107 

108 tanh_out = tanh(0.79788456 * x_a * (1 + 0.044715 * pow(x_a, 2))) 

109 gelu_out = 0.5 * x_a * (1 + tanh_out) 

110 

111 # dgelu/dx 

112 sech2 = 1 - pow(tanh_out, 2) 

113 dgelu = 0.5 * (1 + tanh_out) + 0.5 * x_a * sech2 * 0.79788456 * ( 

114 1 + 3 * 0.044715 * pow(x_a, 2) 

115 ) 

116 

117 grad_a = grad_out * x_b * dgelu 

118 grad_b = grad_out * gelu_out 

119 

120 tl.store(grad_a_ptr, grad_a.to(x_a.dtype), mask=mask) 

121 tl.store(grad_b_ptr, grad_b.to(x_a.dtype), mask=mask) 

122 

123 

124def geglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor: 

125 logger.debug("GEMS GEGLU") 

126 shape = input_tensor.shape 

127 H = shape[-1] // 2 

128 M = input_tensor.numel() // (2 * H) 

129 

130 input_2d = input_tensor.contiguous().view(M, 2 * H) 

131 output_2d = torch.empty(M, H, device=input_tensor.device, dtype=input_tensor.dtype) 

132 

133 grid = lambda META: ( 

134 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

135 triton.cdiv(H, META["BLOCK_SIZE_H"]), 

136 ) 

137 

138 geglu_kernel[grid]( 

139 input_2d, 

140 output_2d, 

141 M, 

142 H, 

143 input_2d.stride(0), 

144 input_2d.stride(1), 

145 output_2d.stride(0), 

146 output_2d.stride(1), 

147 BLOCK_SIZE_M=64, 

148 BLOCK_SIZE_H=64, 

149 ) 

150 # print("geglu") 

151 return output_2d.view(*shape[:-1], H) 

152 

153 

154def dgeglu( 

155 grad_output: torch.Tensor, 

156 input_tensor: torch.Tensor, 

157 quantizer: Optional[Any] = None, 

158) -> torch.Tensor: 

159 logger.debug("GEMS DGEGLU") 

160 shape = input_tensor.shape 

161 H = shape[-1] // 2 

162 M = input_tensor.numel() // (2 * H) 

163 

164 grad_out_2d = grad_output.contiguous().view(M, H) 

165 input_2d = input_tensor.contiguous().view(M, 2 * H) 

166 grad_in_2d = torch.empty_like(input_2d) 

167 

168 grid = lambda META: ( 

169 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

170 triton.cdiv(H, META["BLOCK_SIZE_H"]), 

171 ) 

172 

173 dgeglu_kernel[grid]( 

174 grad_out_2d, 

175 input_2d, 

176 grad_in_2d, 

177 M, 

178 H, 

179 grad_out_2d.stride(0), 

180 grad_out_2d.stride(1), 

181 input_2d.stride(0), 

182 input_2d.stride(1), 

183 grad_in_2d.stride(0), 

184 grad_in_2d.stride(1), 

185 BLOCK_SIZE_M=64, 

186 BLOCK_SIZE_H=64, 

187 ) 

188 # print(dgeglu) 

189 return grad_in_2d.view_as(input_tensor)