Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/kron.py: 0%
117 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import math
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import triton_lang_extension as tle
12def prepare_tensor_for_kron(tensor_a, tensor_b):
13 a_shape = list(tensor_a.shape)
14 b_shape = list(tensor_b.shape)
16 if tensor_a.numel() == 0 or tensor_b.numel() == 0:
17 if not a_shape:
18 a_shape = [0]
19 if not b_shape:
20 b_shape = [0]
22 if len(a_shape) > len(b_shape):
23 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape
24 elif len(b_shape) > len(a_shape):
25 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape
27 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape))
28 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape
30 if len(a_shape) < 2:
31 a_shape = [1] * (2 - len(a_shape)) + a_shape
32 if len(b_shape) < 2:
33 b_shape = [1] * (2 - len(b_shape)) + b_shape
35 if len(a_shape) > len(b_shape):
36 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape
37 elif len(b_shape) > len(a_shape):
38 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape
40 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape))
41 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape
44def calculate_indices(batch_idx, shape_a, shape_b):
45 a_batch_dims = shape_a[:-2] or (1,)
46 b_batch_dims = shape_b[:-2] or (1,)
47 out_batch_dims = tuple(a * b for a, b in zip(a_batch_dims, b_batch_dims))
49 out_indices = []
50 remaining = batch_idx
51 for dim_size in out_batch_dims[::-1]:
52 out_indices.insert(0, remaining % dim_size)
53 remaining //= dim_size
55 a_idx = b_idx = 0
56 for out_idx, (a_dim, b_dim) in zip(out_indices, zip(a_batch_dims, b_batch_dims)):
57 a_idx = a_idx * a_dim + (out_idx // b_dim)
58 b_idx = b_idx * b_dim + (out_idx % b_dim)
60 return a_idx, b_idx
63def heur_block_n(args):
64 import builtins
66 return builtins.min(args["N"], 8192)
69def heur_block_m(args):
70 return triton.next_power_of_2(triton.cdiv(args["M"], 12))
73# @triton.autotune(configs=runtime.get_tuned_config("kron"), key=["M", "N"])
74@triton.heuristics(
75 {
76 "BLOCK_M": heur_block_m,
77 "BLOCK_N": heur_block_n,
78 }
79)
80@triton.jit
81def kron_kernel(
82 a_ptr,
83 b_ptr,
84 c_ptr,
85 map_ptr,
86 batch_size: tl.int64,
87 M: tl.int64,
88 N: tl.int64,
89 M1: tl.int64,
90 M2: tl.int64,
91 N1: tl.int64,
92 N2: tl.int64,
93 a_stride_0: tl.int64,
94 a_stride_1: tl.int64,
95 b_stride_0: tl.int64,
96 b_stride_1: tl.int64,
97 c_stride_0: tl.int64,
98 c_stride_1: tl.int64,
99 a_batch_stride: tl.int64,
100 b_batch_stride: tl.int64,
101 c_batch_stride: tl.int64,
102 BLOCK_M: tl.constexpr,
103 BLOCK_N: tl.constexpr,
104):
105 pid = tle.program_id(0)
106 num_blocks_n = tl.cdiv(N, BLOCK_N)
107 num_blocks_m = tl.cdiv(M, BLOCK_M)
108 num_blocks_per_batch = num_blocks_m * num_blocks_n
110 batch_id = pid // num_blocks_per_batch
111 local_pid = pid % num_blocks_per_batch
112 block_m = local_pid // num_blocks_n
113 block_n = local_pid % num_blocks_n
115 offs_m = block_m * BLOCK_M + tl.arange(0, BLOCK_M)
116 offs_n = block_n * BLOCK_N + tl.arange(0, BLOCK_N)
118 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) & (batch_id < batch_size)
120 offset = batch_id * 2
121 is_valid = batch_id < batch_size
122 a_batch_idx = tl.load(map_ptr + offset, mask=is_valid)
123 b_batch_idx = tl.load(map_ptr + offset + 1, mask=is_valid)
125 a_row = offs_m[:, None] // M2
126 a_col = offs_n[None, :] // N2
127 b_row = offs_m[:, None] % M2
128 b_col = offs_n[None, :] % N2
130 a_idx = a_batch_idx * a_batch_stride + a_row * a_stride_0 + a_col * a_stride_1
131 b_idx = b_batch_idx * b_batch_stride + b_row * b_stride_0 + b_col * b_stride_1
133 a = tl.load(a_ptr + a_idx, mask=mask)
134 b = tl.load(b_ptr + b_idx, mask=mask)
135 c = a * b
137 c_idx = (
138 batch_id * c_batch_stride
139 + offs_m[:, None] * c_stride_0
140 + offs_n[None, :] * c_stride_1
141 )
142 tl.store(c_ptr + c_idx, c, mask=mask)
145def kron(A, B):
146 if A.dim() == 0 and B.dim() == 0:
147 return A * B
149 if A.numel() == 0 or B.numel() == 0:
150 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B)
151 output_dtype = torch.promote_types(A.dtype, B.dtype)
152 return torch.empty(out_shape, device=A.device, dtype=output_dtype)
154 if A.dim() == 0:
155 return A.unsqueeze(0) * B
156 if B.dim() == 0:
157 return A * B.unsqueeze(0)
159 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B)
160 M1, N1 = A_prepared.shape[-2:]
161 M2, N2 = B_prepared.shape[-2:]
162 M, N = M1 * M2, N1 * N2
164 batch_size = math.prod(out_shape[:-2]) if out_shape[:-2] else 1
166 output_dtype = torch.promote_types(A.dtype, B.dtype)
167 C = torch.empty(out_shape, device=A.device, dtype=output_dtype)
169 C_reshaped = C.view(-1, M, N)
170 A_view = A_prepared.reshape(-1, M1, N1)
171 B_view = B_prepared.reshape(-1, M2, N2)
173 if not A_view.is_contiguous():
174 A_view = A_view.contiguous()
175 if not B_view.is_contiguous():
176 B_view = B_view.contiguous()
178 batch_indices = torch.empty(batch_size * 2, device=A.device, dtype=torch.int64)
179 for i in range(batch_size):
180 a_idx, b_idx = calculate_indices(i, A_prepared.shape, B_prepared.shape)
181 batch_indices[i * 2] = a_idx
182 batch_indices[i * 2 + 1] = b_idx
184 a_batch_stride = M1 * N1
185 b_batch_stride = M2 * N2
186 c_batch_stride = M * N
187 with torch_device_fn.device(A.device):
188 grid = lambda meta: (
189 batch_size
190 * triton.cdiv(M, meta["BLOCK_M"])
191 * triton.cdiv(N, meta["BLOCK_N"]),
192 )
194 kron_kernel[grid](
195 A_view,
196 B_view,
197 C_reshaped,
198 batch_indices,
199 batch_size,
200 M,
201 N,
202 M1,
203 M2,
204 N1,
205 N2,
206 A_view.stride(1),
207 A_view.stride(2),
208 B_view.stride(1),
209 B_view.stride(2),
210 C_reshaped.stride(1),
211 C_reshaped.stride(2),
212 a_batch_stride,
213 b_batch_stride,
214 c_batch_stride,
215 )
217 if A.dim() <= 1 and B.dim() <= 1:
218 return C.reshape(-1)
220 return C