Coverage for src/flag_gems/ops/kron.py: 62%
171 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15def prepare_tensor_for_kron(tensor_a, tensor_b):
16 a_shape = list(tensor_a.shape)
17 b_shape = list(tensor_b.shape)
19 if tensor_a.numel() == 0 or tensor_b.numel() == 0:
20 if not a_shape:
21 a_shape = [0]
22 if not b_shape:
23 b_shape = [0]
25 if len(a_shape) > len(b_shape):
26 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape
27 elif len(b_shape) > len(a_shape):
28 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape
30 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape))
31 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape
33 if len(a_shape) < 2:
34 a_shape = [1] * (2 - len(a_shape)) + a_shape
35 if len(b_shape) < 2:
36 b_shape = [1] * (2 - len(b_shape)) + b_shape
38 if len(a_shape) > len(b_shape):
39 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape
40 elif len(b_shape) > len(a_shape):
41 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape
43 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape))
44 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape
47def calculate_indices(batch_idx, shape_a, shape_b):
48 a_batch_dims = shape_a[:-2] or (1,)
49 b_batch_dims = shape_b[:-2] or (1,)
50 out_batch_dims = tuple(a * b for a, b in zip(a_batch_dims, b_batch_dims))
52 out_indices = []
53 remaining = batch_idx
54 for dim_size in out_batch_dims[::-1]:
55 out_indices.insert(0, remaining % dim_size)
56 remaining //= dim_size
58 a_idx = b_idx = 0
59 for out_idx, (a_dim, b_dim) in zip(out_indices, zip(a_batch_dims, b_batch_dims)):
60 a_idx = a_idx * a_dim + (out_idx // b_dim)
61 b_idx = b_idx * b_dim + (out_idx % b_dim)
63 return a_idx, b_idx
66@triton.autotune(configs=runtime.get_tuned_config("kron"), key=["M", "N"])
67@triton.jit
68def kron_kernel_for_batch_size_1(
69 a_ptr,
70 b_ptr,
71 c_ptr,
72 batch_size: tl.int64,
73 M: tl.int64,
74 N: tl.int64,
75 M1: tl.int64,
76 M2: tl.int64,
77 N1: tl.int64,
78 N2: tl.int64,
79 a_stride_0: tl.int64,
80 a_stride_1: tl.int64,
81 b_stride_0: tl.int64,
82 b_stride_1: tl.int64,
83 c_stride_0: tl.int64,
84 c_stride_1: tl.int64,
85 BLOCK_M: tl.constexpr,
86 BLOCK_N: tl.constexpr,
87):
88 pid = tle.program_id(0)
89 num_blocks_n = tl.cdiv(N, BLOCK_N)
90 num_blocks_m = tl.cdiv(M, BLOCK_M)
91 num_blocks_per_batch = num_blocks_m * num_blocks_n
93 local_pid = pid % num_blocks_per_batch
94 block_m = local_pid // num_blocks_n
95 block_n = local_pid % num_blocks_n
97 offs_m = block_m * BLOCK_M + tl.arange(0, BLOCK_M)
98 offs_n = block_n * BLOCK_N + tl.arange(0, BLOCK_N)
100 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
102 a_row = offs_m[:, None] // M2
103 a_col = offs_n[None, :] // N2
104 b_row = offs_m[:, None] % M2
105 b_col = offs_n[None, :] % N2
107 a_idx = a_row * a_stride_0 + a_col * a_stride_1
108 b_idx = b_row * b_stride_0 + b_col * b_stride_1
110 a = tl.load(a_ptr + a_idx, mask=mask)
111 b = tl.load(b_ptr + b_idx, mask=mask)
112 c = a * b
114 c_idx = offs_m[:, None] * c_stride_0 + offs_n[None, :] * c_stride_1
115 tl.store(c_ptr + c_idx, c, mask=mask)
118@triton.autotune(configs=runtime.get_tuned_config("kron"), key=["M", "N"])
119@triton.jit
120def kron_kernel(
121 a_ptr,
122 b_ptr,
123 c_ptr,
124 map_ptr,
125 batch_size: tl.int64,
126 M: tl.int64,
127 N: tl.int64,
128 M1: tl.int64,
129 M2: tl.int64,
130 N1: tl.int64,
131 N2: tl.int64,
132 a_stride_0: tl.int64,
133 a_stride_1: tl.int64,
134 b_stride_0: tl.int64,
135 b_stride_1: tl.int64,
136 c_stride_0: tl.int64,
137 c_stride_1: tl.int64,
138 a_batch_stride: tl.int64,
139 b_batch_stride: tl.int64,
140 c_batch_stride: tl.int64,
141 BLOCK_M: tl.constexpr,
142 BLOCK_N: tl.constexpr,
143):
144 pid = tle.program_id(0)
145 num_blocks_n = tl.cdiv(N, BLOCK_N)
146 num_blocks_m = tl.cdiv(M, BLOCK_M)
147 num_blocks_per_batch = num_blocks_m * num_blocks_n
149 batch_id = pid // num_blocks_per_batch
150 local_pid = pid % num_blocks_per_batch
151 block_m = local_pid // num_blocks_n
152 block_n = local_pid % num_blocks_n
154 offs_m = block_m * BLOCK_M + tl.arange(0, BLOCK_M)
155 offs_n = block_n * BLOCK_N + tl.arange(0, BLOCK_N)
157 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) & (batch_id < batch_size)
159 offset = batch_id * 2
160 is_valid = batch_id < batch_size
161 a_batch_idx = tl.load(map_ptr + offset, mask=is_valid)
162 b_batch_idx = tl.load(map_ptr + offset + 1, mask=is_valid)
164 a_row = offs_m[:, None] // M2
165 a_col = offs_n[None, :] // N2
166 b_row = offs_m[:, None] % M2
167 b_col = offs_n[None, :] % N2
169 a_idx = a_batch_idx * a_batch_stride + a_row * a_stride_0 + a_col * a_stride_1
170 b_idx = b_batch_idx * b_batch_stride + b_row * b_stride_0 + b_col * b_stride_1
172 a = tl.load(a_ptr + a_idx, mask=mask)
173 b = tl.load(b_ptr + b_idx, mask=mask)
174 c = a * b
176 c_idx = (
177 batch_id * c_batch_stride
178 + offs_m[:, None] * c_stride_0
179 + offs_n[None, :] * c_stride_1
180 )
181 tl.store(c_ptr + c_idx, c, mask=mask)
184@triton.jit
185def calculate_batch_indices_kernel(
186 batch_indices_ptr,
187 a_batch0: tl.int64,
188 a_batch1: tl.int64,
189 b_batch0: tl.int64,
190 b_batch1: tl.int64,
191 out_batch0: tl.int64,
192 out_batch1: tl.int64,
193 BLOCK_SIZE: tl.constexpr,
194):
195 pid = tl.program_id(axis=0)
197 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
199 out_indice1 = offset % out_batch1
200 remaining = offset // out_batch1
201 out_indice0 = remaining % out_batch0
202 a_idx = out_indice0 // b_batch0
203 a_idx = a_idx * a_batch1 + (out_indice1 // b_batch1)
204 b_idx = out_indice0 % b_batch0
205 b_idx = b_idx * b_batch1 + (out_indice1 % b_batch1)
207 a_store_offset = 2 * offset
208 b_store_offset = 2 * offset + 1
209 tl.store(batch_indices_ptr + a_store_offset, a_idx)
210 tl.store(batch_indices_ptr + b_store_offset, b_idx)
213def kron(A, B):
214 logger.debug("GEMS KRON")
215 if A.dim() == 0 and B.dim() == 0:
216 return A * B
218 if A.numel() == 0 or B.numel() == 0:
219 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B)
220 output_dtype = torch.promote_types(A.dtype, B.dtype)
221 return torch.empty(out_shape, device=A.device, dtype=output_dtype)
223 if A.dim() == 0:
224 return A.unsqueeze(0) * B
225 if B.dim() == 0:
226 return A * B.unsqueeze(0)
228 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B)
229 M1, N1 = A_prepared.shape[-2:]
230 M2, N2 = B_prepared.shape[-2:]
231 M, N = M1 * M2, N1 * N2
232 batch_size = math.prod(out_shape[:-2]) if out_shape[:-2] else 1
234 output_dtype = torch.promote_types(A.dtype, B.dtype)
235 C = torch.empty(out_shape, device=A.device, dtype=output_dtype)
237 C_reshaped = C.view(-1, M, N)
238 A_view = A_prepared.reshape(-1, M1, N1)
239 B_view = B_prepared.reshape(-1, M2, N2)
241 if not A_view.is_contiguous():
242 A_view = A_view.contiguous()
243 if not B_view.is_contiguous():
244 B_view = B_view.contiguous()
245 a_batch_stride = M1 * N1
246 b_batch_stride = M2 * N2
247 c_batch_stride = M * N
248 if A_prepared.dim() == 4 and B_prepared.dim() == 4:
249 batch_indices = torch.empty(batch_size * 2, device=A.device, dtype=torch.int64)
250 a_batch0, a_batch1 = A_prepared.shape[:-2]
251 b_batch0, b_batch1 = B_prepared.shape[:-2]
252 out_batch0 = a_batch0 * b_batch0
253 out_batch1 = a_batch1 * b_batch1
254 indice_tile_size = 256
255 grid_for_indice = (triton.cdiv(batch_size, indice_tile_size),)
256 with torch_device_fn.device(A.device):
257 calculate_batch_indices_kernel[grid_for_indice](
258 batch_indices,
259 a_batch0,
260 a_batch1,
261 b_batch0,
262 b_batch1,
263 out_batch0,
264 out_batch1,
265 BLOCK_SIZE=indice_tile_size,
266 )
267 grid = lambda meta: (
268 batch_size
269 * triton.cdiv(M, meta["BLOCK_M"])
270 * triton.cdiv(N, meta["BLOCK_N"]),
271 )
273 kron_kernel[grid](
274 A_view,
275 B_view,
276 C_reshaped,
277 batch_indices,
278 batch_size,
279 M,
280 N,
281 M1,
282 M2,
283 N1,
284 N2,
285 A_view.stride(1),
286 A_view.stride(2),
287 B_view.stride(1),
288 B_view.stride(2),
289 C_reshaped.stride(1),
290 C_reshaped.stride(2),
291 a_batch_stride,
292 b_batch_stride,
293 c_batch_stride,
294 )
296 else:
297 if batch_size != 1:
298 batch_indices = torch.empty(
299 batch_size * 2, device=A.device, dtype=torch.int64
300 )
301 for i in range(batch_size):
302 a_idx, b_idx = calculate_indices(i, A_prepared.shape, B_prepared.shape)
303 batch_indices[i * 2] = a_idx
304 batch_indices[i * 2 + 1] = b_idx
305 with torch_device_fn.device(A.device):
306 grid = lambda meta: (
307 batch_size
308 * triton.cdiv(M, meta["BLOCK_M"])
309 * triton.cdiv(N, meta["BLOCK_N"]),
310 )
311 kron_kernel[grid](
312 A_view,
313 B_view,
314 C_reshaped,
315 batch_indices,
316 batch_size,
317 M,
318 N,
319 M1,
320 M2,
321 N1,
322 N2,
323 A_view.stride(1),
324 A_view.stride(2),
325 B_view.stride(1),
326 B_view.stride(2),
327 C_reshaped.stride(1),
328 C_reshaped.stride(2),
329 a_batch_stride,
330 b_batch_stride,
331 c_batch_stride,
332 )
333 else:
334 with torch_device_fn.device(A.device):
335 grid = lambda meta: (
336 batch_size
337 * triton.cdiv(M, meta["BLOCK_M"])
338 * triton.cdiv(N, meta["BLOCK_N"]),
339 )
340 kron_kernel_for_batch_size_1[grid](
341 A_view,
342 B_view,
343 C_reshaped,
344 batch_size,
345 M,
346 N,
347 M1,
348 M2,
349 N1,
350 N2,
351 A_view.stride(1),
352 A_view.stride(2),
353 B_view.stride(1),
354 B_view.stride(2),
355 C_reshaped.stride(1),
356 C_reshaped.stride(2),
357 )
358 if A.dim() <= 1 and B.dim() <= 1:
359 return C.reshape(-1)
361 return C