Coverage for src/flag_gems/runtime/backend/_cambricon/ops/bmm.py: 0%

92 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14@libentry() 

15@libtuner( 

16 configs=runtime.get_tuned_config("bmm"), 

17 key=["M", "N", "K", "stride_am", "stride_bk"], 

18 strategy=["log", "log", "log", "align32", "align32"], 

19) 

20@triton.heuristics(runtime.get_heuristic_config("bmm")) 

21@triton.jit 

22def bmm_kernel( 

23 A, 

24 B, 

25 O, 

26 M, 

27 N, 

28 K, 

29 stride_ab, 

30 stride_am, 

31 stride_ak, 

32 stride_bb, 

33 stride_bk, 

34 stride_bn, 

35 stride_ob, 

36 stride_om, 

37 stride_on, 

38 TILE_M: tl.constexpr, 

39 TILE_N: tl.constexpr, 

40 TILE_K: tl.constexpr, 

41 GROUP_M: tl.constexpr, 

42 DIVISIBLE_M: tl.constexpr, 

43 DIVISIBLE_N: tl.constexpr, 

44 DIVISIBLE_K: tl.constexpr, 

45): 

46 # batch offsets 

47 pid_b = tl.program_id(2) 

48 A += pid_b * stride_ab 

49 B += pid_b * stride_bb 

50 O += pid_b * stride_ob 

51 

52 pidx = tl.program_id(0) 

53 pidy = tl.program_id(1) 

54 

55 if GROUP_M == 1: 

56 pid_m, pid_n = pidx, pidy 

57 else: 

58 # reorder CTAs 

59 gridx = tl.num_programs(0) 

60 gridy = tl.num_programs(1) 

61 pid = pidx + pidy * gridx 

62 

63 num_CTA_per_group = gridy * GROUP_M 

64 

65 group_id = pid // num_CTA_per_group 

66 inner_group_id = pid % num_CTA_per_group 

67 GROUP_SIZE = tl.where( 

68 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

69 ) 

70 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

71 pid_n = inner_group_id // GROUP_SIZE 

72 

73 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

74 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

75 offs_k = tl.arange(0, TILE_K) 

76 

77 if not DIVISIBLE_M: 

78 mask_m = offs_m < M 

79 if not DIVISIBLE_N: 

80 mask_n = offs_n < N 

81 

82 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

83 b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn 

84 o_ptrs = O + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on 

85 

86 num_iters = tl.cdiv(K, TILE_K) 

87 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

88 for _ in range(num_iters): 

89 if DIVISIBLE_K: 

90 if DIVISIBLE_M: 

91 mask_a = None 

92 else: 

93 mask_a = mask_m[:, None] 

94 if DIVISIBLE_N: 

95 mask_b = None 

96 else: 

97 mask_b = mask_n[None, :] 

98 else: 

99 mask_k = offs_k < K 

100 if DIVISIBLE_M: 

101 mask_a = mask_k[None, :] 

102 else: 

103 mask_a = mask_m[:, None] & mask_k[None, :] 

104 if DIVISIBLE_N: 

105 mask_b = mask_k[:, None] 

106 else: 

107 mask_b = mask_k[:, None] & mask_n[None, :] 

108 

109 a = tl.load(a_ptrs, mask_a) 

110 b = tl.load(b_ptrs, mask_b) 

111 

112 offs_k += TILE_K 

113 a_ptrs += TILE_K * stride_ak 

114 b_ptrs += TILE_K * stride_bk 

115 

116 o += tl.dot(a, b, allow_tf32=False) 

117 

118 if DIVISIBLE_M and DIVISIBLE_N: 

119 mask_c = None 

120 elif DIVISIBLE_M and not DIVISIBLE_N: 

121 mask_c = mask_n[None, :] 

122 elif not DIVISIBLE_M and DIVISIBLE_N: 

123 mask_c = mask_m[:, None] 

124 else: 

125 mask_c = mask_m[:, None] & mask_n[None, :] 

126 tl.store(o_ptrs, o, mask_c) 

127 

128 

129def bmm(A, B): 

130 logger.debug("GEMS_CAMBRICON BMM") 

131 assert A.shape[0] == B.shape[0], "Batch dim mismatch" 

132 assert A.shape[2] == B.shape[1], "K dim mismatch" 

133 batch, M, K = A.shape 

134 _, _, N = B.shape 

135 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

136 

137 grid_fn = lambda meta: ( 

138 triton.cdiv(meta["M"], meta["TILE_M"]), 

139 triton.cdiv(meta["N"], meta["TILE_N"]), 

140 batch, 

141 ) 

142 with torch_device_fn.device(A.device): 

143 bmm_kernel[grid_fn]( 

144 A, 

145 B, 

146 out, 

147 M, 

148 N, 

149 K, 

150 A.stride(0), 

151 A.stride(1), 

152 A.stride(2), 

153 B.stride(0), 

154 B.stride(1), 

155 B.stride(2), 

156 out.stride(0), 

157 out.stride(1), 

158 out.stride(2), 

159 ) 

160 return out 

161 

162 

163def bmm_out(A, B, out): 

164 logger.debug("GEMS_CAMBRICON BMM_OUT") 

165 assert A.shape[0] == B.shape[0] == out.shape[0], "Batch dim mismatch" 

166 assert A.shape[2] == B.shape[1], "K dim mismatch" 

167 batch, M, K = A.shape 

168 _, _, N = B.shape 

169 

170 grid_fn = lambda meta: ( 

171 triton.cdiv(meta["M"], meta["TILE_M"]), 

172 triton.cdiv(meta["N"], meta["TILE_N"]), 

173 batch, 

174 ) 

175 with torch_device_fn.device(A.device): 

176 bmm_kernel[grid_fn]( 

177 A, 

178 B, 

179 out, 

180 M, 

181 N, 

182 K, 

183 A.stride(0), 

184 A.stride(1), 

185 A.stride(2), 

186 B.stride(0), 

187 B.stride(1), 

188 B.stride(2), 

189 out.stride(0), 

190 out.stride(1), 

191 out.stride(2), 

192 ) 

193 return out