Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mm.py: 0%

108 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16def heur_split_k(args): 

17 return 1 

18 

19 

20def heur_even_k(args): 

21 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

22 

23 

24def heur_group_m(args): 

25 if args["BLOCK_M"] > args["BLOCK_N"]: 

26 return 1 

27 else: 

28 return (args["M"] + args["BLOCK_M"] - 1) // args["BLOCK_M"] 

29 

30 

31autotune_decorator = triton.autotune( 

32 configs=[], 

33 generate_configs="mm", 

34 key=["M", "N", "K"], 

35) 

36 

37 

38KLX_USE_AUTOTUNE = os.environ.get("KLX_USE_AUTOTUNE", "1") == "1" 

39 

40if not KLX_USE_AUTOTUNE: 

41 autotune_decorator = triton.autotune( 

42 configs=[ 

43 triton.Config({"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 256}), 

44 ], 

45 key=["M", "N", "K"], 

46 ) 

47 

48 

49@libentry() 

50@autotune_decorator 

51@triton.heuristics( 

52 { 

53 "SPLIT_K": heur_split_k, 

54 "EVEN_K": heur_even_k, 

55 "GROUP_M": heur_group_m, 

56 } 

57) 

58@triton.jit 

59def mm_kernel( 

60 A, 

61 B, 

62 C, 

63 M, 

64 N, 

65 K, 

66 stride_am, 

67 stride_ak, 

68 stride_bk, 

69 stride_bn, 

70 stride_cm, 

71 stride_cn, 

72 dot_out_dtype: tl.constexpr, 

73 BLOCK_M: tl.constexpr, 

74 BLOCK_N: tl.constexpr, 

75 BLOCK_K: tl.constexpr, 

76 GROUP_M: tl.constexpr, 

77 SPLIT_K: tl.constexpr, 

78 EVEN_K: tl.constexpr, 

79): 

80 # matrix multiplication 

81 pid = tle.program_id(0) 

82 pid_z = tle.program_id(1) 

83 grid_m = tl.cdiv(M, BLOCK_M) 

84 grid_n = tl.cdiv(N, BLOCK_N) 

85 # re-order program ID for better L2 performance 

86 width = GROUP_M * grid_n 

87 group_id = pid // width 

88 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

89 pid_m = group_id * GROUP_M + (pid % group_size) 

90 pid_n = (pid % width) // (group_size) 

91 # do matrix multiplication 

92 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

93 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

94 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 

95 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 

96 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 

97 # pointers 

98 A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 

99 B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 

100 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) 

101 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): 

102 if EVEN_K: 

103 a = tl.load(A) 

104 b = tl.load(B) 

105 else: 

106 k_remaining = K - k * (BLOCK_K * SPLIT_K) 

107 _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) 

108 a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) 

109 b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) 

110 if a.dtype != b.dtype: 

111 a = a.to(C.dtype.element_ty) 

112 b = b.to(C.dtype.element_ty) 

113 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False) 

114 A += BLOCK_K * SPLIT_K * stride_ak 

115 B += BLOCK_K * SPLIT_K * stride_bk 

116 acc = acc.to(C.dtype.element_ty) 

117 # rematerialize rm and rn to save registers 

118 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

119 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

120 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

121 mask = (rm < M)[:, None] & (rn < N)[None, :] 

122 # handles write-back with reduction-splitting 

123 if SPLIT_K == 1: 

124 tl.store(C, acc, mask=mask) 

125 else: 

126 tl.atomic_add(C, acc, mask=mask) 

127 

128 

129_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

130 

131 

132def get_higher_dtype(a, b): 

133 if a is b: 

134 return a 

135 

136 assert a in _ordered_datatypes 

137 assert b in _ordered_datatypes 

138 

139 for d in _ordered_datatypes: 

140 if a is d: 

141 return b 

142 if b is d: 

143 return a 

144 

145 

146def mm(a, b): 

147 logger.debug("GEMS MM") 

148 device = a.device 

149 # handle non-contiguous inputs if necessary 

150 if a.stride(0) > 1 and a.stride(1) > 1: 

151 a = a.contiguous() 

152 if b.stride(0) > 1 and b.stride(1) > 1: 

153 b = b.contiguous() 

154 # checks constraints 

155 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

156 M, K = a.shape 

157 _, N = b.shape 

158 # allocates output 

159 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

160 c = torch.empty((M, N), device=device, dtype=c_dtype) 

161 dot_out_dtype = tl.float32 

162 # launch kernel 

163 grid = lambda META: ( 

164 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

165 META["SPLIT_K"], 

166 ) 

167 with torch_device_fn.device(a.device): 

168 mm_kernel[grid]( 

169 a, 

170 b, 

171 c, 

172 M, 

173 N, 

174 K, 

175 a.stride(0), 

176 a.stride(1), 

177 b.stride(0), 

178 b.stride(1), 

179 c.stride(0), 

180 c.stride(1), 

181 dot_out_dtype=dot_out_dtype, 

182 ) 

183 return c 

184 

185 

186def mm_out(a, b, *, out): 

187 logger.debug("GEMS MM_OUT") 

188 # handle non-contiguous inputs if necessary 

189 if a.stride(0) > 1 and a.stride(1) > 1: 

190 a = a.contiguous() 

191 if b.stride(0) > 1 and b.stride(1) > 1: 

192 b = b.contiguous() 

193 # checks constraints 

194 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

195 M, K = a.shape 

196 _, N = b.shape 

197 # allocates output 

198 c = out 

199 dot_out_dtype = tl.float32 

200 # launch kernel 

201 grid = lambda META: ( 

202 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

203 META["SPLIT_K"], 

204 ) 

205 with torch_device_fn.device(a.device): 

206 mm_kernel[grid]( 

207 a, 

208 b, 

209 c, 

210 M, 

211 N, 

212 K, 

213 a.stride(0), 

214 a.stride(1), 

215 b.stride(0), 

216 b.stride(1), 

217 c.stride(0), 

218 c.stride(1), 

219 dot_out_dtype=dot_out_dtype, 

220 ) 

221 return c