Coverage for src/flag_gems/runtime/backend/_metax/ops/mm.py: 0%

116 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems." + __name__) 

14 

15 

16@libentry() 

17@libtuner( 

18 configs=runtime.get_tuned_config("mm"), 

19 key=["M", "N", "K"], 

20) 

21@triton.heuristics(runtime.get_heuristic_config("mm")) 

22@triton.heuristics( 

23 { 

24 "UPGRADE": lambda args: math.ceil( 

25 (args["M"] * args["N"]) / (args["BLOCK_M"] * args["BLOCK_N"]) 

26 ).bit_length() 

27 > 31, 

28 } 

29) 

30@triton.heuristics( 

31 { 

32 "UPGRADE_A_OFFS": lambda args: math.ceil(args["M"] * args["K"]).bit_length() 

33 > 31, 

34 } 

35) 

36@triton.heuristics( 

37 { 

38 "UPGRADE_B_OFFS": lambda args: math.ceil(args["K"] * args["N"]).bit_length() 

39 > 31, 

40 } 

41) 

42@triton.heuristics( 

43 { 

44 "UPGRADE_C_OFFS": lambda args: math.ceil(args["M"] * args["N"]).bit_length() 

45 > 31, 

46 } 

47) 

48@triton.jit 

49def mm_kernel( 

50 A, 

51 B, 

52 C, 

53 M, 

54 N, 

55 K, 

56 stride_am, 

57 stride_ak, 

58 stride_bk, 

59 stride_bn, 

60 stride_cm, 

61 stride_cn, 

62 dot_out_dtype: tl.constexpr, 

63 BLOCK_M: tl.constexpr, 

64 BLOCK_N: tl.constexpr, 

65 BLOCK_K: tl.constexpr, 

66 GROUP_M: tl.constexpr, 

67 SPLIT_K: tl.constexpr, 

68 EVEN_K: tl.constexpr, 

69 UPGRADE: tl.constexpr, 

70 UPGRADE_A_OFFS: tl.constexpr, 

71 UPGRADE_B_OFFS: tl.constexpr, 

72 UPGRADE_C_OFFS: tl.constexpr, 

73): 

74 # matrix multiplication 

75 if UPGRADE: 

76 pid = tle.program_id(0) 

77 pid_z = tle.program_id(1) 

78 else: 

79 pid = tl.program_id(0) 

80 pid_z = tl.program_id(1) 

81 grid_m = tl.cdiv(M, BLOCK_M) 

82 grid_n = tl.cdiv(N, BLOCK_N) 

83 # re-order program ID for better L2 performance 

84 width = GROUP_M * grid_n 

85 group_id = pid // width 

86 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

87 pid_m = group_id * GROUP_M + (pid % group_size) 

88 pid_n = (pid % width) // (group_size) 

89 # do matrix multiplication 

90 if UPGRADE_A_OFFS: 

91 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

92 ram = (tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)).to(tl.int64) 

93 else: 

94 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

95 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 

96 if UPGRADE_B_OFFS: 

97 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

98 rbn = (tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)).to(tl.int64) 

99 else: 

100 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

101 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 

102 

103 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 

104 # pointers 

105 A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 

106 B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 

107 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) 

108 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): 

109 if EVEN_K: 

110 a = tl.load(A) 

111 b = tl.load(B) 

112 else: 

113 k_remaining = K - k * (BLOCK_K * SPLIT_K) 

114 _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) 

115 a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) 

116 b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) 

117 if a.dtype != b.dtype: 

118 a = a.to(C.dtype.element_ty) 

119 b = b.to(C.dtype.element_ty) 

120 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False) 

121 A += BLOCK_K * SPLIT_K * stride_ak 

122 B += BLOCK_K * SPLIT_K * stride_bk 

123 acc = acc.to(C.dtype.element_ty) 

124 # rematerialize rm and rn to save registers 

125 if UPGRADE_C_OFFS: 

126 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

127 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

128 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn).to(tl.int64) 

129 else: 

130 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

131 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

132 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

133 mask = (rm < M)[:, None] & (rn < N)[None, :] 

134 # handles write-back with reduction-splitting 

135 if SPLIT_K == 1: 

136 tl.store(C, acc, mask=mask) 

137 else: 

138 tl.atomic_add(C, acc, mask=mask) 

139 

140 

141_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

142 

143 

144def get_higher_dtype(a, b): 

145 if a is b: 

146 return a 

147 

148 assert a in _ordered_datatypes 

149 assert b in _ordered_datatypes 

150 

151 for d in _ordered_datatypes: 

152 if a is d: 

153 return b 

154 if b is d: 

155 return a 

156 

157 

158def mm(a, b): 

159 logger.debug("METAX GEMS MM") 

160 device = a.device 

161 # handle non-contiguous inputs if necessary 

162 if a.stride(0) > 1 and a.stride(1) > 1: 

163 a = a.contiguous() 

164 if b.stride(0) > 1 and b.stride(1) > 1: 

165 b = b.contiguous() 

166 # checks constraints 

167 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

168 M, K = a.shape 

169 _, N = b.shape 

170 # allocates output 

171 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

172 c = torch.empty((M, N), device=device, dtype=c_dtype) 

173 dot_out_dtype = tl.float32 

174 logger.debug( 

175 "METAX GEMS MM, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

176 "[A column-major]: %s, [B column-major]: %s", 

177 M, 

178 N, 

179 K, 

180 a.stride(0) == 1, 

181 b.stride(0) == 1, 

182 ) 

183 # launch kernel 

184 grid = lambda META: ( 

185 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

186 META["SPLIT_K"], 

187 ) 

188 with torch_device_fn.device(a.device): 

189 mm_kernel[grid]( 

190 a, 

191 b, 

192 c, 

193 M, 

194 N, 

195 K, 

196 a.stride(0), 

197 a.stride(1), 

198 b.stride(0), 

199 b.stride(1), 

200 c.stride(0), 

201 c.stride(1), 

202 dot_out_dtype=dot_out_dtype, 

203 GROUP_M=8, 

204 ) 

205 return c 

206 

207 

208def mm_out(a, b, *, out): 

209 logger.debug("METAX GEMS MM_OUT") 

210 # handle non-contiguous inputs if necessary 

211 if a.stride(0) > 1 and a.stride(1) > 1: 

212 a = a.contiguous() 

213 if b.stride(0) > 1 and b.stride(1) > 1: 

214 b = b.contiguous() 

215 # checks constraints 

216 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

217 M, K = a.shape 

218 _, N = b.shape 

219 # allocates output 

220 c = out 

221 dot_out_dtype = tl.float32 

222 logger.debug( 

223 "METAX GEMS MM, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

224 "[A column-major]: %s, [B column-major]: %s", 

225 M, 

226 N, 

227 K, 

228 a.stride(0) == 1, 

229 b.stride(0) == 1, 

230 ) 

231 # launch kernel 

232 grid = lambda META: ( 

233 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

234 META["SPLIT_K"], 

235 ) 

236 with torch_device_fn.device(a.device): 

237 mm_kernel[grid]( 

238 a, 

239 b, 

240 c, 

241 M, 

242 N, 

243 K, 

244 a.stride(0), 

245 a.stride(1), 

246 b.stride(0), 

247 b.stride(1), 

248 c.stride(0), 

249 c.stride(1), 

250 dot_out_dtype=dot_out_dtype, 

251 GROUP_M=8, 

252 ) 

253 return c