Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/kron.py: 0%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13def prepare_tensor_for_kron(tensor_a, tensor_b): 

14 a_shape = list(tensor_a.shape) 

15 b_shape = list(tensor_b.shape) 

16 

17 if tensor_a.numel() == 0 or tensor_b.numel() == 0: 

18 if not a_shape: 

19 a_shape = [0] 

20 if not b_shape: 

21 b_shape = [0] 

22 

23 if len(a_shape) > len(b_shape): 

24 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape 

25 elif len(b_shape) > len(a_shape): 

26 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape 

27 

28 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape)) 

29 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape 

30 

31 if len(a_shape) < 2: 

32 a_shape = [1] * (2 - len(a_shape)) + a_shape 

33 if len(b_shape) < 2: 

34 b_shape = [1] * (2 - len(b_shape)) + b_shape 

35 

36 if len(a_shape) > len(b_shape): 

37 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape 

38 elif len(b_shape) > len(a_shape): 

39 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape 

40 

41 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape)) 

42 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape 

43 

44 

45def calculate_indices(batch_idx, shape_a, shape_b): 

46 a_batch_dims = shape_a[:-2] or (1,) 

47 b_batch_dims = shape_b[:-2] or (1,) 

48 out_batch_dims = tuple(a * b for a, b in zip(a_batch_dims, b_batch_dims)) 

49 

50 out_indices = [] 

51 remaining = batch_idx 

52 for dim_size in out_batch_dims[::-1]: 

53 out_indices.insert(0, remaining % dim_size) 

54 remaining //= dim_size 

55 

56 a_idx = b_idx = 0 

57 for out_idx, (a_dim, b_dim) in zip(out_indices, zip(a_batch_dims, b_batch_dims)): 

58 a_idx = a_idx * a_dim + (out_idx // b_dim) 

59 b_idx = b_idx * b_dim + (out_idx % b_dim) 

60 

61 return a_idx, b_idx 

62 

63 

64@triton.autotune( 

65 configs=[ 

66 triton.Config( 

67 {"BLOCK_M": 512, "BLOCK_N": 512, "BLOCK_TILE_M": 128, "BLOCK_TILE_N": 64} 

68 ), 

69 # triton.Config( 

70 # {"BLOCK_M": 512, "BLOCK_N": 2048, "BLOCK_TILE_M": 64, "BLOCK_TILE_N": 64} 

71 # ), 

72 # triton.Config( 

73 # {"BLOCK_M": 512, "BLOCK_N": 8192, "BLOCK_TILE_M": 128, "BLOCK_TILE_N": 64} 

74 # ), 

75 # triton.Config( 

76 # {"BLOCK_M": 512, "BLOCK_N": 32768, "BLOCK_TILE_M": 128, "BLOCK_TILE_N": 64} 

77 # ), 

78 ], 

79 key=[ 

80 "M", 

81 "N", 

82 ], 

83 warmup=1, 

84 rep=1, 

85) 

86@triton.jit 

87def kron_kernel( 

88 a_ptr, 

89 b_ptr, 

90 c_ptr, 

91 map_ptr, 

92 batch_size: tl.int64, 

93 M: tl.int64, 

94 N: tl.int64, 

95 M1: tl.int64, 

96 M2: tl.int64, 

97 N1: tl.int64, 

98 N2: tl.int64, 

99 a_stride_0: tl.int64, 

100 a_stride_1: tl.int64, 

101 b_stride_0: tl.int64, 

102 b_stride_1: tl.int64, 

103 c_stride_0: tl.int64, 

104 c_stride_1: tl.int64, 

105 a_batch_stride: tl.int64, 

106 b_batch_stride: tl.int64, 

107 c_batch_stride: tl.int64, 

108 BLOCK_M: tl.constexpr, 

109 BLOCK_N: tl.constexpr, 

110 BLOCK_TILE_M: tl.constexpr, 

111 BLOCK_TILE_N: tl.constexpr, 

112): 

113 pid = tl.program_id(0) 

114 num_blocks_n = tl.cdiv(N, BLOCK_N) 

115 num_blocks_m = tl.cdiv(M, BLOCK_M) 

116 num_blocks_per_batch = num_blocks_m * num_blocks_n 

117 

118 batch_id = pid // num_blocks_per_batch 

119 local_pid = pid % num_blocks_per_batch 

120 block_m = local_pid // num_blocks_n 

121 block_n = local_pid % num_blocks_n 

122 

123 offset = batch_id * 2 

124 is_valid_batch = batch_id < batch_size 

125 a_batch_idx = tl.load(map_ptr + offset, mask=is_valid_batch) 

126 b_batch_idx = tl.load(map_ptr + offset + 1, mask=is_valid_batch) 

127 

128 num_tiles_m = tl.cdiv(BLOCK_M, BLOCK_TILE_M) 

129 num_tiles_n = tl.cdiv(BLOCK_N, BLOCK_TILE_N) 

130 

131 for tile_m in range(num_tiles_m): 

132 for tile_n in range(num_tiles_n): 

133 tile_offset_m = tile_m * BLOCK_TILE_M 

134 tile_offset_n = tile_n * BLOCK_TILE_N 

135 

136 current_offs_m = ( 

137 block_m * BLOCK_M + tile_offset_m + tl.arange(0, BLOCK_TILE_M) 

138 ) 

139 current_offs_n = ( 

140 block_n * BLOCK_N + tile_offset_n + tl.arange(0, BLOCK_TILE_N) 

141 ) 

142 

143 tile_mask = ( 

144 (current_offs_m[:, None] < M) 

145 & (current_offs_n[None, :] < N) 

146 & is_valid_batch 

147 ) 

148 

149 a_row = current_offs_m[:, None] // M2 

150 a_col = current_offs_n[None, :] // N2 

151 b_row = current_offs_m[:, None] % M2 

152 b_col = current_offs_n[None, :] % N2 

153 

154 a_idx = ( 

155 a_batch_idx * a_batch_stride + a_row * a_stride_0 + a_col * a_stride_1 

156 ) 

157 b_idx = ( 

158 b_batch_idx * b_batch_stride + b_row * b_stride_0 + b_col * b_stride_1 

159 ) 

160 

161 a = tl.load(a_ptr + a_idx, mask=tile_mask) 

162 b = tl.load(b_ptr + b_idx, mask=tile_mask) 

163 c = a * b 

164 

165 c_idx = ( 

166 batch_id * c_batch_stride 

167 + current_offs_m[:, None] * c_stride_0 

168 + current_offs_n[None, :] * c_stride_1 

169 ) 

170 

171 tl.store(c_ptr + c_idx, c, mask=tile_mask) 

172 

173 

174def kron(A, B): 

175 logger.debug("GEMS_TSINGMICRO KRON") 

176 if A.dim() == 0 and B.dim() == 0: 

177 return A * B 

178 

179 if A.numel() == 0 or B.numel() == 0: 

180 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B) 

181 output_dtype = torch.promote_types(A.dtype, B.dtype) 

182 return torch.empty(out_shape, device=A.device, dtype=output_dtype) 

183 

184 if A.dim() == 0: 

185 return A.unsqueeze(0) * B 

186 if B.dim() == 0: 

187 return A * B.unsqueeze(0) 

188 

189 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B) 

190 M1, N1 = A_prepared.shape[-2:] 

191 M2, N2 = B_prepared.shape[-2:] 

192 M, N = M1 * M2, N1 * N2 

193 

194 batch_size = math.prod(out_shape[:-2]) if out_shape[:-2] else 1 

195 

196 output_dtype = torch.promote_types(A.dtype, B.dtype) 

197 C = torch.empty(out_shape, device=A.device, dtype=output_dtype) 

198 

199 C_reshaped = C.view(-1, M, N) 

200 A_view = A_prepared.reshape(-1, M1, N1) 

201 B_view = B_prepared.reshape(-1, M2, N2) 

202 

203 if not A_view.is_contiguous(): 

204 A_view = A_view.contiguous() 

205 if not B_view.is_contiguous(): 

206 B_view = B_view.contiguous() 

207 

208 batch_indices = torch.empty(batch_size * 2, device=A.device, dtype=torch.int64) 

209 for i in range(batch_size): 

210 a_idx, b_idx = calculate_indices(i, A_prepared.shape, B_prepared.shape) 

211 batch_indices[i * 2] = a_idx 

212 batch_indices[i * 2 + 1] = b_idx 

213 

214 a_batch_stride = M1 * N1 

215 b_batch_stride = M2 * N2 

216 c_batch_stride = M * N 

217 

218 with torch_device_fn.device(A.device): 

219 grid = lambda meta: ( 

220 batch_size 

221 * triton.cdiv(M, meta["BLOCK_M"]) 

222 * triton.cdiv(N, meta["BLOCK_N"]), 

223 ) 

224 

225 kron_kernel[grid]( 

226 A_view, 

227 B_view, 

228 C_reshaped, 

229 batch_indices, 

230 batch_size, 

231 M, 

232 N, 

233 M1, 

234 M2, 

235 N1, 

236 N2, 

237 A_view.stride(1), 

238 A_view.stride(2), 

239 B_view.stride(1), 

240 B_view.stride(2), 

241 C_reshaped.stride(1), 

242 C_reshaped.stride(2), 

243 a_batch_stride, 

244 b_batch_stride, 

245 c_batch_stride, 

246 ) 

247 

248 if A.dim() <= 1 and B.dim() <= 1: 

249 return C.reshape(-1) 

250 

251 return C