Coverage for src/flag_gems/ops/bmm.py: 38%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@libtuner( 

17 configs=runtime.get_tuned_config("bmm"), 

18 key=["M", "N", "K", "stride_am", "stride_bk"], 

19 strategy=["log", "log", "log", "align32", "align32"], 

20) 

21@triton.heuristics(runtime.get_heuristic_config("bmm")) 

22@triton.jit 

23def bmm_kernel( 

24 A, 

25 B, 

26 O, 

27 M, 

28 N, 

29 K, 

30 stride_ab, 

31 stride_am, 

32 stride_ak, 

33 stride_bb, 

34 stride_bk, 

35 stride_bn, 

36 stride_ob, 

37 stride_om, 

38 stride_on, 

39 TILE_M: tl.constexpr, 

40 TILE_N: tl.constexpr, 

41 TILE_K: tl.constexpr, 

42 GROUP_M: tl.constexpr, 

43 DIVISIBLE_M: tl.constexpr, 

44 DIVISIBLE_N: tl.constexpr, 

45 DIVISIBLE_K: tl.constexpr, 

46): 

47 # batch offsets 

48 pid_b = tle.program_id(2) 

49 A += pid_b * stride_ab 

50 B += pid_b * stride_bb 

51 O += pid_b * stride_ob 

52 

53 pidx = tle.program_id(0) 

54 pidy = tle.program_id(1) 

55 

56 if GROUP_M == 1: 

57 pid_m, pid_n = pidx, pidy 

58 else: 

59 # reorder CTAs 

60 gridx = tle.num_programs(0) 

61 gridy = tle.num_programs(1) 

62 pid = pidx + pidy * gridx 

63 

64 num_CTA_per_group = gridy * GROUP_M 

65 

66 group_id = pid // num_CTA_per_group 

67 inner_group_id = pid % num_CTA_per_group 

68 GROUP_SIZE = tl.where( 

69 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

70 ) 

71 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

72 pid_n = inner_group_id // GROUP_SIZE 

73 

74 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

75 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

76 offs_k = tl.arange(0, TILE_K) 

77 

78 if not DIVISIBLE_M: 

79 mask_m = offs_m < M 

80 if not DIVISIBLE_N: 

81 mask_n = offs_n < N 

82 

83 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

84 b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn 

85 o_ptrs = O + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on 

86 

87 num_iters = tl.cdiv(K, TILE_K) 

88 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

89 for _ in range(num_iters): 

90 if DIVISIBLE_K: 

91 if DIVISIBLE_M: 

92 mask_a = None 

93 else: 

94 mask_a = mask_m[:, None] 

95 if DIVISIBLE_N: 

96 mask_b = None 

97 else: 

98 mask_b = mask_n[None, :] 

99 else: 

100 mask_k = offs_k < K 

101 if DIVISIBLE_M: 

102 mask_a = mask_k[None, :] 

103 else: 

104 mask_a = mask_m[:, None] & mask_k[None, :] 

105 if DIVISIBLE_N: 

106 mask_b = mask_k[:, None] 

107 else: 

108 mask_b = mask_k[:, None] & mask_n[None, :] 

109 

110 a = tl.load(a_ptrs, mask_a) 

111 b = tl.load(b_ptrs, mask_b) 

112 

113 offs_k += TILE_K 

114 a_ptrs += TILE_K * stride_ak 

115 b_ptrs += TILE_K * stride_bk 

116 

117 o += tl.dot(a, b, allow_tf32=False) 

118 

119 if DIVISIBLE_M and DIVISIBLE_N: 

120 mask_c = None 

121 elif DIVISIBLE_M and not DIVISIBLE_N: 

122 mask_c = mask_n[None, :] 

123 elif not DIVISIBLE_M and DIVISIBLE_N: 

124 mask_c = mask_m[:, None] 

125 else: 

126 mask_c = mask_m[:, None] & mask_n[None, :] 

127 tl.store(o_ptrs, o, mask_c) 

128 

129 

130def bmm(A, B): 

131 logger.debug("GEMS BMM") 

132 assert A.shape[0] == B.shape[0], "Batch dim mismatch" 

133 assert A.shape[2] == B.shape[1], "K dim mismatch" 

134 batch, M, K = A.shape 

135 _, _, N = B.shape 

136 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

137 

138 grid_fn = lambda meta: ( 

139 triton.cdiv(meta["M"], meta["TILE_M"]), 

140 triton.cdiv(meta["N"], meta["TILE_N"]), 

141 batch, 

142 ) 

143 with torch_device_fn.device(A.device): 

144 bmm_kernel[grid_fn]( 

145 A, 

146 B, 

147 out, 

148 M, 

149 N, 

150 K, 

151 A.stride(0), 

152 A.stride(1), 

153 A.stride(2), 

154 B.stride(0), 

155 B.stride(1), 

156 B.stride(2), 

157 out.stride(0), 

158 out.stride(1), 

159 out.stride(2), 

160 ) 

161 return out 

162 

163 

164def bmm_out(A, B, out): 

165 logger.debug("GEMS BMM_OUT") 

166 assert A.shape[0] == B.shape[0] == out.shape[0], "Batch dim mismatch" 

167 assert A.shape[2] == B.shape[1], "K dim mismatch" 

168 batch, M, K = A.shape 

169 _, _, N = B.shape 

170 

171 grid_fn = lambda meta: ( 

172 triton.cdiv(meta["M"], meta["TILE_M"]), 

173 triton.cdiv(meta["N"], meta["TILE_N"]), 

174 batch, 

175 ) 

176 with torch_device_fn.device(A.device): 

177 bmm_kernel[grid_fn]( 

178 A, 

179 B, 

180 out, 

181 M, 

182 N, 

183 K, 

184 A.stride(0), 

185 A.stride(1), 

186 A.stride(2), 

187 B.stride(0), 

188 B.stride(1), 

189 B.stride(2), 

190 out.stride(0), 

191 out.stride(1), 

192 out.stride(2), 

193 ) 

194 return out