Coverage for src/flag_gems/ops/kron.py: 62%

171 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15def prepare_tensor_for_kron(tensor_a, tensor_b): 

16 a_shape = list(tensor_a.shape) 

17 b_shape = list(tensor_b.shape) 

18 

19 if tensor_a.numel() == 0 or tensor_b.numel() == 0: 

20 if not a_shape: 

21 a_shape = [0] 

22 if not b_shape: 

23 b_shape = [0] 

24 

25 if len(a_shape) > len(b_shape): 

26 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape 

27 elif len(b_shape) > len(a_shape): 

28 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape 

29 

30 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape)) 

31 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape 

32 

33 if len(a_shape) < 2: 

34 a_shape = [1] * (2 - len(a_shape)) + a_shape 

35 if len(b_shape) < 2: 

36 b_shape = [1] * (2 - len(b_shape)) + b_shape 

37 

38 if len(a_shape) > len(b_shape): 

39 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape 

40 elif len(b_shape) > len(a_shape): 

41 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape 

42 

43 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape)) 

44 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape 

45 

46 

47def calculate_indices(batch_idx, shape_a, shape_b): 

48 a_batch_dims = shape_a[:-2] or (1,) 

49 b_batch_dims = shape_b[:-2] or (1,) 

50 out_batch_dims = tuple(a * b for a, b in zip(a_batch_dims, b_batch_dims)) 

51 

52 out_indices = [] 

53 remaining = batch_idx 

54 for dim_size in out_batch_dims[::-1]: 

55 out_indices.insert(0, remaining % dim_size) 

56 remaining //= dim_size 

57 

58 a_idx = b_idx = 0 

59 for out_idx, (a_dim, b_dim) in zip(out_indices, zip(a_batch_dims, b_batch_dims)): 

60 a_idx = a_idx * a_dim + (out_idx // b_dim) 

61 b_idx = b_idx * b_dim + (out_idx % b_dim) 

62 

63 return a_idx, b_idx 

64 

65 

66@triton.autotune(configs=runtime.get_tuned_config("kron"), key=["M", "N"]) 

67@triton.jit 

68def kron_kernel_for_batch_size_1( 

69 a_ptr, 

70 b_ptr, 

71 c_ptr, 

72 batch_size: tl.int64, 

73 M: tl.int64, 

74 N: tl.int64, 

75 M1: tl.int64, 

76 M2: tl.int64, 

77 N1: tl.int64, 

78 N2: tl.int64, 

79 a_stride_0: tl.int64, 

80 a_stride_1: tl.int64, 

81 b_stride_0: tl.int64, 

82 b_stride_1: tl.int64, 

83 c_stride_0: tl.int64, 

84 c_stride_1: tl.int64, 

85 BLOCK_M: tl.constexpr, 

86 BLOCK_N: tl.constexpr, 

87): 

88 pid = tle.program_id(0) 

89 num_blocks_n = tl.cdiv(N, BLOCK_N) 

90 num_blocks_m = tl.cdiv(M, BLOCK_M) 

91 num_blocks_per_batch = num_blocks_m * num_blocks_n 

92 

93 local_pid = pid % num_blocks_per_batch 

94 block_m = local_pid // num_blocks_n 

95 block_n = local_pid % num_blocks_n 

96 

97 offs_m = block_m * BLOCK_M + tl.arange(0, BLOCK_M) 

98 offs_n = block_n * BLOCK_N + tl.arange(0, BLOCK_N) 

99 

100 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

101 

102 a_row = offs_m[:, None] // M2 

103 a_col = offs_n[None, :] // N2 

104 b_row = offs_m[:, None] % M2 

105 b_col = offs_n[None, :] % N2 

106 

107 a_idx = a_row * a_stride_0 + a_col * a_stride_1 

108 b_idx = b_row * b_stride_0 + b_col * b_stride_1 

109 

110 a = tl.load(a_ptr + a_idx, mask=mask) 

111 b = tl.load(b_ptr + b_idx, mask=mask) 

112 c = a * b 

113 

114 c_idx = offs_m[:, None] * c_stride_0 + offs_n[None, :] * c_stride_1 

115 tl.store(c_ptr + c_idx, c, mask=mask) 

116 

117 

118@triton.autotune(configs=runtime.get_tuned_config("kron"), key=["M", "N"]) 

119@triton.jit 

120def kron_kernel( 

121 a_ptr, 

122 b_ptr, 

123 c_ptr, 

124 map_ptr, 

125 batch_size: tl.int64, 

126 M: tl.int64, 

127 N: tl.int64, 

128 M1: tl.int64, 

129 M2: tl.int64, 

130 N1: tl.int64, 

131 N2: tl.int64, 

132 a_stride_0: tl.int64, 

133 a_stride_1: tl.int64, 

134 b_stride_0: tl.int64, 

135 b_stride_1: tl.int64, 

136 c_stride_0: tl.int64, 

137 c_stride_1: tl.int64, 

138 a_batch_stride: tl.int64, 

139 b_batch_stride: tl.int64, 

140 c_batch_stride: tl.int64, 

141 BLOCK_M: tl.constexpr, 

142 BLOCK_N: tl.constexpr, 

143): 

144 pid = tle.program_id(0) 

145 num_blocks_n = tl.cdiv(N, BLOCK_N) 

146 num_blocks_m = tl.cdiv(M, BLOCK_M) 

147 num_blocks_per_batch = num_blocks_m * num_blocks_n 

148 

149 batch_id = pid // num_blocks_per_batch 

150 local_pid = pid % num_blocks_per_batch 

151 block_m = local_pid // num_blocks_n 

152 block_n = local_pid % num_blocks_n 

153 

154 offs_m = block_m * BLOCK_M + tl.arange(0, BLOCK_M) 

155 offs_n = block_n * BLOCK_N + tl.arange(0, BLOCK_N) 

156 

157 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) & (batch_id < batch_size) 

158 

159 offset = batch_id * 2 

160 is_valid = batch_id < batch_size 

161 a_batch_idx = tl.load(map_ptr + offset, mask=is_valid) 

162 b_batch_idx = tl.load(map_ptr + offset + 1, mask=is_valid) 

163 

164 a_row = offs_m[:, None] // M2 

165 a_col = offs_n[None, :] // N2 

166 b_row = offs_m[:, None] % M2 

167 b_col = offs_n[None, :] % N2 

168 

169 a_idx = a_batch_idx * a_batch_stride + a_row * a_stride_0 + a_col * a_stride_1 

170 b_idx = b_batch_idx * b_batch_stride + b_row * b_stride_0 + b_col * b_stride_1 

171 

172 a = tl.load(a_ptr + a_idx, mask=mask) 

173 b = tl.load(b_ptr + b_idx, mask=mask) 

174 c = a * b 

175 

176 c_idx = ( 

177 batch_id * c_batch_stride 

178 + offs_m[:, None] * c_stride_0 

179 + offs_n[None, :] * c_stride_1 

180 ) 

181 tl.store(c_ptr + c_idx, c, mask=mask) 

182 

183 

184@triton.jit 

185def calculate_batch_indices_kernel( 

186 batch_indices_ptr, 

187 a_batch0: tl.int64, 

188 a_batch1: tl.int64, 

189 b_batch0: tl.int64, 

190 b_batch1: tl.int64, 

191 out_batch0: tl.int64, 

192 out_batch1: tl.int64, 

193 BLOCK_SIZE: tl.constexpr, 

194): 

195 pid = tl.program_id(axis=0) 

196 

197 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

198 

199 out_indice1 = offset % out_batch1 

200 remaining = offset // out_batch1 

201 out_indice0 = remaining % out_batch0 

202 a_idx = out_indice0 // b_batch0 

203 a_idx = a_idx * a_batch1 + (out_indice1 // b_batch1) 

204 b_idx = out_indice0 % b_batch0 

205 b_idx = b_idx * b_batch1 + (out_indice1 % b_batch1) 

206 

207 a_store_offset = 2 * offset 

208 b_store_offset = 2 * offset + 1 

209 tl.store(batch_indices_ptr + a_store_offset, a_idx) 

210 tl.store(batch_indices_ptr + b_store_offset, b_idx) 

211 

212 

213def kron(A, B): 

214 logger.debug("GEMS KRON") 

215 if A.dim() == 0 and B.dim() == 0: 

216 return A * B 

217 

218 if A.numel() == 0 or B.numel() == 0: 

219 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B) 

220 output_dtype = torch.promote_types(A.dtype, B.dtype) 

221 return torch.empty(out_shape, device=A.device, dtype=output_dtype) 

222 

223 if A.dim() == 0: 

224 return A.unsqueeze(0) * B 

225 if B.dim() == 0: 

226 return A * B.unsqueeze(0) 

227 

228 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B) 

229 M1, N1 = A_prepared.shape[-2:] 

230 M2, N2 = B_prepared.shape[-2:] 

231 M, N = M1 * M2, N1 * N2 

232 batch_size = math.prod(out_shape[:-2]) if out_shape[:-2] else 1 

233 

234 output_dtype = torch.promote_types(A.dtype, B.dtype) 

235 C = torch.empty(out_shape, device=A.device, dtype=output_dtype) 

236 

237 C_reshaped = C.view(-1, M, N) 

238 A_view = A_prepared.reshape(-1, M1, N1) 

239 B_view = B_prepared.reshape(-1, M2, N2) 

240 

241 if not A_view.is_contiguous(): 

242 A_view = A_view.contiguous() 

243 if not B_view.is_contiguous(): 

244 B_view = B_view.contiguous() 

245 a_batch_stride = M1 * N1 

246 b_batch_stride = M2 * N2 

247 c_batch_stride = M * N 

248 if A_prepared.dim() == 4 and B_prepared.dim() == 4: 

249 batch_indices = torch.empty(batch_size * 2, device=A.device, dtype=torch.int64) 

250 a_batch0, a_batch1 = A_prepared.shape[:-2] 

251 b_batch0, b_batch1 = B_prepared.shape[:-2] 

252 out_batch0 = a_batch0 * b_batch0 

253 out_batch1 = a_batch1 * b_batch1 

254 indice_tile_size = 256 

255 grid_for_indice = (triton.cdiv(batch_size, indice_tile_size),) 

256 with torch_device_fn.device(A.device): 

257 calculate_batch_indices_kernel[grid_for_indice]( 

258 batch_indices, 

259 a_batch0, 

260 a_batch1, 

261 b_batch0, 

262 b_batch1, 

263 out_batch0, 

264 out_batch1, 

265 BLOCK_SIZE=indice_tile_size, 

266 ) 

267 grid = lambda meta: ( 

268 batch_size 

269 * triton.cdiv(M, meta["BLOCK_M"]) 

270 * triton.cdiv(N, meta["BLOCK_N"]), 

271 ) 

272 

273 kron_kernel[grid]( 

274 A_view, 

275 B_view, 

276 C_reshaped, 

277 batch_indices, 

278 batch_size, 

279 M, 

280 N, 

281 M1, 

282 M2, 

283 N1, 

284 N2, 

285 A_view.stride(1), 

286 A_view.stride(2), 

287 B_view.stride(1), 

288 B_view.stride(2), 

289 C_reshaped.stride(1), 

290 C_reshaped.stride(2), 

291 a_batch_stride, 

292 b_batch_stride, 

293 c_batch_stride, 

294 ) 

295 

296 else: 

297 if batch_size != 1: 

298 batch_indices = torch.empty( 

299 batch_size * 2, device=A.device, dtype=torch.int64 

300 ) 

301 for i in range(batch_size): 

302 a_idx, b_idx = calculate_indices(i, A_prepared.shape, B_prepared.shape) 

303 batch_indices[i * 2] = a_idx 

304 batch_indices[i * 2 + 1] = b_idx 

305 with torch_device_fn.device(A.device): 

306 grid = lambda meta: ( 

307 batch_size 

308 * triton.cdiv(M, meta["BLOCK_M"]) 

309 * triton.cdiv(N, meta["BLOCK_N"]), 

310 ) 

311 kron_kernel[grid]( 

312 A_view, 

313 B_view, 

314 C_reshaped, 

315 batch_indices, 

316 batch_size, 

317 M, 

318 N, 

319 M1, 

320 M2, 

321 N1, 

322 N2, 

323 A_view.stride(1), 

324 A_view.stride(2), 

325 B_view.stride(1), 

326 B_view.stride(2), 

327 C_reshaped.stride(1), 

328 C_reshaped.stride(2), 

329 a_batch_stride, 

330 b_batch_stride, 

331 c_batch_stride, 

332 ) 

333 else: 

334 with torch_device_fn.device(A.device): 

335 grid = lambda meta: ( 

336 batch_size 

337 * triton.cdiv(M, meta["BLOCK_M"]) 

338 * triton.cdiv(N, meta["BLOCK_N"]), 

339 ) 

340 kron_kernel_for_batch_size_1[grid]( 

341 A_view, 

342 B_view, 

343 C_reshaped, 

344 batch_size, 

345 M, 

346 N, 

347 M1, 

348 M2, 

349 N1, 

350 N2, 

351 A_view.stride(1), 

352 A_view.stride(2), 

353 B_view.stride(1), 

354 B_view.stride(2), 

355 C_reshaped.stride(1), 

356 C_reshaped.stride(2), 

357 ) 

358 if A.dim() <= 1 and B.dim() <= 1: 

359 return C.reshape(-1) 

360 

361 return C