Coverage for src/flag_gems/runtime/backend/_cambricon/ops/cat.py: 0%
213 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import importlib
2import logging
3import math
4import os
5from typing import Callable, List, Mapping, Tuple, Union
7import torch
9from flag_gems.utils.code_cache import code_cache_dir
10from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
12from .vstack import vstack
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17class CatKernelGenerator(IndentedBuffer):
18 overloads: Mapping[str, Callable] = {}
20 def __init__(self):
21 self.pid = os.getpid()
22 self.cache = self.overloads
23 super().__init__()
25 def __init(
26 self,
27 tensors: List[torch.Tensor],
28 dim: int,
29 high_num: int,
30 low_cat_accum: List[int],
31 ):
32 self.dim = dim
33 self.high_num = high_num
34 self.low_cat_accum = low_cat_accum
35 self.tensor_num = len(tensors)
36 even = all([t.numel() == tensors[0].numel() for t in tensors])
38 if even and low_cat_accum[-1] // self.tensor_num <= 128:
39 # Special case for tensors with small and even low size,
40 # which means weak contiguity when storing the out tensor.
41 # Divide each tensor into tiles of `BLOCK_LOW` size,
42 # and each cta process tiles one by one.
43 self.kernel_name = "_cat_kernel_small"
44 self.wrapper_name = "_cat_wrapper_small"
45 self.MODE = 0
46 else:
47 # General cases.
48 # Divide tasks by high_num, each cta process parts of high of all tensors.
49 self.kernel_name = "_cat_kernel_parthigh"
50 self.wrapper_name = "_cat_wrapper_parthigh"
51 self.MODE = 1
53 def __call__(
54 self,
55 tensors: List[torch.Tensor],
56 dim: int,
57 high_num: int,
58 low_cat_accum: List[int],
59 ):
60 self.__init(tensors, dim, high_num, low_cat_accum)
61 key = f"{len(tensors)}_{high_num}_{low_cat_accum[-1]}"
62 if key not in self.cache:
63 self.codegen()
65 filename = f"{self.kernel_name}_{key}.py"
66 filepath = code_cache_dir() / filename
67 write_atomic(filepath, self.getvalue())
69 spec = importlib.util.spec_from_file_location(
70 f"_gen_module_{key}", filepath
71 )
72 m = importlib.util.module_from_spec(spec)
73 spec.loader.exec_module(m)
74 overload = getattr(m, self.wrapper_name)
75 self.cache[key] = overload
76 overload = self.cache[key]
77 return overload(tensors, dim, high_num, low_cat_accum)
79 def gen_imports(self):
80 self.writeline("import math")
81 self.writeline("import copy")
82 self.newline()
83 self.writeline("import torch")
84 self.writeline("import triton")
85 self.writeline("import triton.language as tl")
86 self.newline()
87 self.writeline("from flag_gems.runtime import torch_device_fn")
88 self.writeline("from flag_gems.runtime.backend import vendor_module")
89 self.writeline("from flag_gems.utils import libentry, libtuner")
90 self.newline()
91 self.writeline("TOTAL_CORE_NUM = vendor_module.TOTAL_CORE_NUM")
92 self.newline()
93 self.newline()
95 def gen_wrapper(self):
96 self.writeline(
97 f"def {self.wrapper_name}(tensors, dim, high_num, low_cat_accum):"
98 )
99 with self.indent():
100 self.writeline("device = tensors[0].device")
101 self.writeline("dtype = tensors[0].dtype")
102 self.writeline("tensor_num = len(tensors)")
103 self.writeline("cat_dim_size = sum([t.shape[dim] for t in tensors])")
104 self.writeline("out_shape = list(tensors[0].shape)")
105 self.writeline("out_shape[dim] = cat_dim_size")
106 self.writeline("out_cat_num = low_cat_accum[-1]")
107 self.writeline("out = torch.empty(out_shape, device=device, dtype=dtype)")
108 for i in range(self.tensor_num):
109 self.writeline(f"in{i}_stride_high = tensors[{i}].stride(dim - 1)")
110 self.writeline(f"in{i}_stride_low = tensors[{i}].stride(-1)")
111 self.writeline("out_stride_high = out.stride(dim - 1)")
112 self.writeline("out_stride_low = out.stride(-1)")
113 self.writeline(
114 "grid = lambda meta: (TOTAL_CORE_NUM // meta['num_warps'], )"
115 )
116 self.writeline("with torch_device_fn.device(device):")
117 with self.indent():
118 self.writeline(
119 f"{self.kernel_name}[grid]({self.gen_kernel_args(is_declare=False)})"
120 )
121 self.writeline("return out")
122 self.newline()
123 self.newline()
125 def gen_decorators(self):
126 self.writeline("@libentry()")
127 self.writeline("@libtuner(")
128 with self.indent():
129 self.writeline("configs=[")
130 with self.indent():
131 if self.MODE == 0:
132 self.writeline(
133 """
134 triton.Config({'BLOCK_LOW': 2 ** i}, num_stages=1, num_warps=1) for i in range(7, 12)
135 """
136 )
137 elif self.MODE == 1:
138 self.writeline(
139 """
140 triton.Config({'BLOCK_HIGH': i, 'BLOCK_LOW': 2 ** j}, num_stages=1, num_warps=1)
141 for i in [6, 11, 22]
142 for j in range(8, 12)
143 """
144 )
145 self.writeline("],")
146 self.writeline("key=['high_num', 'out_cat_num'],")
147 self.writeline("strategy=['log', 'log'],")
148 self.writeline("restore_value=['out'],")
149 self.writeline(")")
150 self.writeline("@triton.jit")
152 def gen_kernel(self):
153 self.writeline(f"def {self.kernel_name}({self.gen_kernel_args()}):")
154 with self.indent():
155 self.writeline("pid = tl.program_id(0)")
156 self.writeline("programs_num = tl.num_programs(0)")
157 if self.MODE == 0:
158 self.writeline(
159 "tiles_per_tensor = tl.cdiv(high_num * tl.cdiv(out_cat_num, tensor_num), BLOCK_LOW)"
160 )
161 self.writeline("num_tiles = tiles_per_tensor * tensor_num")
162 self.writeline("tiles_per_cta = tl.cdiv(num_tiles, programs_num)")
163 self.writeline("for i in range(tiles_per_cta):")
164 with self.indent():
165 self.writeline("tile_id = pid + i * programs_num")
166 self.writeline("tensor_id = tile_id // tiles_per_tensor")
167 self.writeline("tile_id = tile_id % tiles_per_tensor")
168 for j in range(self.tensor_num):
169 self.writeline(f"if tensor_id == {j}:")
170 with self.indent():
171 self.writeline(
172 f"low_cat = low_cat_accum{j + 1} - low_cat_accum{j}"
173 )
174 self.writeline("offsets = tl.arange(0, BLOCK_LOW)")
175 self.writeline("in_offsets = tile_id * BLOCK_LOW + offsets")
176 self.writeline("mask = in_offsets < high_num * low_cat")
177 self.writeline(
178 f"data = tl.load(in{j} + in_offsets, mask=mask)"
179 )
180 high_part = "(in_offsets // low_cat) * out_cat_num"
181 low_part = f"low_cat_accum{j} + (in_offsets % low_cat)"
182 self.writeline(f"out_offsets = {high_part} + {low_part}")
183 self.writeline(
184 "tl.store(out + out_offsets, data, mask=mask)"
185 )
186 elif self.MODE == 1:
187 self.writeline("num_tiles = tl.cdiv(high_num, BLOCK_HIGH)")
188 self.writeline("tiles_per_cta = tl.cdiv(num_tiles, programs_num)")
189 self.writeline("for i in range(tiles_per_cta):")
190 with self.indent():
191 self.writeline("tile_id = pid + i * programs_num")
192 self.writeline("high_offset = tile_id * BLOCK_HIGH")
193 for j in range(self.tensor_num):
194 self.writeline(
195 f"low_cat = low_cat_accum{j + 1}-low_cat_accum{j}"
196 )
197 self.writeline(
198 "for low_offset in range(0, low_cat, BLOCK_LOW):"
199 )
200 with self.indent():
201 self.writeline(
202 "high_offsets = high_offset + tl.arange(0, BLOCK_HIGH)"
203 )
204 self.writeline(
205 "low_offsets = low_offset + tl.arange(0, BLOCK_LOW)"
206 )
207 high_part = f"high_offsets[:, None] * in{j}_stride_high"
208 low_part = f"low_offsets[None, :] * in{j}_stride_low"
209 self.writeline(f"in_offsets = {high_part} + {low_part}")
210 self.writeline(
211 "in_mask = (high_offsets < high_num)[:,None] & (low_offsets < low_cat)[None,:]"
212 )
213 self.writeline(
214 f"data = tl.load(in{j}+in_offsets, mask=in_mask)"
215 )
216 high_part = "high_offsets[:, None] * out_stride_high"
217 low_part = f"(low_cat_accum{j} + low_offsets[None, :]) * out_stride_low"
218 self.writeline(f"out_offsets = {high_part} + {low_part}")
219 self.writeline(
220 "tl.store(out+out_offsets, data, mask=in_mask)"
221 )
223 def gen_kernel_args(self, is_declare=True):
224 in_args = ", ".join(
225 f"in{i}" if is_declare else f"tensors[{i}]" for i in range(self.tensor_num)
226 )
227 low_cat_accum_args = ", ".join(
228 f"low_cat_accum{i}" if is_declare else f"low_cat_accum[{i}]"
229 for i in range(self.tensor_num + 1)
230 )
231 stride_args = (
232 ", ".join(
233 f"in{i}_stride_high, in{i}_stride_low" for i in range(self.tensor_num)
234 )
235 + ", out_stride_high, out_stride_low"
236 )
238 kernel_args = f"{in_args}, out, {stride_args}, tensor_num, high_num, {low_cat_accum_args}, out_cat_num, "
239 ex_args = "BLOCK_LOW: tl.constexpr, num_warps: tl.constexpr"
240 if self.MODE == 1:
241 ex_args += ", BLOCK_HIGH: tl.constexpr"
243 return kernel_args if not is_declare else kernel_args + ex_args
245 def codegen(self):
246 self.gen_imports()
247 self.gen_wrapper()
248 self.gen_decorators()
249 self.gen_kernel()
252def cat(
253 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
254) -> torch.Tensor:
255 logger.debug("GEMS_CAMBRICON CAT")
257 # Check empty inputs.
258 if len(tensors) == 0:
259 raise RuntimeError(
260 "Expected a non-empty list or tuple/list of non-empty torch.Tensor"
261 )
262 if len(tensors) == 1:
263 return tensors[0]
265 # remove torch.Size([0]) tensors
266 device = tensors[0].device
267 dtype = tensors[0].dtype
268 tensors = list(tensors)
270 for i in range(len(tensors) - 1, -1, -1):
271 if tensors[i].shape == torch.Size([0]):
272 tensors.pop(i)
273 if len(tensors) == 0:
274 return torch.tensor([], dtype=dtype, device=device)
275 elif len(tensors) == 1:
276 return tensors[0]
278 # Check dimensions.
279 ndim = tensors[0].ndim
280 assert dim >= -ndim and dim < ndim, f"Invalid concat dimension: {dim}"
281 dim %= ndim
283 # Check shapes and zero element tensors.
284 device = tensors[0].device
285 dtypes = [t.dtype for t in tensors]
286 dtype = dtypes[0]
287 for ty in dtypes[1:]:
288 dtype = torch.promote_types(dtype, ty)
289 shape = tensors[0].shape
290 valid_tensors = []
292 for _, tensor in enumerate(tensors):
293 assert (
294 tensor.ndim == ndim
295 ), f"Requires same ndim of inputs, but got {ndim} and {tensor.ndim}"
296 assert (
297 tensor.device == device
298 ), f"Requires same device of inputs, but got {device} and {tensor.device}"
299 for d_idx, (size, base_size) in enumerate(zip(tensor.shape, shape)):
300 assert (
301 dim == d_idx or size == base_size
302 ), f"Requires same dim sizes of dim {d_idx}, but got {size} and {base_size}"
303 if tensor.numel() != 0:
304 tensor = tensor.contiguous()
305 valid_tensors.append(tensor.to(dtype) if tensor.dtype != dtype else tensor)
307 tensor_num = len(valid_tensors)
309 # Deal with special cases.
310 if tensor_num == 1:
311 return valid_tensors[0]
313 cat_dim_sizes = [_.shape[dim] for _ in tensors]
314 out_shape = list(tensors[0].shape)
315 out_shape[dim] = sum(cat_dim_sizes)
317 if tensor_num == 0:
318 return torch.empty(out_shape, dtype=dtype, device=device)
320 # Preprocess kernel parameters.
321 high_num = int(math.prod(out_shape[:dim]))
322 low_num = int(math.prod(out_shape[dim + 1 :]))
323 out_cat_num = 0
324 low_cat_accum = [0]
326 for size in cat_dim_sizes:
327 out_cat_num += size * low_num
328 low_cat_accum.append(out_cat_num)
330 # Launch kernel.
331 if high_num == 1:
332 # Vstack and Concat results in the same storage arrangement when high_num == 1.
333 valid_tensors = [t.view(t.shape[dim], -1) for t in valid_tensors]
334 return vstack(valid_tensors).view(out_shape)
335 else:
336 # Dealing with concat situations that having arbitary nums of inputs via template code genertaor.
337 return CatKernelGenerator()(valid_tensors, dim, high_num, low_cat_accum)