Coverage for src/flag_gems/runtime/backend/_cambricon/ops/vstack.py: 0%
184 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import importlib
2import logging
3import math
4import os
5from typing import Callable, Mapping
7import torch
9from flag_gems.utils.code_cache import cache_dir
10from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
12from ..utils import TOTAL_CORE_NUM
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17class VstackKernelCode(IndentedBuffer):
18 """
19 Vstack kernel template.
20 """
22 overloads: Mapping[str, Callable] = {}
24 def __init__(self):
25 self.pid = os.getpid()
26 self.cache = self.overloads
27 self.kernel_name = "_vstack_jit_kernel"
28 self.wrapper_func_name = "_wrapper"
29 self.vstack_small_limit = 49152
30 super(VstackKernelCode, self).__init__()
32 def __init(self, tensors):
33 """Initialize the vstack kernel."""
34 self.device = tensors[0].device
35 dtypes = [t.dtype for t in tensors]
36 dtype = dtypes[0]
37 for ty in dtypes[1:]:
38 dtype = torch.promote_types(dtype, ty)
39 self.dtype = dtype
40 for i, tensor in enumerate(tensors):
41 assert (
42 tensor.device == self.device
43 and tensor.dim() == tensors[0].dim()
44 and tensors[0].shape[1:] == tensor.shape[1:]
45 if tensors[0].dim() > 1
46 else tensors[0].shape == tensor.shape
47 )
48 if tensor.dtype != self.dtype:
49 tensors[i] = tensor.to(self.dtype)
50 c_tensors = [t.contiguous() for t in tensors]
51 self.inputs = []
52 self.idxs = [0]
53 self.total_size = 0
54 for tensor in c_tensors:
55 self.total_size += tensor.numel()
56 self.idxs.append(self.total_size)
57 self.inputs.append(tensor)
58 self.deal_num = math.ceil(self.total_size / TOTAL_CORE_NUM)
59 self.input_num = len(self.inputs)
60 flag = (self.total_size / self.input_num) == self.idxs[1]
61 if (
62 self.total_size < self.vstack_small_limit
63 and self.input_num <= TOTAL_CORE_NUM
64 and flag
65 ):
66 self.is_small = True
67 else:
68 self.is_small = False
70 def __imports(self):
71 """Generate imports for the kernel code."""
72 self.tpl(
73 """
74import math
75import torch
76import triton
77from triton import language as tl
78from flag_gems.runtime import torch_device_fn
79from flag_gems.utils import libentry, libtuner
80from flag_gems.runtime.backend import vendor_module
81TOTAL_CORE_NUM = vendor_module.TOTAL_CORE_NUM
82MAX_NRAM_SIZE = vendor_module.MAX_NRAM_SIZE
83 """
84 )
86 def __wrapper(self):
87 """Generate wrapper function for the kernel code."""
88 self.newline()
89 self.tpl(
90 """
91def {wrapper_name}(tensors, inputs, idx, total_size, input_num, deal_num, is_small):
92 tensors = torch.atleast_2d(tensors)
93 num_tensors = len(tensors)
94 assert num_tensors > 0
95 if num_tensors == 1:
96 return tensors[0]
97 device = tensors[0].device
98 dtype = tensors[0].dtype
99 input = [i for i in inputs]
100 c_tensors = [t.contiguous() for t in tensors]
101 total_rows = sum(tensor.shape[0] for tensor in c_tensors)
102 output_shape = list(c_tensors[0].shape)
103 output_shape[0] = total_rows
104 output = torch.empty(output_shape, device=device, dtype=dtype)
105 with torch_device_fn.device(device):
106 {kernel_name}[(TOTAL_CORE_NUM,)]({args})
107 return output
108 """,
109 wrapper_name=self.wrapper_func_name,
110 kernel_name=self.kernel_name,
111 args=self.__kernel_args(is_declare=False),
112 )
114 def __config(self):
115 """Generate config for the kernel code."""
116 # generate config key.
117 self.newline()
118 self.tpl(
119 """
120@libentry()
121@libtuner(
122 configs=[
123 triton.Config({{'BLOCK_SIZE' : 512}}, num_warps=1),
124 triton.Config({{'BLOCK_SIZE' : 2048}}, num_warps=1),
125 triton.Config({{'BLOCK_SIZE' : 4096}}, num_warps=1),
126 triton.Config({{'BLOCK_SIZE' : 8192}}, num_warps=1),
127 triton.Config({{'BLOCK_SIZE' : 10240}}, num_warps=1),
128 triton.Config({{'BLOCK_SIZE' : 14336}}, num_warps=1),
129 triton.Config({{'BLOCK_SIZE' : 18000}}, num_warps=1),
130 triton.Config({{'BLOCK_SIZE' : 22000}}, num_warps=1),
131 triton.Config({{'BLOCK_SIZE' : 28000}}, num_warps=1),
132 triton.Config({{'BLOCK_SIZE' : 32000}}, num_warps=1),
133 ],
134 key = [{config_keys}],
135)
136@triton.jit
137 """,
138 config_keys="'total_size'",
139 )
141 def __kernel(self):
142 """Generate kernel code body."""
143 # configuration.
144 self.__config()
145 kernel_signature = f"def {self.kernel_name}({self.__kernel_args()}):"
146 self.idx_1 = 1
147 self.idx_0 = 0
148 self.writeline(kernel_signature)
149 with self.indent():
150 self.writeline("pid_x = tl.program_id(0)")
151 self.writeline("block = tl.arange(0, BLOCK_SIZE)")
152 self.writeline("if is_small:")
153 with self.indent():
154 self.writeline("for i in range(input_num):")
155 with self.indent():
156 for i in range(self.input_num):
157 self.writeline(f"if pid_x == {i} and pid_x == i:")
158 with self.indent():
159 self.writeline(
160 f"for num in range(0, idx_{i + 1} - idx_{i}, BLOCK_SIZE):"
161 )
162 with self.indent():
163 self.writeline("in_offset = num + block")
164 self.writeline(f"dst_offset = idx_{i} + num + block")
165 self.writeline(
166 f"x = tl.load(input_{i} + in_offset, mask = in_offset < idx_{i + 1} - idx_{i})"
167 )
168 self.writeline(
169 f"tl.store(output + dst_offset, x, mask = dst_offset < idx_{i + 1})"
170 )
171 self.writeline("else:")
172 with self.indent():
173 self.writeline("condidate_num = idx_1")
174 self.writeline("input_iter = 0")
175 self.writeline("for pid in range(pid_x + 1):")
176 with self.indent():
177 self.writeline("need_num = deal_num")
178 self.writeline("while(need_num > 0):")
179 with self.indent():
180 self.writeline("per_fetch_num = min(condidate_num, need_num)")
181 self.writeline("if pid == pid_x:")
182 with self.indent():
183 self.writeline("if input_iter == 0:")
184 with self.indent():
185 self.writeline("offset = idx_1 - idx_0 - condidate_num")
186 self.writeline("deal_rem = deal_num - per_fetch_num")
187 self.writeline(
188 "for i in range(0, deal_num, BLOCK_SIZE):"
189 )
190 with self.indent():
191 self.writeline("in_offset = offset + i + block")
192 self.writeline("dst_offset = in_offset")
193 self.writeline(
194 "x = tl.load(input_0 + in_offset, mask=in_offset < idx_1 - idx_0)"
195 )
196 self.writeline(
197 "tl.store(output + dst_offset, x, mask=dst_offset<idx_1)"
198 )
199 if self.input_num > 1:
200 self.writeline("else:")
201 with self.indent():
202 for i in range(1, self.input_num, 1):
203 idx = i + 1
204 self.writeline(f"if input_iter == {i}:")
205 with self.indent():
206 self.writeline(
207 f"offset = idx_{idx} - idx_{i} - condidate_num"
208 )
209 self.writeline("if need_num != deal_num:")
210 with self.indent():
211 self.writeline(
212 "deal_rem = deal_num - per_fetch_num"
213 )
214 self.writeline(
215 "for i in range(0, need_num, BLOCK_SIZE):"
216 )
217 with self.indent():
218 self.writeline(
219 "in_offset = offset + i + block"
220 )
221 self.writeline(
222 f"dst_offset = idx_{i} + in_offset"
223 )
224 self.writeline(
225 f"x = tl.load(input_{i} + in_offset, mask=in_offset < need_num)"
226 )
227 self.writeline(
228 f"tl.store(output + dst_offset, x, \
229 mask=dst_offset<idx_{i}+per_fetch_num)"
230 )
231 self.writeline("else:")
232 with self.indent():
233 self.writeline(
234 "for i in range(0, need_num, BLOCK_SIZE):"
235 )
236 with self.indent():
237 self.writeline(
238 "in_offset = offset + i + block"
239 )
240 self.writeline(
241 f"dst_offset = idx_{i} + in_offset"
242 )
243 self.writeline(
244 f"x = tl.load(input_{i} + in_offset, \
245 mask=in_offset < idx_{idx}-idx_{i})"
246 )
247 self.writeline(
248 f"tl.store(output + dst_offset, x, mask=dst_offset<idx_{idx})"
249 )
250 self.writeline("condidate_num -= per_fetch_num")
251 self.writeline("need_num -= per_fetch_num")
252 self.writeline("if (condidate_num <= 0):")
253 with self.indent():
254 for i in range(1, self.input_num, 1):
255 idx = i + 1
256 input_idx = i - 1
257 if self.input_num == 2:
258 self.writeline(
259 f"condidate_num = idx_{idx} - idx_{i}"
260 )
261 else:
262 if i == 1:
263 self.writeline(f"if input_iter == {input_idx}:")
264 with self.indent():
265 self.writeline(
266 f"condidate_num = idx_{idx} - idx_{i}"
267 )
268 else:
269 if i < self.input_num - 1:
270 self.writeline(
271 f"elif input_iter == {input_idx}:"
272 )
273 with self.indent():
274 self.writeline(
275 f"condidate_num = idx_{idx} - idx_{i}"
276 )
277 else:
278 self.writeline("else:")
279 with self.indent():
280 self.writeline(
281 f"condidate_num = idx_{idx} - idx_{i}"
282 )
284 self.writeline("input_iter += 1")
286 def __gen_code(self):
287 """Entry point for code generation of vstack."""
288 # generate imports.
289 self.__imports()
290 # generate wrapper function.
291 self.__wrapper()
293 # generate kernel.
294 self.__kernel()
296 def __kernel_args(self, is_declare=True):
297 input_args = []
298 idxs_args = []
299 if is_declare:
300 for i in range(self.input_num):
301 input_args.append(f"input_{i}")
302 for i in range(len(self.idxs)):
303 idxs_args.append(f"idx_{i}")
304 else:
305 for i in range(self.input_num):
306 input_args.append(f"input[{i}]")
307 for i in range(len(self.idxs)):
308 idxs_args.append(f"idx[{i}]")
309 input_args_str = ", ".join(input_args)
310 idxs_args_str = ", ".join(idxs_args)
312 extra_args_str = f"{input_args_str}, {idxs_args_str}"
313 if is_declare:
314 return f"{extra_args_str}, output, total_size, input_num, deal_num, is_small, BLOCK_SIZE: tl.constexpr"
315 else:
316 return (
317 f"{extra_args_str}, output, total_size, input_num, deal_num, is_small"
318 )
320 def __call__(self, tensors: list) -> torch.Tensor:
321 # get overload kernel.
322 self.__init(tensors)
324 vstack_input_num = "_".join(str(self.input_num))
326 self.kernel_name = self.kernel_name + "_vstack_" + vstack_input_num
327 key = f"{self.total_size}_{self.input_num}"
328 if key not in self.cache:
329 # generate code and cache.
330 self.__gen_code()
331 file_name = f"vstack_{key}_pid_{self.pid}.py"
332 filepath = cache_dir() / file_name
333 write_atomic(filepath, self.getvalue())
334 # load
335 spec = importlib.util.spec_from_file_location(
336 f"_gen_module_{key}_pid_{self.pid}", filepath
337 )
338 m = importlib.util.module_from_spec(spec)
339 # do not expose it to sys.modules
340 # sys.modules["_add_module"] = m
341 spec.loader.exec_module(m)
342 overload = getattr(m, self.wrapper_func_name)
343 self.cache[key] = overload
344 overload = self.cache[key]
345 return overload(
346 tensors,
347 self.inputs,
348 self.idxs,
349 self.total_size,
350 self.input_num,
351 self.deal_num,
352 self.is_small,
353 )
356def vstack(tensors: list):
357 logger.debug("GEMS_CAMBRICON VSTACK")
359 return VstackKernelCode()(tensors)