Coverage for src/flag_gems/fused/reglu.py: 53%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2from typing import Any, Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.utils import libentry, libtuner 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@libtuner( 

16 configs=runtime.get_tuned_config("gated_activation"), 

17 key=["M", "N"], 

18) 

19@triton.jit 

20def dreglu_kernel( 

21 grad_output_ptr, 

22 input_ptr, 

23 grad_input_ptr, 

24 M, 

25 N, 

26 stride_grad_out_m, 

27 stride_grad_out_n, 

28 stride_in_m, 

29 stride_in_n, 

30 stride_grad_in_m, 

31 stride_grad_in_n, 

32 BLOCK_M: tl.constexpr, 

33 BLOCK_N: tl.constexpr, 

34): 

35 pid_m = tl.program_id(axis=0) 

36 pid_n = tl.program_id(axis=1) 

37 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

38 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

39 grad_output_ptr += ( 

40 offs_m[:, None] * stride_grad_out_m + offs_n[None, :] * stride_grad_out_n 

41 ) 

42 input_ptr_a = ( 

43 input_ptr + offs_m[:, None] * stride_in_m + offs_n[None, :] * stride_in_n 

44 ) 

45 input_ptr_b = ( 

46 input_ptr + offs_m[:, None] * stride_in_m + (offs_n[None, :] + N) * stride_in_n 

47 ) 

48 grad_input_ptr_a = ( 

49 grad_input_ptr 

50 + offs_m[:, None] * stride_grad_in_m 

51 + offs_n[None, :] * stride_grad_in_n 

52 ) 

53 grad_input_ptr_b = ( 

54 grad_input_ptr 

55 + offs_m[:, None] * stride_grad_in_m 

56 + (offs_n[None, :] + N) * stride_grad_in_n 

57 ) 

58 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

59 grad_out = tl.load(grad_output_ptr, mask=mask, other=0.0) 

60 block_a = tl.load(input_ptr_a, mask=mask, other=0.0) 

61 block_b = tl.load(input_ptr_b, mask=mask, other=0.0) 

62 relu_a = tl.maximum(block_a, 0.0) 

63 d_relu_a = tl.where(block_a > 0, 1.0, 0.0) 

64 grad_a = grad_out * d_relu_a * block_b 

65 grad_b = grad_out * relu_a 

66 tl.store(grad_input_ptr_a, grad_a, mask=mask) 

67 tl.store(grad_input_ptr_b, grad_b, mask=mask) 

68 

69 

70@libentry() 

71@libtuner( 

72 configs=runtime.get_tuned_config("gated_activation"), 

73 key=["M", "N_OUT"], 

74) 

75@triton.jit 

76def reglu_kernel( 

77 x_ptr, 

78 y_ptr, 

79 M, 

80 N_OUT, 

81 stride_x_m, 

82 stride_x_n, 

83 stride_y_m, 

84 stride_y_n, 

85 BLOCK_M: tl.constexpr, 

86 BLOCK_N: tl.constexpr, 

87): 

88 pid_m = tl.program_id(axis=0) 

89 pid_n = tl.program_id(axis=1) 

90 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

91 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

92 x_ptr_a = x_ptr + offs_m[:, None] * stride_x_m + offs_n[None, :] * stride_x_n 

93 x_ptr_b = ( 

94 x_ptr + offs_m[:, None] * stride_x_m + (offs_n[None, :] + N_OUT) * stride_x_n 

95 ) 

96 y_ptr = y_ptr + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n 

97 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N_OUT) 

98 block_a = tl.load(x_ptr_a, mask=mask, other=0.0) 

99 block_b = tl.load(x_ptr_b, mask=mask, other=0.0) 

100 gate = tl.where(block_a > 0, block_a, 0.0) 

101 output = gate * block_b 

102 tl.store(y_ptr, output, mask=mask) 

103 

104 

105def reglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor: 

106 shape = input_tensor.shape 

107 if input_tensor.dim() < 1: 

108 raise ValueError("Input tensor must have at least 1 dimension.") 

109 last_dim = shape[-1] 

110 if last_dim % 2 != 0: 

111 raise ValueError( 

112 f"The last dimension of the input tensor must be even, but got {last_dim}." 

113 ) 

114 N_OUT = last_dim // 2 

115 M = input_tensor.numel() // last_dim 

116 if input_tensor.numel() == 0: 

117 output_shape = (*shape[:-1], N_OUT) 

118 return torch.empty( 

119 output_shape, device=input_tensor.device, dtype=input_tensor.dtype 

120 ) 

121 input_2d = input_tensor.contiguous().view(M, last_dim) 

122 output_2d = torch.empty( 

123 (M, N_OUT), device=input_tensor.device, dtype=input_tensor.dtype 

124 ) 

125 grid = lambda META: ( 

126 triton.cdiv(M, META["BLOCK_M"]), 

127 triton.cdiv(N_OUT, META["BLOCK_N"]), 

128 ) 

129 reglu_kernel[grid]( 

130 input_2d, 

131 output_2d, 

132 M, 

133 N_OUT, 

134 input_2d.stride(0), 

135 input_2d.stride(1), 

136 output_2d.stride(0), 

137 output_2d.stride(1), 

138 ) 

139 output_shape = (*shape[:-1], N_OUT) 

140 return output_2d.view(output_shape) 

141 

142 

143def dreglu( 

144 grad_output: torch.Tensor, 

145 input_tensor: torch.Tensor, 

146 quantizer: Optional[Any] = None, 

147) -> torch.Tensor: 

148 shape = input_tensor.shape 

149 if shape[:-1] != grad_output.shape[:-1] or shape[-1] != 2 * grad_output.shape[-1]: 

150 raise ValueError( 

151 f"Shape mismatch: input {shape} vs grad_output {grad_output.shape}" 

152 ) 

153 M = grad_output.numel() // grad_output.shape[-1] 

154 N = grad_output.shape[-1] 

155 grad_output_2d = grad_output.contiguous().view(M, N) 

156 input_2d = input_tensor.contiguous().view(M, 2 * N) 

157 grad_input = torch.empty_like(input_2d) 

158 grid = lambda META: ( 

159 triton.cdiv(M, META["BLOCK_M"]), 

160 triton.cdiv(N, META["BLOCK_N"]), 

161 ) 

162 dreglu_kernel[grid]( 

163 grad_output_2d, 

164 input_2d, 

165 grad_input, 

166 M, 

167 N, 

168 grad_output_2d.stride(0), 

169 grad_output_2d.stride(1), 

170 input_2d.stride(0), 

171 input_2d.stride(1), 

172 grad_input.stride(0), 

173 grad_input.stride(1), 

174 ) 

175 return grad_input.view(shape)