Coverage for src/flag_gems/runtime/backend/_mthreads/ops/bmm.py: 0%

137 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13from .utils import create_tma_device_descriptor, should_enable_sqmma 

14 

15logger = logging.getLogger( 

16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

17) 

18 

19 

20@libentry() 

21@libtuner( 

22 configs=runtime.get_tuned_config("bmm"), 

23 key=["M", "N", "K"], 

24 strategy=["log", "log", "log"], 

25) 

26@triton.heuristics(runtime.get_heuristic_config("bmm")) 

27@triton.jit 

28def bmm_kernel( 

29 A, 

30 B, 

31 O, 

32 M, 

33 N, 

34 K, 

35 TILE_M: tl.constexpr, 

36 TILE_N: tl.constexpr, 

37 TILE_K: tl.constexpr, 

38 GROUP_M: tl.constexpr, 

39 DIVISIBLE_M: tl.constexpr, 

40 DIVISIBLE_N: tl.constexpr, 

41 DIVISIBLE_K: tl.constexpr, 

42): 

43 # batch offsets 

44 pid_b = tle.program_id(2) 

45 A += pid_b * M * K 

46 B += pid_b * K * N 

47 O += pid_b * M * N 

48 

49 pidx = tle.program_id(0) 

50 pidy = tle.program_id(1) 

51 

52 if GROUP_M == 1: 

53 pid_m, pid_n = pidx, pidy 

54 else: 

55 # reorder CTAs 

56 gridx = tle.num_programs(0) 

57 gridy = tle.num_programs(1) 

58 pid = pidx + pidy * gridx 

59 

60 num_CTA_per_group = gridy * GROUP_M 

61 

62 group_id = pid // num_CTA_per_group 

63 inner_group_id = pid % num_CTA_per_group 

64 GROUP_SIZE = tl.where( 

65 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

66 ) 

67 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

68 pid_n = inner_group_id // GROUP_SIZE 

69 

70 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

71 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

72 offs_k = tl.arange(0, TILE_K) 

73 

74 if not DIVISIBLE_M: 

75 mask_m = offs_m < M 

76 if not DIVISIBLE_N: 

77 mask_n = offs_n < N 

78 

79 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] 

80 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] 

81 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] 

82 

83 num_iters = tl.cdiv(K, TILE_K) 

84 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

85 for _ in range(num_iters): 

86 if DIVISIBLE_K: 

87 if DIVISIBLE_M: 

88 mask_a = None 

89 else: 

90 mask_a = mask_m[:, None] 

91 if DIVISIBLE_N: 

92 mask_b = None 

93 else: 

94 mask_b = mask_n[None, :] 

95 else: 

96 mask_k = offs_k < K 

97 if DIVISIBLE_M: 

98 mask_a = mask_k[None, :] 

99 else: 

100 mask_a = mask_m[:, None] & mask_k[None, :] 

101 if DIVISIBLE_N: 

102 mask_b = mask_k[:, None] 

103 else: 

104 mask_b = mask_k[:, None] & mask_n[None, :] 

105 

106 a = tl.load(a_ptrs, mask_a) 

107 b = tl.load(b_ptrs, mask_b) 

108 

109 offs_k += TILE_K 

110 a_ptrs += TILE_K 

111 b_ptrs += TILE_K * N 

112 

113 o += tl.dot(a, b, allow_tf32=False) 

114 

115 if DIVISIBLE_M and DIVISIBLE_N: 

116 mask_c = None 

117 elif DIVISIBLE_M and not DIVISIBLE_N: 

118 mask_c = mask_n[None, :] 

119 elif not DIVISIBLE_M and DIVISIBLE_N: 

120 mask_c = mask_m[:, None] 

121 else: 

122 mask_c = mask_m[:, None] & mask_n[None, :] 

123 tl.store(o_ptrs, o, mask_c) 

124 

125 

126def bmm_fma(A, B): 

127 logger.debug("GEMS_MTHREADS BMM(FMA)") 

128 batch, M, K = A.shape 

129 _, _, N = B.shape 

130 A = A.contiguous() 

131 B = B.contiguous() 

132 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

133 

134 grid_fn = lambda meta: ( 

135 triton.cdiv(meta["M"], meta["TILE_M"]), 

136 triton.cdiv(meta["N"], meta["TILE_N"]), 

137 batch, 

138 ) 

139 with torch_device_fn.device(A.device): 

140 bmm_kernel[grid_fn](A, B, out, M, N, K) 

141 return out 

142 

143 

144@triton.jit 

145def bmm_sqmma_kernel( 

146 a_desc_ptr, 

147 b_desc_ptr, 

148 c_desc_ptr, 

149 M, 

150 N, 

151 K, 

152 BLOCK_SIZE_M: tl.constexpr, 

153 BLOCK_SIZE_N: tl.constexpr, 

154 BLOCK_SIZE_K: tl.constexpr, 

155 ab_type: tl.constexpr, 

156 d_type: tl.constexpr, 

157): 

158 pid = tl.program_id(axis=0) 

159 batch_index = tl.program_id(axis=1) 

160 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 

161 pid_m = pid % num_pid_m 

162 pid_n = pid // num_pid_m 

163 offs_am = pid_m * BLOCK_SIZE_M + batch_index * M 

164 offs_bn = pid_n * BLOCK_SIZE_N 

165 offs_ak = 0 

166 offs_bk = batch_index * K 

167 tme_load_type = ab_type 

168 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

169 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

170 a = tl._experimental_descriptor_load( 

171 a_desc_ptr, [offs_am, offs_ak], [BLOCK_SIZE_M, BLOCK_SIZE_K], tme_load_type 

172 ) 

173 b = tl._experimental_descriptor_load( 

174 b_desc_ptr, [offs_bk, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tme_load_type 

175 ) 

176 accumulator = tl.dot(a, b, acc=accumulator) 

177 offs_ak += BLOCK_SIZE_K 

178 offs_bk += BLOCK_SIZE_K 

179 accumulator = accumulator.to(d_type) 

180 tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) 

181 

182 

183def get_triton_type(elem_type): 

184 type_map = { 

185 torch.float16: tl.float16, 

186 torch.bfloat16: tl.bfloat16, 

187 torch.float8_e4m3fn: tl.float8e4nv, 

188 } 

189 return type_map.get(elem_type, None) 

190 

191 

192def bmm_sqmma( 

193 A, B, elem_type, batch, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages 

194): 

195 device = "musa" 

196 ab_type = elem_type 

197 c_type = elem_type if (elem_type != torch.bfloat16) else torch.float16 

198 C = torch.empty((batch, M, N), dtype=torch.float16, device=device).to(c_type) 

199 desc_a = create_tma_device_descriptor( 

200 A.reshape(batch * M, K), BLOCK_M, BLOCK_K, device 

201 ) 

202 desc_b = create_tma_device_descriptor( 

203 B.reshape(batch * K, N), BLOCK_K, BLOCK_N, device 

204 ) 

205 desc_c = create_tma_device_descriptor( 

206 C.reshape(batch * M, N), BLOCK_M, BLOCK_N, device 

207 ) 

208 bmm_sqmma_kernel[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), batch, 1)]( 

209 desc_a, 

210 desc_b, 

211 desc_c, 

212 M, 

213 N, 

214 K, 

215 BLOCK_M, 

216 BLOCK_N, 

217 BLOCK_K, 

218 get_triton_type(ab_type), 

219 get_triton_type(c_type), 

220 num_warps=num_warps, 

221 num_stages=num_stages, 

222 ) 

223 return C 

224 

225 

226def bmm(a, b): 

227 a_dtype = a.dtype 

228 b_dtype = b.dtype 

229 batch, M, K = a.shape 

230 _, _, N = b.shape 

231 use_sqmma = should_enable_sqmma(a_dtype, b_dtype, M, N, K) 

232 if use_sqmma: 

233 BLOCK_M = 128 

234 BLOCK_N = BLOCK_M 

235 BLOCK_K = 64 

236 num_warps = 16 if BLOCK_M == 256 else 4 

237 num_stages = 1 

238 return bmm_sqmma( 

239 a, 

240 b, 

241 a_dtype, 

242 batch, 

243 M, 

244 N, 

245 K, 

246 BLOCK_M, 

247 BLOCK_N, 

248 BLOCK_K, 

249 num_warps, 

250 num_stages, 

251 ) 

252 else: 

253 enable_sqmma = os.environ.pop("MUSA_ENABLE_SQMMA", None) 

254 result = bmm_fma(a, b) 

255 if enable_sqmma: 

256 os.environ["MUSA_ENABLE_SQMMA"] = enable_sqmma 

257 return result