Coverage for src/flag_gems/fused/swiglu.py: 51%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2from typing import Any, Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import tl_extra_shim 

9 

10sigmoid = tl.sigmoid 

11exp = tl_extra_shim.exp 

12pow = tl_extra_shim.pow 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@triton.jit 

18def swiglu_kernel( 

19 input_ptr, 

20 output_ptr, 

21 M, 

22 H, 

23 stride_in_m, 

24 stride_in_h, 

25 stride_out_m, 

26 stride_out_h, 

27 BLOCK_SIZE_M: tl.constexpr, 

28 BLOCK_SIZE_H: tl.constexpr, 

29): 

30 pid_m = tl.program_id(0) 

31 pid_h = tl.program_id(1) 

32 

33 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

34 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) 

35 

36 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H) 

37 

38 input_a_ptr = ( 

39 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h 

40 ) 

41 input_b_ptr = ( 

42 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h 

43 ) 

44 output_ptr = ( 

45 output_ptr + offs_m[:, None] * stride_out_m + offs_h[None, :] * stride_out_h 

46 ) 

47 

48 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32) 

49 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32) 

50 

51 silu_x_a = x_a * sigmoid(x_a) 

52 out = silu_x_a * x_b 

53 

54 tl.store(output_ptr, out.to(x_a.dtype), mask=mask) 

55 

56 

57@triton.jit 

58def dswiglu_kernel( 

59 grad_out_ptr, 

60 input_ptr, 

61 grad_in_ptr, 

62 M, 

63 H, 

64 stride_grad_out_m, 

65 stride_grad_out_h, 

66 stride_in_m, 

67 stride_in_h, 

68 stride_grad_in_m, 

69 stride_grad_in_h, 

70 BLOCK_SIZE_M: tl.constexpr, 

71 BLOCK_SIZE_H: tl.constexpr, 

72): 

73 pid_m = tl.program_id(0) 

74 pid_h = tl.program_id(1) 

75 

76 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

77 offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) 

78 

79 mask = (offs_m[:, None] < M) & (offs_h[None, :] < H) 

80 

81 grad_out_ptr = ( 

82 grad_out_ptr 

83 + offs_m[:, None] * stride_grad_out_m 

84 + offs_h[None, :] * stride_grad_out_h 

85 ) 

86 input_a_ptr = ( 

87 input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h 

88 ) 

89 input_b_ptr = ( 

90 input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h 

91 ) 

92 grad_a_ptr = ( 

93 grad_in_ptr 

94 + offs_m[:, None] * stride_grad_in_m 

95 + offs_h[None, :] * stride_grad_in_h 

96 ) 

97 grad_b_ptr = ( 

98 grad_in_ptr 

99 + offs_m[:, None] * stride_grad_in_m 

100 + (offs_h[None, :] + H) * stride_grad_in_h 

101 ) 

102 

103 grad_out = tl.load(grad_out_ptr, mask=mask, other=0.0).to(tl.float32) 

104 x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32) 

105 x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32) 

106 

107 sig = sigmoid(x_a) 

108 silu = x_a * sig 

109 d_silu = sig + x_a * sig * (1 - sig) 

110 

111 grad_a = grad_out * x_b * d_silu 

112 grad_b = grad_out * silu 

113 

114 tl.store(grad_a_ptr, grad_a.to(x_a.dtype), mask=mask) 

115 tl.store(grad_b_ptr, grad_b.to(x_a.dtype), mask=mask) 

116 

117 

118def swiglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor: 

119 if input_tensor.shape[-1] % 2 != 0: 

120 raise ValueError( 

121 f"The last dimension of must be even number, got {input_tensor.shape[-1]}" 

122 ) 

123 if not input_tensor.is_cuda: 

124 raise ValueError("Only CUDA tensor is supported by SwiGLU") 

125 

126 shape = input_tensor.shape 

127 H = shape[-1] // 2 

128 M = input_tensor.numel() // (2 * H) 

129 input_2d = input_tensor.contiguous().view(M, 2 * H) 

130 output_2d = torch.empty(M, H, device=input_tensor.device, dtype=input_tensor.dtype) 

131 

132 grid = lambda META: ( 

133 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

134 triton.cdiv(H, META["BLOCK_SIZE_H"]), 

135 ) 

136 

137 swiglu_kernel[grid]( 

138 input_2d, 

139 output_2d, 

140 M, 

141 H, 

142 input_2d.stride(0), 

143 input_2d.stride(1), 

144 output_2d.stride(0), 

145 output_2d.stride(1), 

146 BLOCK_SIZE_M=64, 

147 BLOCK_SIZE_H=64, 

148 ) 

149 

150 return output_2d.view(*shape[:-1], H) 

151 

152 

153def dswiglu( 

154 grad_output: torch.Tensor, 

155 input_tensor: torch.Tensor, 

156 quantizer: Optional[Any] = None, 

157) -> torch.Tensor: 

158 shape = input_tensor.shape 

159 assert ( 

160 shape[-1] % 2 == 0 

161 ), f"The last dimension of input_tensor must be an even number, got {shape[-1]}" 

162 H = shape[-1] // 2 

163 M = input_tensor.numel() // (2 * H) 

164 grad_out_2d = grad_output.contiguous().view(M, H) 

165 input_2d = input_tensor.contiguous().view(M, 2 * H) 

166 grad_in_2d = torch.empty_like(input_2d) 

167 

168 grid = lambda META: ( 

169 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

170 triton.cdiv(H, META["BLOCK_SIZE_H"]), 

171 ) 

172 

173 dswiglu_kernel[grid]( 

174 grad_out_2d, 

175 input_2d, 

176 grad_in_2d, 

177 M, 

178 H, 

179 grad_out_2d.stride(0), 

180 grad_out_2d.stride(1), 

181 input_2d.stride(0), 

182 input_2d.stride(1), 

183 grad_in_2d.stride(0), 

184 grad_in_2d.stride(1), 

185 BLOCK_SIZE_M=64, 

186 BLOCK_SIZE_H=64, 

187 ) 

188 

189 return grad_in_2d.view_as(input_tensor) 

190 

191 

192__all__ = ["swiglu", "dswiglu"]