Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/bmm.py: 0%

102 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15def heur_group_m(args): 

16 if args["TILE_M"] > args["TILE_N"]: 

17 return 1 

18 else: 

19 return (args["M"] + args["TILE_M"] - 1) // args["TILE_M"] 

20 

21 

22def heur_divisible_m(args): 

23 return args["M"] % args["TILE_M"] == 0 

24 

25 

26def heur_divisible_n(args): 

27 return args["N"] % args["TILE_N"] == 0 

28 

29 

30def heur_divisible_k(args): 

31 return args["K"] % args["TILE_K"] == 0 

32 

33 

34@libentry() 

35@triton.autotune( 

36 configs=[], 

37 generate_configs="bmm", 

38 key=["M", "N", "K"], 

39) 

40@triton.heuristics( 

41 { 

42 "GROUP_M": heur_group_m, 

43 "DIVISIBLE_M": heur_divisible_m, 

44 "DIVISIBLE_N": heur_divisible_n, 

45 "DIVISIBLE_K": heur_divisible_k, 

46 } 

47) 

48@triton.jit 

49def bmm_kernel( 

50 A, 

51 B, 

52 O, 

53 M, 

54 N, 

55 K, 

56 TILE_M: tl.constexpr, 

57 TILE_N: tl.constexpr, 

58 TILE_K: tl.constexpr, 

59 GROUP_M: tl.constexpr, 

60 DIVISIBLE_M: tl.constexpr, 

61 DIVISIBLE_N: tl.constexpr, 

62 DIVISIBLE_K: tl.constexpr, 

63): 

64 # batch offsets 

65 pid_b = tle.program_id(2) 

66 A += pid_b * M * K 

67 B += pid_b * K * N 

68 O += pid_b * M * N 

69 

70 pidx = tle.program_id(0) 

71 pidy = tle.program_id(1) 

72 

73 if GROUP_M == 1: 

74 pid_m, pid_n = pidx, pidy 

75 else: 

76 # reorder CTAs 

77 gridx = tle.num_programs(0) 

78 gridy = tle.num_programs(1) 

79 pid = pidx + pidy * gridx 

80 

81 num_CTA_per_group = gridy * GROUP_M 

82 

83 group_id = pid // num_CTA_per_group 

84 inner_group_id = pid % num_CTA_per_group 

85 GROUP_SIZE = tl.where( 

86 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

87 ) 

88 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

89 pid_n = inner_group_id // GROUP_SIZE 

90 

91 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

92 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

93 offs_k = tl.arange(0, TILE_K) 

94 

95 if not DIVISIBLE_M: 

96 mask_m = offs_m < M 

97 if not DIVISIBLE_N: 

98 mask_n = offs_n < N 

99 

100 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] 

101 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] 

102 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] 

103 

104 num_iters = tl.cdiv(K, TILE_K) 

105 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

106 for _ in range(num_iters): 

107 if DIVISIBLE_K: 

108 if DIVISIBLE_M: 

109 mask_a = tl.full([TILE_M, TILE_K], value=1, dtype=tl.int1) 

110 else: 

111 mask_a = mask_m[:, None] 

112 if DIVISIBLE_N: 

113 mask_b = tl.full([TILE_K, TILE_N], value=1, dtype=tl.int1) 

114 else: 

115 mask_b = mask_n[None, :] 

116 else: 

117 mask_k = offs_k < K 

118 offs_k += TILE_K 

119 if DIVISIBLE_M: 

120 mask_a = mask_k[None, :] 

121 else: 

122 mask_a = mask_m[:, None] & mask_k[None, :] 

123 if DIVISIBLE_N: 

124 mask_b = mask_k[:, None] 

125 else: 

126 mask_b = mask_k[:, None] & mask_n[None, :] 

127 

128 a = tl.load(a_ptrs, mask_a) 

129 b = tl.load(b_ptrs, mask_b) 

130 

131 a_ptrs += TILE_K 

132 b_ptrs += TILE_K * N 

133 

134 o += tl.dot(a, b, allow_tf32=False) 

135 

136 if DIVISIBLE_M and DIVISIBLE_N: 

137 mask_c = tl.full([TILE_M, TILE_N], value=1, dtype=tl.int1) 

138 elif DIVISIBLE_M and not DIVISIBLE_N: 

139 mask_c = mask_n[None, :] 

140 elif not DIVISIBLE_M and DIVISIBLE_N: 

141 mask_c = mask_m[:, None] 

142 else: 

143 mask_c = mask_m[:, None] & mask_n[None, :] 

144 tl.store(o_ptrs, o, mask_c) 

145 

146 

147def bmm(A, B): 

148 logger.debug("GEMS BMM") 

149 batch, M, K = A.shape 

150 _, _, N = B.shape 

151 A = A.contiguous() 

152 B = B.contiguous() 

153 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

154 

155 grid_fn = lambda meta: ( 

156 triton.cdiv(meta["M"], meta["TILE_M"]), 

157 triton.cdiv(meta["N"], meta["TILE_N"]), 

158 batch, 

159 ) 

160 with torch_device_fn.device(A.device): 

161 bmm_kernel[grid_fn](A, B, out, M, N, K) 

162 return out 

163 

164 

165def bmm_out(A, B, out): 

166 logger.debug("GEMS BMM_OUT") 

167 assert A.shape[0] == B.shape[0] == out.shape[0], "Batch dim mismatch" 

168 assert A.shape[2] == B.shape[1], "K dim mismatch" 

169 batch, M, K = A.shape 

170 _, _, N = B.shape 

171 

172 grid_fn = lambda meta: ( 

173 triton.cdiv(meta["M"], meta["TILE_M"]), 

174 triton.cdiv(meta["N"], meta["TILE_N"]), 

175 batch, 

176 ) 

177 with torch_device_fn.device(A.device): 

178 bmm_kernel[grid_fn](A, B, out, M, N, K) 

179 return out