Coverage for src/flag_gems/runtime/backend/_cambricon/ops/kron.py: 0%
121 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger(__name__)
14def prepare_tensor_for_kron(tensor_a, tensor_b):
15 a_shape = list(tensor_a.shape)
16 b_shape = list(tensor_b.shape)
18 if tensor_a.numel() == 0 or tensor_b.numel() == 0:
19 if not a_shape:
20 a_shape = [0]
21 if not b_shape:
22 b_shape = [0]
24 if len(a_shape) > len(b_shape):
25 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape
26 elif len(b_shape) > len(a_shape):
27 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape
29 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape))
30 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape
32 if len(a_shape) < 2:
33 a_shape = [1] * (2 - len(a_shape)) + a_shape
34 if len(b_shape) < 2:
35 b_shape = [1] * (2 - len(b_shape)) + b_shape
37 if len(a_shape) > len(b_shape):
38 b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape
39 elif len(b_shape) > len(a_shape):
40 a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape
42 out_shape = tuple(a * b for a, b in zip(a_shape, b_shape))
43 return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape
46def calculate_indices(batch_idx, shape_a, shape_b):
47 a_batch_dims = shape_a[:-2] or (1,)
48 b_batch_dims = shape_b[:-2] or (1,)
49 out_batch_dims = tuple(a * b for a, b in zip(a_batch_dims, b_batch_dims))
51 out_indices = []
52 remaining = batch_idx
53 for dim_size in out_batch_dims[::-1]:
54 out_indices.insert(0, remaining % dim_size)
55 remaining //= dim_size
57 a_idx = b_idx = 0
58 for out_idx, (a_dim, b_dim) in zip(out_indices, zip(a_batch_dims, b_batch_dims)):
59 a_idx = a_idx * a_dim + (out_idx // b_dim)
60 b_idx = b_idx * b_dim + (out_idx % b_dim)
62 return a_idx, b_idx
65@triton.autotune(
66 configs=[
67 triton.Config(
68 {"BLOCK_M": 512, "BLOCK_N": 512, "BLOCK_TILE_M": 128, "BLOCK_TILE_N": 64}
69 ),
70 triton.Config(
71 {"BLOCK_M": 512, "BLOCK_N": 2048, "BLOCK_TILE_M": 64, "BLOCK_TILE_N": 64}
72 ),
73 triton.Config(
74 {"BLOCK_M": 512, "BLOCK_N": 8192, "BLOCK_TILE_M": 128, "BLOCK_TILE_N": 64}
75 ),
76 triton.Config(
77 {"BLOCK_M": 512, "BLOCK_N": 32768, "BLOCK_TILE_M": 128, "BLOCK_TILE_N": 64}
78 ),
79 ],
80 key=[
81 "M",
82 "N",
83 ],
84)
85@triton.jit
86def kron_kernel(
87 a_ptr,
88 b_ptr,
89 c_ptr,
90 map_ptr,
91 batch_size: tl.int64,
92 M: tl.int64,
93 N: tl.int64,
94 M1: tl.int64,
95 M2: tl.int64,
96 N1: tl.int64,
97 N2: tl.int64,
98 a_stride_0: tl.int64,
99 a_stride_1: tl.int64,
100 b_stride_0: tl.int64,
101 b_stride_1: tl.int64,
102 c_stride_0: tl.int64,
103 c_stride_1: tl.int64,
104 a_batch_stride: tl.int64,
105 b_batch_stride: tl.int64,
106 c_batch_stride: tl.int64,
107 BLOCK_M: tl.constexpr,
108 BLOCK_N: tl.constexpr,
109 BLOCK_TILE_M: tl.constexpr,
110 BLOCK_TILE_N: tl.constexpr,
111):
112 pid = tle.program_id(0)
113 num_blocks_n = tl.cdiv(N, BLOCK_N)
114 num_blocks_m = tl.cdiv(M, BLOCK_M)
115 num_blocks_per_batch = num_blocks_m * num_blocks_n
117 batch_id = pid // num_blocks_per_batch
118 local_pid = pid % num_blocks_per_batch
119 block_m = local_pid // num_blocks_n
120 block_n = local_pid % num_blocks_n
122 offset = batch_id * 2
123 is_valid_batch = batch_id < batch_size
124 a_batch_idx = tl.load(map_ptr + offset, mask=is_valid_batch)
125 b_batch_idx = tl.load(map_ptr + offset + 1, mask=is_valid_batch)
127 num_tiles_m = tl.cdiv(BLOCK_M, BLOCK_TILE_M)
128 num_tiles_n = tl.cdiv(BLOCK_N, BLOCK_TILE_N)
130 for tile_m in range(num_tiles_m):
131 for tile_n in range(num_tiles_n):
132 tile_offset_m = tile_m * BLOCK_TILE_M
133 tile_offset_n = tile_n * BLOCK_TILE_N
135 current_offs_m = (
136 block_m * BLOCK_M + tile_offset_m + tl.arange(0, BLOCK_TILE_M)
137 )
138 current_offs_n = (
139 block_n * BLOCK_N + tile_offset_n + tl.arange(0, BLOCK_TILE_N)
140 )
142 tile_mask = (
143 (current_offs_m[:, None] < M)
144 & (current_offs_n[None, :] < N)
145 & is_valid_batch
146 )
148 a_row = current_offs_m[:, None] // M2
149 a_col = current_offs_n[None, :] // N2
150 b_row = current_offs_m[:, None] % M2
151 b_col = current_offs_n[None, :] % N2
153 a_idx = (
154 a_batch_idx * a_batch_stride + a_row * a_stride_0 + a_col * a_stride_1
155 )
156 b_idx = (
157 b_batch_idx * b_batch_stride + b_row * b_stride_0 + b_col * b_stride_1
158 )
160 a = tl.load(a_ptr + a_idx, mask=tile_mask)
161 b = tl.load(b_ptr + b_idx, mask=tile_mask)
162 c = a * b
164 c_idx = (
165 batch_id * c_batch_stride
166 + current_offs_m[:, None] * c_stride_0
167 + current_offs_n[None, :] * c_stride_1
168 )
170 tl.store(c_ptr + c_idx, c, mask=tile_mask)
173def kron(A, B):
174 logger.debug("GEMS_CAMBRICON KRON")
175 if A.dim() == 0 and B.dim() == 0:
176 return A * B
178 if A.numel() == 0 or B.numel() == 0:
179 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B)
180 output_dtype = torch.promote_types(A.dtype, B.dtype)
181 return torch.empty(out_shape, device=A.device, dtype=output_dtype)
183 if A.dim() == 0:
184 return A.unsqueeze(0) * B
185 if B.dim() == 0:
186 return A * B.unsqueeze(0)
188 A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B)
189 M1, N1 = A_prepared.shape[-2:]
190 M2, N2 = B_prepared.shape[-2:]
191 M, N = M1 * M2, N1 * N2
193 batch_size = math.prod(out_shape[:-2]) if out_shape[:-2] else 1
195 output_dtype = torch.promote_types(A.dtype, B.dtype)
196 C = torch.empty(out_shape, device=A.device, dtype=output_dtype)
198 C_reshaped = C.view(-1, M, N)
199 A_view = A_prepared.reshape(-1, M1, N1)
200 B_view = B_prepared.reshape(-1, M2, N2)
202 if not A_view.is_contiguous():
203 A_view = A_view.contiguous()
204 if not B_view.is_contiguous():
205 B_view = B_view.contiguous()
207 batch_indices = torch.empty(batch_size * 2, device=A.device, dtype=torch.int64)
208 for i in range(batch_size):
209 a_idx, b_idx = calculate_indices(i, A_prepared.shape, B_prepared.shape)
210 batch_indices[i * 2] = a_idx
211 batch_indices[i * 2 + 1] = b_idx
213 a_batch_stride = M1 * N1
214 b_batch_stride = M2 * N2
215 c_batch_stride = M * N
217 with torch_device_fn.device(A.device):
218 grid = lambda meta: (
219 batch_size
220 * triton.cdiv(M, meta["BLOCK_M"])
221 * triton.cdiv(N, meta["BLOCK_N"]),
222 )
224 kron_kernel[grid](
225 A_view,
226 B_view,
227 C_reshaped,
228 batch_indices,
229 batch_size,
230 M,
231 N,
232 M1,
233 M2,
234 N1,
235 N2,
236 A_view.stride(1),
237 A_view.stride(2),
238 B_view.stride(1),
239 B_view.stride(2),
240 C_reshaped.stride(1),
241 C_reshaped.stride(2),
242 a_batch_stride,
243 b_batch_stride,
244 c_batch_stride,
245 )
247 if A.dim() <= 1 and B.dim() <= 1:
248 return C.reshape(-1)
250 return C