Coverage for src/flag_gems/ops/mm.py: 43%

104 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.ops.mm_streamk import streamk_mm 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.device_info import get_device_capability, get_sm_count 

13 

14CACHE_USAGE_THRESHOLD = 0.8 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19@triton.jit 

20def prev_multiple_of(a, b): 

21 # the largest x<a that x%b ==0 

22 return tl.cdiv(a, b) * b - b 

23 

24 

25@libentry() 

26@libtuner( 

27 configs=runtime.get_tuned_config("mm"), 

28 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides. 

29 key=["M", "N", "K", "stride_am", "stride_bk"], 

30 strategy=["align32", "align32", "align32", "align32", "align32"], 

31 warmup=5, 

32 rep=10, 

33) 

34@triton.jit 

35def mm_kernel_general( 

36 A, 

37 B, 

38 C, 

39 M, 

40 N, 

41 K, 

42 stride_am, 

43 stride_ak, 

44 stride_bk, 

45 stride_bn, 

46 stride_cm, 

47 stride_cn, 

48 BLOCK_M: tl.constexpr, 

49 BLOCK_N: tl.constexpr, 

50 BLOCK_K: tl.constexpr, 

51 GROUP_M: tl.constexpr, 

52): 

53 # matrix multiplication 

54 pid = tle.program_id(0) 

55 grid_m = tl.cdiv(M, BLOCK_M) 

56 grid_n = tl.cdiv(N, BLOCK_N) 

57 # re-order program ID for better L2 performance 

58 width = GROUP_M * grid_n 

59 group_id = pid // width 

60 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

61 pid_m = group_id * GROUP_M + (pid % group_size) 

62 pid_n = (pid % width) // (group_size) 

63 # do matrix multiplication 

64 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

65 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

66 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

67 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

68 rm = rm.to(tl.int64) 

69 rn = rn.to(tl.int64) 

70 prev_multiple = prev_multiple_of(K, BLOCK_K) 

71 

72 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

73 for start_k in range(0, prev_multiple, BLOCK_K): 

74 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

75 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

76 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

77 if a.dtype != b.dtype: 

78 a = a.to(C.dtype.element_ty) 

79 b = b.to(C.dtype.element_ty) 

80 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

81 

82 # loop peeling 

83 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

84 mask_k = rk < K 

85 a = tl.load( 

86 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

87 mask=mask_k[None, :], 

88 other=0.0, 

89 ) 

90 b = tl.load( 

91 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), 

92 mask=mask_k[:, None], 

93 other=0.0, 

94 ) 

95 if a.dtype != b.dtype: 

96 a = a.to(C.dtype.element_ty) 

97 b = b.to(C.dtype.element_ty) 

98 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

99 

100 acc = acc.to(C.dtype.element_ty) 

101 # rematerialize rm and rn to save registers 

102 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

103 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

104 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

105 mask = (rm < M)[:, None] & (rn < N)[None, :] 

106 # handles write-back with reduction-splitting 

107 tl.store(C, acc, mask=mask) 

108 

109 

110_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

111 

112 

113def get_higher_dtype(a, b): 

114 if a is b: 

115 return a 

116 

117 assert a in _ordered_datatypes 

118 assert b in _ordered_datatypes 

119 

120 for d in _ordered_datatypes: 

121 if a is d: 

122 return b 

123 if b is d: 

124 return a 

125 

126 

127def general_mm(a, b, c, M, N, K): 

128 logger.debug( 

129 "GEMS MM, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

130 "[A column-major]: %s, [B column-major]: %s", 

131 M, 

132 N, 

133 K, 

134 a.stride(0) == 1, 

135 b.stride(0) == 1, 

136 ) 

137 grid = lambda META: ( 

138 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

139 ) 

140 with torch_device_fn.device(a.device): 

141 mm_kernel_general[grid]( 

142 a, 

143 b, 

144 c, 

145 M, 

146 N, 

147 K, 

148 a.stride(0), 

149 a.stride(1), 

150 b.stride(0), 

151 b.stride(1), 

152 c.stride(0), 

153 c.stride(1), 

154 GROUP_M=8, 

155 ) 

156 return c 

157 

158 

159def streamk_scenario(a, b, M, N, K): 

160 # TODO: this my change sometime according to the realbenchmark result 

161 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8). 

162 # The optimal settings for other devices need to be determined through real testing. 

163 capability = get_device_capability() 

164 return ( 

165 capability[0] == 8 

166 and a.dtype in [torch.float16, torch.bfloat16] 

167 and b.dtype in [torch.float16, torch.bfloat16] 

168 and a.is_contiguous() 

169 and b.is_contiguous() 

170 and K > M * 5 

171 and K > N * 5 

172 ) 

173 

174 

175def mm(a, b): 

176 device = a.device 

177 # handle non-contiguous inputs if necessary 

178 if a.stride(0) > 1 and a.stride(1) > 1: 

179 a = a.contiguous() 

180 if b.stride(0) > 1 and b.stride(1) > 1: 

181 b = b.contiguous() 

182 # checks constraints 

183 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

184 M, K = a.shape 

185 _, N = b.shape 

186 # allocates output 

187 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

188 c = torch.empty((M, N), device=device, dtype=c_dtype) 

189 # l2_cache_size = get_l2_cache_size() 

190 sm_count = get_sm_count() 

191 if streamk_scenario(a, b, M, N, K): 

192 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count) 

193 else: 

194 return general_mm(a, b, c, M, N, K) 

195 

196 

197def mm_out(a, b, *, out): 

198 # handle non-contiguous inputs if necessary 

199 if a.stride(0) > 1 and a.stride(1) > 1: 

200 a = a.contiguous() 

201 if b.stride(0) > 1 and b.stride(1) > 1: 

202 b = b.contiguous() 

203 # checks constraints 

204 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

205 M, K = a.shape 

206 _, N = b.shape 

207 # l2_cache_size = get_l2_cache_size() 

208 sm_count = get_sm_count() 

209 if streamk_scenario(a, b, M, N, K): 

210 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count) 

211 else: 

212 return general_mm(a, b, out, M, N, K)