Coverage for src/flag_gems/runtime/backend/_hygon/ops/mm.py: 0%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@triton.jit 

16def prev_multiple_of(a, b): 

17 # the largest x<a that x%b ==0 

18 return tl.cdiv(a, b) * b - b 

19 

20 

21@libentry() 

22@libtuner( 

23 configs=runtime.get_tuned_config("mm"), 

24 key=["M", "N", "K"], 

25 strategy=["log", "log", "log"], 

26) 

27@triton.jit 

28def mm_kernel( 

29 a_ptr, 

30 b_ptr, 

31 c_ptr, 

32 M, 

33 N, 

34 K, 

35 stride_am, 

36 stride_ak, 

37 stride_bk, 

38 stride_bn, 

39 stride_cm, 

40 stride_cn, 

41 BLOCK_M: tl.constexpr, 

42 BLOCK_N: tl.constexpr, 

43 BLOCK_K: tl.constexpr, 

44 GROUP_M: tl.constexpr, 

45): 

46 pid = tle.program_id(0) 

47 

48 # -------------------------- 

49 # match naming: num_pid_m, num_pid_n 

50 # -------------------------- 

51 num_pid_m = tl.cdiv(M, BLOCK_M) 

52 num_pid_n = tl.cdiv(N, BLOCK_N) 

53 

54 # reorder for L2 

55 num_pid_in_group = GROUP_M * num_pid_n 

56 group_id = pid // num_pid_in_group 

57 group_size_m = min(num_pid_m - group_id * GROUP_M, GROUP_M) 

58 

59 pid_m = group_id * GROUP_M + (pid % group_size_m) 

60 pid_n = (pid % num_pid_in_group) // group_size_m 

61 

62 # -------------------------- 

63 # match naming: offs_am, offs_bn, offs_k 

64 # -------------------------- 

65 offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

66 offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

67 offs_k = tl.arange(0, BLOCK_K) 

68 

69 # contiguous aligned offsets (ram/rbn → offs_am/offs_bn) 

70 offs_am_cont = tl.max_contiguous(tl.multiple_of(offs_am % M, BLOCK_M), BLOCK_M) 

71 offs_bn_cont = tl.max_contiguous(tl.multiple_of(offs_bn % N, BLOCK_N), BLOCK_N) 

72 

73 # previous K multiple 

74 # prev_k_mult = prev_multiple_of(K, BLOCK_K) 

75 prev_k_mult = tl.cdiv(K, BLOCK_K) * BLOCK_K - BLOCK_K 

76 

77 # accumulator 

78 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

79 

80 # -------------------------- 

81 # main K loop 

82 # -------------------------- 

83 for start_k in range(0, prev_k_mult, BLOCK_K): 

84 rk = start_k + offs_k 

85 

86 a = tl.load( 

87 a_ptr + (offs_am_cont[:, None] * stride_am + rk[None, :] * stride_ak) 

88 ) 

89 b = tl.load( 

90 b_ptr + (rk[:, None] * stride_bk + offs_bn_cont[None, :] * stride_bn) 

91 ) 

92 

93 if a.dtype != b.dtype: 

94 a = a.to(c_ptr.dtype.element_ty) 

95 b = b.to(c_ptr.dtype.element_ty) 

96 

97 accumulator += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

98 

99 # -------------------------- 

100 # loop peel 

101 # -------------------------- 

102 rk = prev_k_mult + offs_k 

103 mask_k = rk < K 

104 

105 a = tl.load( 

106 a_ptr + (offs_am_cont[:, None] * stride_am + rk[None, :] * stride_ak), 

107 mask=mask_k[None, :], 

108 ) 

109 b = tl.load( 

110 b_ptr + (rk[:, None] * stride_bk + offs_bn_cont[None, :] * stride_bn), 

111 mask=mask_k[:, None], 

112 ) 

113 

114 if a.dtype != b.dtype: 

115 a = a.to(c_ptr.dtype.element_ty) 

116 b = b.to(c_ptr.dtype.element_ty) 

117 

118 accumulator += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

119 

120 # cast to output dtype 

121 accumulator = accumulator.to(c_ptr.dtype.element_ty) 

122 

123 # -------------------------- 

124 # rematerialize offsets for store 

125 # (match naming: offs_cm, offs_cn) 

126 # -------------------------- 

127 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

128 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

129 

130 c_ptr = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) 

131 mask_store = (offs_cm < M)[:, None] & (offs_cn < N)[None, :] 

132 

133 tl.store(c_ptr, accumulator, mask=mask_store) 

134 

135 

136_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

137 

138 

139def get_higher_dtype(a, b): 

140 if a is b: 

141 return a 

142 

143 assert a in _ordered_datatypes 

144 assert b in _ordered_datatypes 

145 

146 for d in _ordered_datatypes: 

147 if a is d: 

148 return b 

149 if b is d: 

150 return a 

151 

152 

153def mm(a, b): 

154 logger.debug("GEMS MM") 

155 device = a.device 

156 # handle non-contiguous inputs if necessary 

157 if a.stride(0) > 1 and a.stride(1) > 1: 

158 a = a.contiguous() 

159 if b.stride(0) > 1 and b.stride(1) > 1: 

160 b = b.contiguous() 

161 # checks constraints 

162 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

163 M, K = a.shape 

164 _, N = b.shape 

165 # allocates output 

166 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

167 c = torch.empty((M, N), device=device, dtype=c_dtype) 

168 # launch kernel 

169 grid = lambda META: ( 

170 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

171 ) 

172 with torch_device_fn.device(a.device): 

173 mm_kernel[grid]( 

174 a, 

175 b, 

176 c, 

177 M, 

178 N, 

179 K, 

180 a.stride(0), 

181 a.stride(1), 

182 b.stride(0), 

183 b.stride(1), 

184 c.stride(0), 

185 c.stride(1), 

186 GROUP_M=8, 

187 ) 

188 return c 

189 

190 

191def mm_out(a, b, *, out): 

192 logger.debug("GEMS MM_OUT") 

193 # handle non-contiguous inputs if necessary 

194 if a.stride(0) > 1 and a.stride(1) > 1: 

195 a = a.contiguous() 

196 if b.stride(0) > 1 and b.stride(1) > 1: 

197 b = b.contiguous() 

198 # checks constraints 

199 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

200 M, K = a.shape 

201 _, N = b.shape 

202 # allocates output 

203 c = out 

204 # launch kernel 

205 grid = lambda META: ( 

206 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

207 ) 

208 with torch_device_fn.device(a.device): 

209 mm_kernel[grid]( 

210 a, 

211 b, 

212 c, 

213 M, 

214 N, 

215 K, 

216 a.stride(0), 

217 a.stride(1), 

218 b.stride(0), 

219 b.stride(1), 

220 c.stride(0), 

221 c.stride(1), 

222 GROUP_M=8, 

223 ) 

224 return c