Coverage for src/flag_gems/runtime/backend/_metax/ops/bmm.py: 0%

92 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems." + __name__) 

14 

15 

16@libentry() 

17@libtuner( 

18 configs=runtime.get_tuned_config("bmm"), 

19 key=["M", "N", "K"], 

20) 

21@triton.heuristics(runtime.get_heuristic_config("bmm")) 

22@triton.heuristics( 

23 { 

24 "UPGRADE": lambda args: math.ceil( 

25 (args["M"] * args["N"] * args["batch"]) / (args["TILE_M"] * args["TILE_M"]) 

26 ).bit_length() 

27 > 32, 

28 } 

29) 

30@triton.jit 

31def bmm_kernel( 

32 A, 

33 B, 

34 O, 

35 M, 

36 N, 

37 K, 

38 batch, 

39 TILE_M: tl.constexpr, 

40 TILE_N: tl.constexpr, 

41 TILE_K: tl.constexpr, 

42 GROUP_M: tl.constexpr, 

43 DIVISIBLE_M: tl.constexpr, 

44 DIVISIBLE_N: tl.constexpr, 

45 DIVISIBLE_K: tl.constexpr, 

46 UPGRADE: tl.constexpr, 

47): 

48 # batch offsets 

49 pid_b = tle.program_id(2) 

50 A += pid_b * M * K 

51 B += pid_b * K * N 

52 O += pid_b * M * N 

53 

54 if UPGRADE: 

55 pidx = tle.program_id(0) 

56 pidy = tle.program_id(1) 

57 else: 

58 pidx = tl.program_id(0) 

59 pidy = tl.program_id(1) 

60 

61 if GROUP_M == 1: 

62 pid_m, pid_n = pidx, pidy 

63 else: 

64 # reorder CTAs 

65 if UPGRADE: 

66 gridx = tle.num_programs(0) 

67 gridy = tle.num_programs(1) 

68 else: 

69 gridx = tl.num_programs(0) 

70 gridy = tl.num_programs(1) 

71 pid = pidx + pidy * gridx 

72 

73 num_CTA_per_group = gridy * GROUP_M 

74 

75 group_id = pid // num_CTA_per_group 

76 inner_group_id = pid % num_CTA_per_group 

77 GROUP_SIZE = tl.where( 

78 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

79 ) 

80 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

81 pid_n = inner_group_id // GROUP_SIZE 

82 

83 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

84 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

85 offs_k = tl.arange(0, TILE_K) 

86 

87 if not DIVISIBLE_M: 

88 mask_m = offs_m < M 

89 if not DIVISIBLE_N: 

90 mask_n = offs_n < N 

91 

92 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] 

93 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] 

94 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] 

95 

96 num_iters = tl.cdiv(K, TILE_K) 

97 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

98 for _ in range(num_iters): 

99 if DIVISIBLE_K: 

100 if DIVISIBLE_M: 

101 mask_a = None 

102 else: 

103 mask_a = mask_m[:, None] 

104 if DIVISIBLE_N: 

105 mask_b = None 

106 else: 

107 mask_b = mask_n[None, :] 

108 else: 

109 mask_k = offs_k < K 

110 if DIVISIBLE_M: 

111 mask_a = mask_k[None, :] 

112 else: 

113 mask_a = mask_m[:, None] & mask_k[None, :] 

114 if DIVISIBLE_N: 

115 mask_b = mask_k[:, None] 

116 else: 

117 mask_b = mask_k[:, None] & mask_n[None, :] 

118 

119 a = tl.load(a_ptrs, mask_a) 

120 b = tl.load(b_ptrs, mask_b) 

121 

122 offs_k += TILE_K 

123 a_ptrs += TILE_K 

124 b_ptrs += TILE_K * N 

125 

126 o += tl.dot(a, b, allow_tf32=False) 

127 

128 if DIVISIBLE_M and DIVISIBLE_N: 

129 mask_c = None 

130 elif DIVISIBLE_M and not DIVISIBLE_N: 

131 mask_c = mask_n[None, :] 

132 elif not DIVISIBLE_M and DIVISIBLE_N: 

133 mask_c = mask_m[:, None] 

134 else: 

135 mask_c = mask_m[:, None] & mask_n[None, :] 

136 tl.store(o_ptrs, o, mask_c) 

137 

138 

139def bmm(A, B): 

140 logger.debug("METAX GEMS BMM") 

141 batch, M, K = A.shape 

142 _, _, N = B.shape 

143 logger.debug( 

144 "METAX GEMS ADDMM_OUT, [shape info]: [%s, %s, %s, %s](batch, M, N, K), " 

145 "[A column-major]: %s, [B column-major]: %s", 

146 batch, 

147 M, 

148 N, 

149 K, 

150 A.stride(0) == 1, 

151 B.stride(0) == 1, 

152 ) 

153 A = A.contiguous() 

154 B = B.contiguous() 

155 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

156 

157 grid_fn = lambda meta: ( 

158 triton.cdiv(meta["M"], meta["TILE_M"]), 

159 triton.cdiv(meta["N"], meta["TILE_N"]), 

160 batch, 

161 ) 

162 with torch_device_fn.device(A.device): 

163 bmm_kernel[grid_fn](A, B, out, M, N, K, batch) 

164 return out