Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/reglu.py: 0%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import logging 

2from typing import Any, Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.utils import libentry, libtuner 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14def heur_tile_m(args): 

15 return triton.cdiv(args["M"], 12) # cluster_num 

16 

17 

18def heru_tile_n(args): 

19 import builtins 

20 

21 return builtins.min(args["N"], 8192) 

22 

23 

24@libentry() 

25@libtuner( 

26 configs=[ 

27 triton.Config({"BLOCK_M": 1, "BLOCK_N": 1024}), 

28 triton.Config({"BLOCK_M": 2, "BLOCK_N": 1024}), 

29 triton.Config({"BLOCK_M": 4, "BLOCK_N": 1024}), 

30 triton.Config({"BLOCK_M": 8, "BLOCK_N": 1024}), 

31 triton.Config({"BLOCK_M": 6, "BLOCK_N": 32}), 

32 triton.Config({"BLOCK_M": 342, "BLOCK_N": 2048}), 

33 triton.Config({"BLOCK_M": 2731, "BLOCK_N": 256}), 

34 ], 

35 key=["M", "N"], 

36) 

37# @triton.heuristics( 

38# values={ 

39# "BLOCK_M": heur_tile_m, 

40# "BLOCK_N": heru_tile_n, 

41# }, 

42# ) 

43@triton.jit 

44def dreglu_kernel( 

45 grad_output_ptr, 

46 input_ptr, 

47 grad_input_ptr, 

48 M, 

49 N, 

50 stride_grad_out_m, 

51 stride_grad_out_n, 

52 stride_in_m, 

53 stride_in_n, 

54 stride_grad_in_m, 

55 stride_grad_in_n, 

56 BLOCK_M: tl.constexpr, 

57 BLOCK_N: tl.constexpr, 

58): 

59 pid_m = tl.program_id(axis=0) 

60 pid_n = tl.program_id(axis=1) 

61 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

62 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

63 grad_output_ptr += ( 

64 offs_m[:, None] * stride_grad_out_m + offs_n[None, :] * stride_grad_out_n 

65 ) 

66 input_ptr_a = ( 

67 input_ptr + offs_m[:, None] * stride_in_m + offs_n[None, :] * stride_in_n 

68 ) 

69 input_ptr_b = ( 

70 input_ptr + offs_m[:, None] * stride_in_m + (offs_n[None, :] + N) * stride_in_n 

71 ) 

72 grad_input_ptr_a = ( 

73 grad_input_ptr 

74 + offs_m[:, None] * stride_grad_in_m 

75 + offs_n[None, :] * stride_grad_in_n 

76 ) 

77 grad_input_ptr_b = ( 

78 grad_input_ptr 

79 + offs_m[:, None] * stride_grad_in_m 

80 + (offs_n[None, :] + N) * stride_grad_in_n 

81 ) 

82 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

83 grad_out = tl.load(grad_output_ptr, mask=mask, other=0.0).to(tl.float32) 

84 block_a = tl.load(input_ptr_a, mask=mask, other=0.0).to(tl.float32) 

85 block_b = tl.load(input_ptr_b, mask=mask, other=0.0).to(tl.float32) 

86 relu_a = tl.maximum(block_a, 0.0) 

87 d_relu_a = tl.where(block_a > 0, 1.0, 0.0) 

88 grad_a = grad_out * d_relu_a * block_b 

89 grad_b = grad_out * relu_a 

90 tl.store(grad_input_ptr_a, grad_a, mask=mask) 

91 tl.store(grad_input_ptr_b, grad_b, mask=mask) 

92 

93 

94@libentry() 

95@libtuner( 

96 configs=runtime.get_tuned_config("gated_activation"), 

97 key=["M", "N_OUT"], 

98) 

99@triton.jit 

100def reglu_kernel( 

101 x_ptr, 

102 y_ptr, 

103 M, 

104 N_OUT, 

105 stride_x_m, 

106 stride_x_n, 

107 stride_y_m, 

108 stride_y_n, 

109 BLOCK_M: tl.constexpr, 

110 BLOCK_N: tl.constexpr, 

111): 

112 pid_m = tl.program_id(axis=0) 

113 pid_n = tl.program_id(axis=1) 

114 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

115 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

116 x_ptr_a = x_ptr + offs_m[:, None] * stride_x_m + offs_n[None, :] * stride_x_n 

117 x_ptr_b = ( 

118 x_ptr + offs_m[:, None] * stride_x_m + (offs_n[None, :] + N_OUT) * stride_x_n 

119 ) 

120 y_ptr = y_ptr + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n 

121 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N_OUT) 

122 block_a = tl.load(x_ptr_a, mask=mask, other=0.0) 

123 block_b = tl.load(x_ptr_b, mask=mask, other=0.0) 

124 gate = tl.where(block_a > 0, block_a, 0.0) 

125 output = gate * block_b 

126 tl.store(y_ptr, output, mask=mask) 

127 

128 

129def reglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor: 

130 shape = input_tensor.shape 

131 if input_tensor.dim() < 1: 

132 raise ValueError("Input tensor must have at least 1 dimension.") 

133 last_dim = shape[-1] 

134 if last_dim % 2 != 0: 

135 raise ValueError( 

136 f"The last dimension of the input tensor must be even, but got {last_dim}." 

137 ) 

138 N_OUT = last_dim // 2 

139 M = input_tensor.numel() // last_dim 

140 if input_tensor.numel() == 0: 

141 output_shape = (*shape[:-1], N_OUT) 

142 return torch.empty( 

143 output_shape, device=input_tensor.device, dtype=input_tensor.dtype 

144 ) 

145 input_2d = input_tensor.contiguous().view(M, last_dim) 

146 output_2d = torch.empty( 

147 (M, N_OUT), device=input_tensor.device, dtype=input_tensor.dtype 

148 ) 

149 grid = lambda META: ( 

150 triton.cdiv(M, META["BLOCK_M"]), 

151 triton.cdiv(N_OUT, META["BLOCK_N"]), 

152 ) 

153 reglu_kernel[grid]( 

154 input_2d, 

155 output_2d, 

156 M, 

157 N_OUT, 

158 input_2d.stride(0), 

159 input_2d.stride(1), 

160 output_2d.stride(0), 

161 output_2d.stride(1), 

162 ) 

163 output_shape = (*shape[:-1], N_OUT) 

164 return output_2d.view(output_shape) 

165 

166 

167def dreglu( 

168 grad_output: torch.Tensor, 

169 input_tensor: torch.Tensor, 

170 quantizer: Optional[Any] = None, 

171) -> torch.Tensor: 

172 shape = input_tensor.shape 

173 if shape[:-1] != grad_output.shape[:-1] or shape[-1] != 2 * grad_output.shape[-1]: 

174 raise ValueError( 

175 f"Shape mismatch: input {shape} vs grad_output {grad_output.shape}" 

176 ) 

177 M = grad_output.numel() // grad_output.shape[-1] 

178 N = grad_output.shape[-1] 

179 grad_output_2d = grad_output.contiguous().view(M, N) 

180 input_2d = input_tensor.contiguous().view(M, 2 * N) 

181 grad_input = torch.empty_like(input_2d) 

182 grid = lambda META: ( 

183 triton.cdiv(M, META["BLOCK_M"]), 

184 triton.cdiv(N, META["BLOCK_N"]), 

185 ) 

186 dreglu_kernel[grid]( 

187 grad_output_2d, 

188 input_2d, 

189 grad_input, 

190 M, 

191 N, 

192 grad_output_2d.stride(0), 

193 grad_output_2d.stride(1), 

194 input_2d.stride(0), 

195 input_2d.stride(1), 

196 grad_input.stride(0), 

197 grad_input.stride(1), 

198 ) 

199 return grad_input.view(shape)