Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/mm.py: 0%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@triton.jit 

16def prev_multiple_of(a, b): 

17 # the largest x<a that x%b ==0 

18 return tl.cdiv(a, b) * b - b 

19 

20 

21@libentry() 

22@libtuner( 

23 configs=runtime.get_tuned_config("mm"), 

24 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides. 

25 key=["M", "N", "K", "stride_am", "stride_bk"], 

26 strategy=["align32", "align32", "align32", "align32", "align32"], 

27 warmup=5, 

28 rep=5, 

29) 

30@triton.jit 

31def mm_kernel_general( 

32 A, 

33 B, 

34 C, 

35 M, 

36 N, 

37 K, 

38 stride_am, 

39 stride_ak, 

40 stride_bk, 

41 stride_bn, 

42 stride_cm, 

43 stride_cn, 

44 BLOCK_M: tl.constexpr, 

45 BLOCK_N: tl.constexpr, 

46 BLOCK_K: tl.constexpr, 

47 GROUP_M: tl.constexpr, 

48): 

49 # matrix multiplication 

50 pid = tle.program_id(0) 

51 grid_m = tl.cdiv(M, BLOCK_M) 

52 grid_n = tl.cdiv(N, BLOCK_N) 

53 # re-order program ID for better L2 performance 

54 width = GROUP_M * grid_n 

55 group_id = pid // width 

56 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

57 pid_m = group_id * GROUP_M + (pid % group_size) 

58 pid_n = (pid % width) // (group_size) 

59 # do matrix multiplication 

60 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

61 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

62 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

63 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

64 rm = rm.to(tl.int64) 

65 rn = rn.to(tl.int64) 

66 prev_multiple = prev_multiple_of(K, BLOCK_K) 

67 

68 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

69 for start_k in range(0, prev_multiple, BLOCK_K): 

70 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

71 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

72 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

73 if a.dtype != b.dtype: 

74 a = a.to(C.dtype.element_ty) 

75 b = b.to(C.dtype.element_ty) 

76 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

77 

78 # loop peeling 

79 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

80 mask_k = rk < K 

81 a = tl.load( 

82 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

83 mask=mask_k[None, :], 

84 other=0.0, 

85 ) 

86 b = tl.load( 

87 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), 

88 mask=mask_k[:, None], 

89 other=0.0, 

90 ) 

91 if a.dtype != b.dtype: 

92 a = a.to(C.dtype.element_ty) 

93 b = b.to(C.dtype.element_ty) 

94 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

95 

96 acc = acc.to(C.dtype.element_ty) 

97 # rematerialize rm and rn to save registers 

98 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

99 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

100 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

101 mask = (rm < M)[:, None] & (rn < N)[None, :] 

102 # handles write-back with reduction-splitting 

103 tl.store(C, acc, mask=mask) 

104 

105 

106_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

107 

108 

109def get_higher_dtype(a, b): 

110 if a is b: 

111 return a 

112 

113 assert a in _ordered_datatypes 

114 assert b in _ordered_datatypes 

115 

116 for d in _ordered_datatypes: 

117 if a is d: 

118 return b 

119 if b is d: 

120 return a 

121 

122 

123def general_mm(a, b, c, M, N, K): 

124 logger.debug( 

125 "GEMS MM, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

126 "[A column-major]: %s, [B column-major]: %s", 

127 M, 

128 N, 

129 K, 

130 a.stride(0) == 1, 

131 b.stride(0) == 1, 

132 ) 

133 grid = lambda META: ( 

134 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

135 ) 

136 with torch_device_fn.device(a.device): 

137 mm_kernel_general[grid]( 

138 a, 

139 b, 

140 c, 

141 M, 

142 N, 

143 K, 

144 a.stride(0), 

145 a.stride(1), 

146 b.stride(0), 

147 b.stride(1), 

148 c.stride(0), 

149 c.stride(1), 

150 GROUP_M=8, 

151 ) 

152 return c 

153 

154 

155def mm(a, b): 

156 logger.debug("GEMS_TSINGMICRO mm") 

157 device = a.device 

158 # handle non-contiguous inputs if necessary 

159 if a.stride(0) > 1 and a.stride(1) > 1: 

160 a = a.contiguous() 

161 if b.stride(0) > 1 and b.stride(1) > 1: 

162 b = b.contiguous() 

163 # checks constraints 

164 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

165 M, K = a.shape 

166 _, N = b.shape 

167 # allocates output 

168 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

169 c = torch.empty((M, N), device=device, dtype=c_dtype) 

170 return general_mm(a, b, c, M, N, K) 

171 

172 

173def mm_out(a, b, *, out): 

174 logger.debug("GEMS_TSINGMICRO mm_out") 

175 # handle non-contiguous inputs if necessary 

176 if a.stride(0) > 1 and a.stride(1) > 1: 

177 a = a.contiguous() 

178 if b.stride(0) > 1 and b.stride(1) > 1: 

179 b = b.contiguous() 

180 # checks constraints 

181 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

182 M, K = a.shape 

183 _, N = b.shape 

184 return general_mm(a, b, out, M, N, K)