Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/kron.py: 0%
120 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13def prepare_tensor_for_kron(tensor_a, tensor_b):
14 a_shape = list(tensor_a.shape)
15 b_shape = list(tensor_b.shape)
17 if tensor_a.numel() == 0 or tensor_b.numel() == 0:
18 if not a_shape:
19 a_shape = [0]
20 if not b_shape:
21 b_shape = [0]
23 if len(a_shape) > len(b_shape):
24 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape
25 elif len(b_shape) > len(a_shape):
26 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape
28 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape))
29 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape
31 if len(a_shape) < 2:
32 a_shape = [1] * (2 - len(a_shape)) + a_shape
33 if len(b_shape) < 2:
34 b_shape = [1] * (2 - len(b_shape)) + b_shape
36 if len(a_shape) > len(b_shape):
37 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape
38 elif len(b_shape) > len(a_shape):
39 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape
41 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape))
42 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape
45def calculate_indices(batch_idx, shape_a, shape_b):
46 a_batch_dims = shape_a[:-2] or (1,)
47 b_batch_dims = shape_b[:-2] or (1,)
48 out_batch_dims = tuple(a * b for a, b in zip(a_batch_dims, b_batch_dims))
50 out_indices = []
51 remaining = batch_idx
52 for dim_size in out_batch_dims[::-1]:
53 out_indices.insert(0, remaining % dim_size)
54 remaining //= dim_size
56 a_idx = b_idx = 0
57 for out_idx, (a_dim, b_dim) in zip(out_indices, zip(a_batch_dims, b_batch_dims)):
58 a_idx = a_idx * a_dim + (out_idx // b_dim)
59 b_idx = b_idx * b_dim + (out_idx % b_dim)
61 return a_idx, b_idx
64@triton.autotune(
65 configs=[
66 triton.Config(
67 {"BLOCK_M": 512, "BLOCK_N": 512, "BLOCK_TILE_M": 128, "BLOCK_TILE_N": 64}
68 ),
69 # triton.Config(
70 # {"BLOCK_M": 512, "BLOCK_N": 2048, "BLOCK_TILE_M": 64, "BLOCK_TILE_N": 64}
71 # ),
72 # triton.Config(
73 # {"BLOCK_M": 512, "BLOCK_N": 8192, "BLOCK_TILE_M": 128, "BLOCK_TILE_N": 64}
74 # ),
75 # triton.Config(
76 # {"BLOCK_M": 512, "BLOCK_N": 32768, "BLOCK_TILE_M": 128, "BLOCK_TILE_N": 64}
77 # ),
78 ],
79 key=[
80 "M",
81 "N",
82 ],
83 warmup=1,
84 rep=1,
85)
86@triton.jit
87def kron_kernel(
88 a_ptr,
89 b_ptr,
90 c_ptr,
91 map_ptr,
92 batch_size: tl.int64,
93 M: tl.int64,
94 N: tl.int64,
95 M1: tl.int64,
96 M2: tl.int64,
97 N1: tl.int64,
98 N2: tl.int64,
99 a_stride_0: tl.int64,
100 a_stride_1: tl.int64,
101 b_stride_0: tl.int64,
102 b_stride_1: tl.int64,
103 c_stride_0: tl.int64,
104 c_stride_1: tl.int64,
105 a_batch_stride: tl.int64,
106 b_batch_stride: tl.int64,
107 c_batch_stride: tl.int64,
108 BLOCK_M: tl.constexpr,
109 BLOCK_N: tl.constexpr,
110 BLOCK_TILE_M: tl.constexpr,
111 BLOCK_TILE_N: tl.constexpr,
112):
113 pid = tl.program_id(0)
114 num_blocks_n = tl.cdiv(N, BLOCK_N)
115 num_blocks_m = tl.cdiv(M, BLOCK_M)
116 num_blocks_per_batch = num_blocks_m * num_blocks_n
118 batch_id = pid // num_blocks_per_batch
119 local_pid = pid % num_blocks_per_batch
120 block_m = local_pid // num_blocks_n
121 block_n = local_pid % num_blocks_n
123 offset = batch_id * 2
124 is_valid_batch = batch_id < batch_size
125 a_batch_idx = tl.load(map_ptr + offset, mask=is_valid_batch)
126 b_batch_idx = tl.load(map_ptr + offset + 1, mask=is_valid_batch)
128 num_tiles_m = tl.cdiv(BLOCK_M, BLOCK_TILE_M)
129 num_tiles_n = tl.cdiv(BLOCK_N, BLOCK_TILE_N)
131 for tile_m in range(num_tiles_m):
132 for tile_n in range(num_tiles_n):
133 tile_offset_m = tile_m * BLOCK_TILE_M
134 tile_offset_n = tile_n * BLOCK_TILE_N
136 current_offs_m = (
137 block_m * BLOCK_M + tile_offset_m + tl.arange(0, BLOCK_TILE_M)
138 )
139 current_offs_n = (
140 block_n * BLOCK_N + tile_offset_n + tl.arange(0, BLOCK_TILE_N)
141 )
143 tile_mask = (
144 (current_offs_m[:, None] < M)
145 & (current_offs_n[None, :] < N)
146 & is_valid_batch
147 )
149 a_row = current_offs_m[:, None] // M2
150 a_col = current_offs_n[None, :] // N2
151 b_row = current_offs_m[:, None] % M2
152 b_col = current_offs_n[None, :] % N2
154 a_idx = (
155 a_batch_idx * a_batch_stride + a_row * a_stride_0 + a_col * a_stride_1
156 )
157 b_idx = (
158 b_batch_idx * b_batch_stride + b_row * b_stride_0 + b_col * b_stride_1
159 )
161 a = tl.load(a_ptr + a_idx, mask=tile_mask)
162 b = tl.load(b_ptr + b_idx, mask=tile_mask)
163 c = a * b
165 c_idx = (
166 batch_id * c_batch_stride
167 + current_offs_m[:, None] * c_stride_0
168 + current_offs_n[None, :] * c_stride_1
169 )
171 tl.store(c_ptr + c_idx, c, mask=tile_mask)
174def kron(A, B):
175 logger.debug("GEMS_TSINGMICRO KRON")
176 if A.dim() == 0 and B.dim() == 0:
177 return A * B
179 if A.numel() == 0 or B.numel() == 0:
180 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B)
181 output_dtype = torch.promote_types(A.dtype, B.dtype)
182 return torch.empty(out_shape, device=A.device, dtype=output_dtype)
184 if A.dim() == 0:
185 return A.unsqueeze(0) * B
186 if B.dim() == 0:
187 return A * B.unsqueeze(0)
189 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B)
190 M1, N1 = A_prepared.shape[-2:]
191 M2, N2 = B_prepared.shape[-2:]
192 M, N = M1 * M2, N1 * N2
194 batch_size = math.prod(out_shape[:-2]) if out_shape[:-2] else 1
196 output_dtype = torch.promote_types(A.dtype, B.dtype)
197 C = torch.empty(out_shape, device=A.device, dtype=output_dtype)
199 C_reshaped = C.view(-1, M, N)
200 A_view = A_prepared.reshape(-1, M1, N1)
201 B_view = B_prepared.reshape(-1, M2, N2)
203 if not A_view.is_contiguous():
204 A_view = A_view.contiguous()
205 if not B_view.is_contiguous():
206 B_view = B_view.contiguous()
208 batch_indices = torch.empty(batch_size * 2, device=A.device, dtype=torch.int64)
209 for i in range(batch_size):
210 a_idx, b_idx = calculate_indices(i, A_prepared.shape, B_prepared.shape)
211 batch_indices[i * 2] = a_idx
212 batch_indices[i * 2 + 1] = b_idx
214 a_batch_stride = M1 * N1
215 b_batch_stride = M2 * N2
216 c_batch_stride = M * N
218 with torch_device_fn.device(A.device):
219 grid = lambda meta: (
220 batch_size
221 * triton.cdiv(M, meta["BLOCK_M"])
222 * triton.cdiv(N, meta["BLOCK_N"]),
223 )
225 kron_kernel[grid](
226 A_view,
227 B_view,
228 C_reshaped,
229 batch_indices,
230 batch_size,
231 M,
232 N,
233 M1,
234 M2,
235 N1,
236 N2,
237 A_view.stride(1),
238 A_view.stride(2),
239 B_view.stride(1),
240 B_view.stride(2),
241 C_reshaped.stride(1),
242 C_reshaped.stride(2),
243 a_batch_stride,
244 b_batch_stride,
245 c_batch_stride,
246 )
248 if A.dim() <= 1 and B.dim() <= 1:
249 return C.reshape(-1)
251 return C