Coverage for src/flag_gems/runtime/backend/_cambricon/ops/mm.py: 0%

97 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14@libentry() 

15@libtuner( 

16 configs=runtime.get_tuned_config("mm"), 

17 key=["M", "N", "K", "stride_am", "stride_bk", "stride_ak", "stride_bn"], 

18 strategy=[ 

19 "align32", 

20 "align32", 

21 "align32", 

22 "align32", 

23 "align32", 

24 "align32", 

25 "align32", 

26 ], 

27) 

28@triton.heuristics(runtime.get_heuristic_config("mm")) 

29@triton.jit 

30def mm_kernel( 

31 A, 

32 B, 

33 C, 

34 M, 

35 N, 

36 K, 

37 stride_am, 

38 stride_ak, 

39 stride_bk, 

40 stride_bn, 

41 stride_cm, 

42 stride_cn, 

43 dot_out_dtype: tl.constexpr, 

44 BLOCK_M: tl.constexpr, 

45 BLOCK_N: tl.constexpr, 

46 BLOCK_K: tl.constexpr, 

47 GROUP_M: tl.constexpr, 

48 SPLIT_K: tl.constexpr, 

49 EVEN_K: tl.constexpr, 

50 UPCAST: tl.constexpr, 

51): 

52 # matrix multiplication 

53 if UPCAST: 

54 pid = tl.program_id(0).to(tl.int64) 

55 pid_z = tl.program_id(1).to(tl.int64) 

56 else: 

57 pid = tl.program_id(0) 

58 pid_z = tl.program_id(1) 

59 grid_m = tl.cdiv(M, BLOCK_M) 

60 grid_n = tl.cdiv(N, BLOCK_N) 

61 # re-order program ID for better L2 performance 

62 width = GROUP_M * grid_n 

63 group_id = pid // width 

64 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

65 pid_m = group_id * GROUP_M + (pid % group_size) 

66 pid_n = (pid % width) // (group_size) 

67 # do matrix multiplication 

68 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

69 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

70 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 

71 # pointers 

72 A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) 

73 B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) 

74 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) 

75 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): 

76 if EVEN_K: 

77 a = tl.load(A, mask=(rm < M)[:, None], other=0.0) 

78 b = tl.load(B, mask=(rn < N)[None, :], other=0.0) 

79 else: 

80 k_remaining = K - k * (BLOCK_K * SPLIT_K) 

81 a = tl.load( 

82 A, mask=(rk[None, :] < k_remaining) & (rm < M)[:, None], other=0.0 

83 ) 

84 b = tl.load( 

85 B, mask=(rk[:, None] < k_remaining) & (rn < N)[None, :], other=0.0 

86 ) 

87 

88 if a.dtype != b.dtype: 

89 a = a.to(C.dtype.element_ty) 

90 b = b.to(C.dtype.element_ty) 

91 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False) 

92 A += BLOCK_K * SPLIT_K * stride_ak 

93 B += BLOCK_K * SPLIT_K * stride_bk 

94 acc = acc.to(C.dtype.element_ty) 

95 # rematerialize rm and rn to save registers 

96 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

97 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

98 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

99 mask = (rm < M)[:, None] & (rn < N)[None, :] 

100 # handles write-back with reduction-splitting 

101 if SPLIT_K == 1: 

102 tl.store(C, acc, mask=mask) 

103 else: 

104 tl.atomic_add(C, acc, mask=mask) 

105 

106 

107_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

108 

109 

110def get_higher_dtype(a, b): 

111 if a is b: 

112 return a 

113 

114 assert a in _ordered_datatypes 

115 assert b in _ordered_datatypes 

116 

117 for d in _ordered_datatypes: 

118 if a is d: 

119 return b 

120 if b is d: 

121 return a 

122 

123 

124def mm(a, b): 

125 logger.debug("GEMS_CAMBRICON MM") 

126 device = a.device 

127 # handle non-contiguous inputs if necessary 

128 if a.stride(0) > 1 and a.stride(1) > 1: 

129 a = a.contiguous() 

130 if b.stride(0) > 1 and b.stride(1) > 1: 

131 b = b.contiguous() 

132 # checks constraints 

133 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

134 M, K = a.shape 

135 _, N = b.shape 

136 # allocates output 

137 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

138 c = torch.empty((M, N), device=device, dtype=c_dtype) 

139 dot_out_dtype = tl.float32 

140 UPCAST = ( 

141 M * max(a.stride(0), c.stride(0)) >= 1 << 31 

142 or N * max(b.stride(1), c.stride(1)) >= 1 << 31 

143 or K * max(a.stride(1), b.stride(0)) >= 1 << 31 

144 ) 

145 # launch kernel 

146 grid = lambda META: ( 

147 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

148 META["SPLIT_K"], 

149 ) 

150 with torch_device_fn.device(a.device): 

151 mm_kernel[grid]( 

152 a, 

153 b, 

154 c, 

155 M, 

156 N, 

157 K, 

158 a.stride(0), 

159 a.stride(1), 

160 b.stride(0), 

161 b.stride(1), 

162 c.stride(0), 

163 c.stride(1), 

164 dot_out_dtype=dot_out_dtype, 

165 GROUP_M=8, 

166 UPCAST=UPCAST, 

167 ) 

168 return c 

169 

170 

171def mm_out(a, b, *, out): 

172 logger.debug("GEMS_CAMBRICON MM_OUT") 

173 # handle non-contiguous inputs if necessary 

174 if a.stride(0) > 1 and a.stride(1) > 1: 

175 a = a.contiguous() 

176 if b.stride(0) > 1 and b.stride(1) > 1: 

177 b = b.contiguous() 

178 # checks constraints 

179 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

180 M, K = a.shape 

181 _, N = b.shape 

182 # allocates output 

183 c = out 

184 dot_out_dtype = tl.float32 

185 UPCAST = ( 

186 M * max(a.stride(0), c.stride(0)) >= 1 << 31 

187 or N * max(b.stride(1), c.stride(1)) >= 1 << 31 

188 or K * max(a.stride(1), b.stride(0)) >= 1 << 31 

189 ) 

190 # launch kernel 

191 grid = lambda META: ( 

192 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

193 META["SPLIT_K"], 

194 ) 

195 with torch_device_fn.device(a.device): 

196 mm_kernel[grid]( 

197 a, 

198 b, 

199 c, 

200 M, 

201 N, 

202 K, 

203 a.stride(0), 

204 a.stride(1), 

205 b.stride(0), 

206 b.stride(1), 

207 c.stride(0), 

208 c.stride(1), 

209 dot_out_dtype=dot_out_dtype, 

210 GROUP_M=8, 

211 UPCAST=UPCAST, 

212 ) 

213 return c