Coverage for src/flag_gems/fused/swiglu.py: 53%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2from typing import Any, Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import tl_extra_shim 

9 

10sigmoid = tl.sigmoid 

11exp = tl_extra_shim.exp 

12pow = tl_extra_shim.pow 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@triton.jit 

18def swiglu_kernel( 

19 input_ptr, 

20 output_ptr, 

21 M, 

22 H, 

23 stride_in_m, 

24 stride_in_h, 

25 stride_out_m, 

26 stride_out_h, 

27 BLOCK_SIZE_M: tl.constexpr, 

28 BLOCK_SIZE_H: tl.constexpr, 

29): 

30 pid_m = tl.program_id(0) 

31 pid_h = tl.program_id(1) 

32 

33 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

34 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) 

35 

36 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H) 

37 

38 input_a_ptr = ( 

39 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h 

40 ) 

41 input_b_ptr = ( 

42 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h 

43 ) 

44 output_ptr = ( 

45 output_ptr + offs_m[:, None] * stride_out_m + offs_h[None, :] * stride_out_h 

46 ) 

47 

48 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32) 

49 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32) 

50 

51 silu_x_a = x_a * sigmoid(x_a) 

52 out = silu_x_a * x_b 

53 

54 tl.store(output_ptr, out.to(x_a.dtype), mask=mask) 

55 

56 

57@triton.jit 

58def dswiglu_kernel( 

59 grad_out_ptr, 

60 input_ptr, 

61 grad_in_ptr, 

62 M, 

63 H, 

64 stride_grad_out_m, 

65 stride_grad_out_h, 

66 stride_in_m, 

67 stride_in_h, 

68 stride_grad_in_m, 

69 stride_grad_in_h, 

70 BLOCK_SIZE_M: tl.constexpr, 

71 BLOCK_SIZE_H: tl.constexpr, 

72): 

73 pid_m = tl.program_id(0) 

74 pid_h = tl.program_id(1) 

75 

76 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

77 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) 

78 

79 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H) 

80 

81 grad_out_ptr = ( 

82 grad_out_ptr 

83 + offs_m[:, None] * stride_grad_out_m 

84 + offs_h[None, :] * stride_grad_out_h 

85 ) 

86 input_a_ptr = ( 

87 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h 

88 ) 

89 input_b_ptr = ( 

90 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h 

91 ) 

92 grad_a_ptr = ( 

93 grad_in_ptr 

94 + offs_m[:, None] * stride_grad_in_m 

95 + offs_h[None, :] * stride_grad_in_h 

96 ) 

97 grad_b_ptr = ( 

98 grad_in_ptr 

99 + offs_m[:, None] * stride_grad_in_m 

100 + (offs_h[None, :] + H) * stride_grad_in_h 

101 ) 

102 

103 grad_out = tl.load(grad_out_ptr, mask=mask, other=0.0).to(tl.float32) 

104 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32) 

105 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32) 

106 

107 sig = sigmoid(x_a) 

108 silu = x_a * sig 

109 d_silu = sig + x_a * sig * (1 - sig) 

110 

111 grad_a = grad_out * x_b * d_silu 

112 grad_b = grad_out * silu 

113 

114 tl.store(grad_a_ptr, grad_a.to(x_a.dtype), mask=mask) 

115 tl.store(grad_b_ptr, grad_b.to(x_a.dtype), mask=mask) 

116 

117 

118def swiglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor: 

119 logger.debug("GEMS SWIGLU") 

120 if input_tensor.shape[-1] % 2 != 0: 

121 raise ValueError( 

122 f"The last dimension of must be even number, got {input_tensor.shape[-1]}" 

123 ) 

124 if not input_tensor.is_cuda: 

125 raise ValueError("Only CUDA tensor is supported by SwiGLU") 

126 

127 shape = input_tensor.shape 

128 H = shape[-1] // 2 

129 M = input_tensor.numel() // (2 * H) 

130 input_2d = input_tensor.contiguous().view(M, 2 * H) 

131 output_2d = torch.empty(M, H, device=input_tensor.device, dtype=input_tensor.dtype) 

132 

133 grid = lambda META: ( 

134 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

135 triton.cdiv(H, META["BLOCK_SIZE_H"]), 

136 ) 

137 

138 swiglu_kernel[grid]( 

139 input_2d, 

140 output_2d, 

141 M, 

142 H, 

143 input_2d.stride(0), 

144 input_2d.stride(1), 

145 output_2d.stride(0), 

146 output_2d.stride(1), 

147 BLOCK_SIZE_M=64, 

148 BLOCK_SIZE_H=64, 

149 ) 

150 

151 return output_2d.view(*shape[:-1], H) 

152 

153 

154def dswiglu( 

155 grad_output: torch.Tensor, 

156 input_tensor: torch.Tensor, 

157 quantizer: Optional[Any] = None, 

158) -> torch.Tensor: 

159 logger.debug("GEMS DSWIGLU") 

160 shape = input_tensor.shape 

161 assert ( 

162 shape[-1] % 2 == 0 

163 ), f"The last dimension of input_tensor must be an even number, got {shape[-1]}" 

164 H = shape[-1] // 2 

165 M = input_tensor.numel() // (2 * H) 

166 grad_out_2d = grad_output.contiguous().view(M, H) 

167 input_2d = input_tensor.contiguous().view(M, 2 * H) 

168 grad_in_2d = torch.empty_like(input_2d) 

169 

170 grid = lambda META: ( 

171 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

172 triton.cdiv(H, META["BLOCK_SIZE_H"]), 

173 ) 

174 

175 dswiglu_kernel[grid]( 

176 grad_out_2d, 

177 input_2d, 

178 grad_in_2d, 

179 M, 

180 H, 

181 grad_out_2d.stride(0), 

182 grad_out_2d.stride(1), 

183 input_2d.stride(0), 

184 input_2d.stride(1), 

185 grad_in_2d.stride(0), 

186 grad_in_2d.stride(1), 

187 BLOCK_SIZE_M=64, 

188 BLOCK_SIZE_H=64, 

189 ) 

190 

191 return grad_in_2d.view_as(input_tensor) 

192 

193 

194__all__ = ["swiglu", "dswiglu"]