Coverage for src/flag_gems/runtime/backend/_cambricon/ops/stack.py: 0%

104 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import importlib 

2import logging 

3import math 

4import os 

5import textwrap 

6from typing import Callable, List, Mapping, Tuple, Union 

7 

8import torch 

9 

10from flag_gems.utils.code_cache import cache_dir 

11from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

12 

13from ..utils import TOTAL_CORE_NUM 

14from .vstack import vstack 

15 

16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

17 

18 

19def get_dtype_size(dtype): 

20 try: 

21 return torch.finfo(dtype).bits // 8 

22 except TypeError: 

23 try: 

24 return torch.iinfo(dtype).bits // 8 

25 except TypeError: 

26 if dtype == torch.bool: 

27 return 1 

28 else: 

29 raise ValueError(f"Unsupported dtype: {dtype}") 

30 

31 

32class StackKernelCode(IndentedBuffer): 

33 """ 

34 Stack kernel template. 

35 """ 

36 

37 overloads: Mapping[str, Callable] = {} 

38 

39 def __init__(self): 

40 self.pid = os.getpid() 

41 self.cache = self.overloads 

42 self.kernel_name = "_stack_jit_kernel" 

43 self.wrapper_func_name = "_wrapper" 

44 super(StackKernelCode, self).__init__() 

45 

46 def __imports(self): 

47 """Generate imports for the kernel code.""" 

48 tpl = """\ 

49 import math 

50 import torch 

51 import triton 

52 from triton import language as tl 

53 from typing import List, Tuple, Union 

54 from flag_gems.utils import libentry 

55 from flag_gems.runtime.backend import vendor_module 

56 TOTAL_CORE_NUM = vendor_module.TOTAL_CORE_NUM 

57 MAX_NRAM_SIZE = vendor_module.MAX_NRAM_SIZE 

58 

59 """ 

60 self.tpl(textwrap.dedent(tpl)) 

61 

62 def __wrapper(self): 

63 """Generate wrapper function for the kernel code.""" 

64 self.newline() 

65 tpl = """\ 

66 def {wrapper_name}( 

67 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

68 ) -> torch.Tensor: 

69 if len(tensors) == 0: 

70 raise RuntimeError("stack expected a non-empty TensorList") 

71 

72 inp_shapes = [list(_.shape) for _ in tensors] 

73 inp0_shape = inp_shapes[0] 

74 for i, s in enumerate(inp_shapes[1:]): 

75 if (dim < -tensors[i + 1].dim() - 1) or (dim > tensors[i + 1].dim()): 

76 raise IndexError( 

77 "Dimension out of range (expected to be in range of [{{}}, {{}}], but got {{}})".format( 

78 -tensors[i + 1].dim() - 1, tensors[i + 1].dim(), dim 

79 ) 

80 ) 

81 if s != inp0_shape: 

82 raise RuntimeError( 

83 f"stack expects each tensor to be equal size, \ 

84 but got {{inp0_shape}} at entry 0 and {{s}} at entry {{i+1}}" 

85 ) 

86 

87 if dim < 0: 

88 dim = dim + len(inp0_shape) + 1 

89 out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:] 

90 high = int(math.prod(out_shape[:dim])) 

91 low = int(math.prod(out_shape[dim+1:])) 

92 tensor_num = len(tensors) 

93 out0 = torch.empty(out_shape, dtype=tensors[0].dtype, device=tensors[0].device) 

94 def grid(meta): 

95 if meta['BLOCK_SIZE']>0: 

96 task_x = high 

97 task_y = tensor_num 

98 task_z = triton.cdiv(low ,meta['BLOCK_SIZE']) 

99 return (task_x, task_y, task_z) 

100 else: 

101 total_task = high * tensor_num 

102 if meta['LOW_NUM']>0: 

103 core_used = triton.cdiv(total_task // meta['LOW_NUM'], meta['TASK_PER_CORE']) 

104 elif meta['N_LOW_NUM']>0: 

105 core_used = triton.cdiv(high, meta['TASK_PER_CORE']) 

106 return (core_used,) 

107 {kernel_name}[grid]( 

108 out0, 

109 *tensors, 

110 high, 

111 tensor_num, 

112 low, 

113 ) 

114 return out0 

115 """ 

116 self.tpl( 

117 textwrap.dedent(tpl), 

118 wrapper_name=self.wrapper_func_name, 

119 kernel_name=self.kernel_name, 

120 ) 

121 

122 def __config(self, tensor_num, high, low, dtype): 

123 """Generate config for the kernel code.""" 

124 dtyp_bytest = get_dtype_size(dtype) 

125 # Since the kernel has three branches, each branch has its own parameters, 

126 # so for a certain branch, the other parameters can be directly set to zero. 

127 

128 # 1)N_LOW_NUM branch: NRAM can hold at least one set of `tensor_num * low * dtyp_bytest`. 

129 # This parameter is used to indicate how many `tensor_num * low` are processed in a single core. 

130 

131 # 2) LOW_NUM branch: NRAM can hold at least one set of `low * dtyp_bytest`, 

132 # but cannot hold tensor_num * low * dtyp_bytest. 

133 # This parameter is used to indicate how many `low` are processed in a single core. 

134 

135 # 3) BLOCK_SIZE branch: NRAM is not enough to store a set of `low`, 

136 # so it can only loop multiple times to process a set of `low`. 

137 # This parameter is used to indicates how many elements to load 

138 # at a time when looping over and processing low. 

139 tpl = """\ 

140 def cfggen(): 

141 N_LOW_NUM = {n_low_num_options} 

142 LOW_NUM = {low_num_options} 

143 BLOCK_SIZE = {block_size_options} 

144 warps = [1] 

145 num_stages = {num_stages} 

146 configs = [ 

147 triton.Config( 

148 {{ 

149 "BLOCK_SIZE": block_size, 

150 "N_LOW_NUM": n_low_num, 

151 "LOW_NUM": low_num, 

152 }}, 

153 num_warps=w, 

154 num_stages=s) 

155 for block_size in BLOCK_SIZE 

156 for n_low_num in N_LOW_NUM 

157 for low_num in LOW_NUM 

158 for w in warps for s in num_stages 

159 ] 

160 return configs 

161 performance_related_keys = {keys} 

162 """ 

163 

164 # If `tensor_num * low * dtyp_bytest` is less than `nram_threshold`, 

165 # use N_LOW_NUM branch, otherwise LOW_NUM branch. 

166 nram_threshold = 170000 

167 # The maximum number of elements in triton is 1048576. 

168 max_elements_num = 1048576 

169 # after removing the overhead of pipeline and temporary variables. 

170 if tensor_num * low * dtyp_bytest <= nram_threshold: 

171 n_low_per_core = math.ceil(high / TOTAL_CORE_NUM) 

172 limited_by_nram = nram_threshold // dtyp_bytest // (tensor_num * low) 

173 limited_by_triton = max_elements_num // (tensor_num * low) 

174 best_opt = min(n_low_per_core, limited_by_triton, limited_by_nram) 

175 self.tpl( 

176 textwrap.dedent(tpl), 

177 n_low_num_options=f"{[best_opt]}", 

178 low_num_options=r"[0]", 

179 block_size_options=r"[0]", 

180 num_stages=r"[1]", 

181 keys=r'["high"]', 

182 ) 

183 elif low * dtyp_bytest <= nram_threshold: 

184 self.tpl( 

185 textwrap.dedent(tpl), 

186 n_low_num_options=r"[0]", 

187 low_num_options=r"[1,2,3]", 

188 block_size_options=r"[0]", 

189 num_stages=f"{[1]}", 

190 keys=r'["high", "tensor_num", "low"]', 

191 ) 

192 else: 

193 self.tpl( 

194 textwrap.dedent(tpl), 

195 n_low_num_options=r"[0]", 

196 low_num_options=r"[0]", 

197 block_size_options=r"[8192, 16384, 32768, 65536, 131072, 262144]", 

198 num_stages=r"[1]", 

199 keys=r'["low"]', 

200 ) 

201 

202 def __kernel(self, tensor_num): 

203 """Generate kernel code body.""" 

204 tpl = """\ 

205 def stack_heuristics(args, need_key): 

206 ret = {{ 

207 'TASK_PER_CORE': 0, 

208 'TASK_LAST_CORE_REPEAT': 0, 

209 'TASK_LAST_CORE_REMAIN': 0, 

210 }} 

211 total_task = args['high']*args['tensor_num'] 

212 if args['LOW_NUM']>0: 

213 LOW_NUM = args['LOW_NUM'] if total_task > args['LOW_NUM'] else total_task 

214 ret['TASK_PER_CORE'] = triton.cdiv(total_task // LOW_NUM, TOTAL_CORE_NUM) 

215 assert ret['TASK_PER_CORE']>0, ret['TASK_PER_CORE'] 

216 core_used = triton.cdiv(total_task // LOW_NUM, ret['TASK_PER_CORE']) 

217 task_last_core = total_task-(core_used-1)*ret['TASK_PER_CORE']*LOW_NUM 

218 ret['TASK_LAST_CORE_REPEAT'] = task_last_core//LOW_NUM 

219 ret['TASK_LAST_CORE_REMAIN'] = task_last_core%LOW_NUM 

220 elif args['N_LOW_NUM']>0: 

221 ret['TASK_PER_CORE'] = triton.cdiv(args['high'], TOTAL_CORE_NUM) 

222 core_used = triton.cdiv(args['high'], ret['TASK_PER_CORE']) 

223 ret['TASK_LAST_CORE_REPEAT'] = args['high'] -(core_used-1)*ret['TASK_PER_CORE'] 

224 return ret[need_key] 

225 

226 @triton.jit() 

227 def load_trans_store( 

228 low: tl.constexpr, 

229 tensor_num: tl.constexpr, 

230 {tensors}, 

231 offset, 

232 buffer, 

233 buffer_offset, 

234 output, 

235 out_offset, 

236 ): 

237 if low >64: 

238 {low_gt_64_code} 

239 tl.store(output+offset*tensor_num+out_offset, buffer) 

240 else: 

241 {low_le_64_code} 

242 tl.store(output+offset*tensor_num+out_offset, tl.trans(buffer, 1, 0, 2)) 

243 

244 @triton.jit() 

245 def load_and_store( 

246 output_ptr, 

247 buffer, 

248 buffer_offset, 

249 task_id, 

250 LOW_NUM: tl.constexpr, 

251 low: tl.constexpr, 

252 LOW_OFFSET: tl.constexpr, 

253 tensor_num: tl.constexpr, 

254 {tensors} 

255 ): 

256 for low_idx in tl.range(LOW_NUM): 

257 cur_low_id = task_id + low_idx 

258 tensor_idx = cur_low_id%tensor_num 

259 high_idx = cur_low_id//tensor_num 

260 load_start = high_idx *low 

261 {load_and_store_code} 

262 tl.store(output_ptr+buffer_offset, buffer) 

263 

264 @libentry() 

265 @triton.autotune(configs=cfggen(), key=performance_related_keys) 

266 @triton.heuristics( 

267 {{ 

268 "TASK_PER_CORE": lambda args: stack_heuristics(args, "TASK_PER_CORE"), 

269 "TASK_LAST_CORE_REPEAT": lambda args: stack_heuristics(args, "TASK_LAST_CORE_REPEAT"), 

270 "TASK_LAST_CORE_REMAIN": lambda args: stack_heuristics(args, "TASK_LAST_CORE_REMAIN"), 

271 }} 

272 ) 

273 @triton.jit() 

274 def {kernel_name}( 

275 output, 

276 {tensors}, 

277 high: tl.constexpr, 

278 tensor_num: tl.constexpr, 

279 low: tl.constexpr, 

280 N_LOW_NUM: tl.constexpr, 

281 LOW_NUM: tl.constexpr, 

282 TASK_PER_CORE: tl.constexpr, 

283 TASK_LAST_CORE_REPEAT: tl.constexpr, 

284 TASK_LAST_CORE_REMAIN: tl.constexpr, 

285 BLOCK_SIZE: tl.constexpr): 

286 if N_LOW_NUM>0: 

287 # The memory space is sufficient to hold at least one set of "tensor_num* low * type_bytes" 

288 core_idx = tl.program_id(0) 

289 core_used = tl.num_programs(0) 

290 if core_idx>=core_used: 

291 return 

292 in_offset = core_idx*TASK_PER_CORE*low 

293 if low >64: 

294 buffer_repeat = tl.empty(shape=[N_LOW_NUM, tensor_num, low], dtype=output.dtype.element_ty) 

295 else: 

296 buffer_repeat = tl.empty(shape=[tensor_num, N_LOW_NUM, low], dtype=output.dtype.element_ty) 

297 buffer_repeat_offset = tl.arange(0, N_LOW_NUM)[:, None]*low+tl.arange(0, low)[None,:] 

298 out_repeat_offset= \\ 

299 tl.arange(0, N_LOW_NUM)[:,None,None]*low*tensor_num+\\ 

300 tl.arange(0, tensor_num)[None,:,None]*low+\\ 

301 tl.arange(0, low)[None, None,:] 

302 if core_idx !=core_used -1: 

303 for repeat_idx in range(TASK_PER_CORE//N_LOW_NUM): 

304 repeat_offset = in_offset + repeat_idx*N_LOW_NUM*low 

305 load_trans_store(low, tensor_num, {tensors},repeat_offset, buffer_repeat,\\ 

306 buffer_repeat_offset,output,out_repeat_offset) 

307 if (TASK_PER_CORE%N_LOW_NUM) > 0: 

308 normal_remain_offset = in_offset + (TASK_PER_CORE//N_LOW_NUM)*N_LOW_NUM*low 

309 if low >64: 

310 buffer_normal_remain = tl.empty(shape=[TASK_PER_CORE%N_LOW_NUM,tensor_num, low], \\ 

311 dtype=output.dtype.element_ty) 

312 else: 

313 buffer_normal_remain = tl.empty(shape=[tensor_num,TASK_PER_CORE%N_LOW_NUM, low], \\ 

314 dtype=output.dtype.element_ty) 

315 buffer_normal_remain_offset = tl.arange(0, TASK_PER_CORE%N_LOW_NUM)[:, None]*low + \\ 

316 tl.arange(0, low)[None,:] 

317 out_normal_remain_offset= \\ 

318 tl.arange(0, TASK_PER_CORE%N_LOW_NUM)[:,None,None]*low*tensor_num+\\ 

319 tl.arange(0, tensor_num)[None,:,None]*low+\\ 

320 tl.arange(0, low)[None, None,:] 

321 load_trans_store(low, tensor_num, {tensors},normal_remain_offset, buffer_normal_remain,\\ 

322 buffer_normal_remain_offset,output,out_normal_remain_offset) 

323 else: 

324 for repeat_idx in range(TASK_LAST_CORE_REPEAT//N_LOW_NUM): 

325 repeat_offset = in_offset + repeat_idx*N_LOW_NUM*low 

326 load_trans_store(low, tensor_num, {tensors},repeat_offset, buffer_repeat,\\ 

327 buffer_repeat_offset,output,out_repeat_offset) 

328 if (TASK_LAST_CORE_REPEAT%N_LOW_NUM) >0 : 

329 last_core_remain_offset = in_offset + (TASK_LAST_CORE_REPEAT//N_LOW_NUM)*N_LOW_NUM*low 

330 if low >64: 

331 buffer_last_core_remain = \\ 

332 tl.empty(shape=[TASK_LAST_CORE_REPEAT%N_LOW_NUM,tensor_num, low], \\ 

333 dtype=output.dtype.element_ty) 

334 else: 

335 buffer_last_core_remain = \\ 

336 tl.empty(shape=[tensor_num,TASK_LAST_CORE_REPEAT%N_LOW_NUM, low], \\ 

337 dtype=output.dtype.element_ty) 

338 buffer_last_core_remain_offset = \\ 

339 tl.arange(0, TASK_LAST_CORE_REPEAT%N_LOW_NUM)[:, None]*low + \\ 

340 tl.arange(0, low)[None,:] 

341 out_last_core_remain_offset= \\ 

342 tl.arange(0, TASK_LAST_CORE_REPEAT%N_LOW_NUM)[:,None,None]*low*tensor_num+\\ 

343 tl.arange(0, tensor_num)[None,:,None]*low+\\ 

344 tl.arange(0, low)[None, None,:] 

345 load_trans_store(low, tensor_num, {tensors},last_core_remain_offset, \\ 

346 buffer_last_core_remain, buffer_last_core_remain_offset, \\ 

347 output,out_last_core_remain_offset) 

348 elif LOW_NUM>0: 

349 # The memory space is sufficient to hold at least one set of "low * type_bytes" 

350 core_idx = tl.program_id(0) 

351 core_used = tl.num_programs(0) 

352 if core_idx>=core_used: 

353 return 

354 dtype = output.dtype.element_ty 

355 buffer = tl.empty(shape=[LOW_NUM,low], dtype=dtype) 

356 buffer_offset = tl.arange(0, LOW_NUM)[:,None]*low+tl.arange(0, low)[None,:] 

357 LOW_OFFSET = tl.arange(0, low) 

358 if core_idx != core_used-1: 

359 for cycles_idx in range(TASK_PER_CORE): 

360 task_id = core_idx*TASK_PER_CORE*LOW_NUM+cycles_idx*LOW_NUM 

361 out_ptr = output + task_id*low 

362 load_and_store( 

363 out_ptr, 

364 buffer, 

365 buffer_offset, 

366 task_id, 

367 LOW_NUM, 

368 low, 

369 LOW_OFFSET, 

370 tensor_num, 

371 {tensors} 

372 ) 

373 else: 

374 base_task_id = core_idx*TASK_PER_CORE*LOW_NUM 

375 for cycles_idx in range(TASK_LAST_CORE_REPEAT): 

376 task_id= base_task_id+cycles_idx*LOW_NUM 

377 out_ptr = output + task_id*low 

378 load_and_store( 

379 out_ptr, 

380 buffer, 

381 buffer_offset, 

382 task_id, 

383 LOW_NUM, 

384 low, 

385 LOW_OFFSET, 

386 tensor_num, 

387 {tensors} 

388 ) 

389 task_id = base_task_id+TASK_LAST_CORE_REPEAT*LOW_NUM 

390 output_ptr = output + task_id*low 

391 for low_idx in tl.range(TASK_LAST_CORE_REMAIN): 

392 cur_low_id = task_id + low_idx 

393 tensor_idx = cur_low_id%tensor_num 

394 high_idx = cur_low_id//tensor_num 

395 load_start = high_idx *low 

396 {low_num_gt_0_last_core_code} 

397 tl.store(output_ptr+buffer_offset, buffer, mask=buffer_offset<TASK_LAST_CORE_REMAIN*low) 

398 elif BLOCK_SIZE>0: 

399 # Insufficient memory space to hold a set of "low* type_bytes" 

400 high_idx = tl.program_id(0) 

401 tensor_idx = tl.program_id(1) 

402 output_ptr = output + high_idx*(low*tensor_num)+tensor_idx*low 

403 offset_in_loop = tl.program_id(2)*BLOCK_SIZE+tl.arange(0, BLOCK_SIZE) 

404 x = tl.empty(shape=[BLOCK_SIZE,],dtype=output.dtype.element_ty) 

405 {block_size_gt_0_code} 

406 tl.store(output_ptr+offset_in_loop, x, mask=offset_in_loop<low) 

407 """ 

408 

409 def add_indent(cleaned_str, indent_size): 

410 return "\n".join( 

411 [f"{' ' * indent_size}{line}" for line in cleaned_str.split("\n")] 

412 ) 

413 

414 tensors = ", ".join([f"in_{idx}" for idx in range(tensor_num)]) 

415 load_form_inputs = textwrap.dedent( 

416 """\ 

417 if tensor_idx == 0: 

418 buffer[low_idx,:] = tl.load(in_0+load_start+LOW_OFFSET)\n""" 

419 + "\n".join( 

420 [ 

421 f"""\ 

422 elif tensor_idx == {idx}: 

423 buffer[low_idx,:] = tl.load(in_{idx}+load_start+LOW_OFFSET)""" 

424 for idx in range(1, tensor_num - 1) 

425 ] 

426 ) 

427 + "\n" 

428 + f"""\ 

429 else: 

430 buffer[low_idx,:] = tl.load(in_{tensor_num - 1}+load_start+LOW_OFFSET)""" 

431 ) 

432 self.tpl( 

433 textwrap.dedent(tpl), 

434 kernel_name=self.kernel_name, 

435 tensors=tensors, 

436 low_gt_64_code="\n".join( 

437 [ 

438 f"{' ' * 8}buffer[:,{idx},:]=tl.load(in_{idx}+offset+buffer_offset)" 

439 for idx in range(tensor_num) 

440 ] 

441 ), 

442 low_le_64_code="\n".join( 

443 [ 

444 f"{' ' * 8}buffer[{idx},:,:]=tl.load(in_{idx}+offset+buffer_offset)" 

445 for idx in range(tensor_num) 

446 ] 

447 ), 

448 load_and_store_code=add_indent(load_form_inputs, 8), 

449 low_num_gt_0_last_core_code=add_indent(load_form_inputs, 16), 

450 block_size_gt_0_code=add_indent( 

451 textwrap.dedent( 

452 """\ 

453 if tensor_idx == 0: 

454 x = tl.load(in_0+high_idx *low+offset_in_loop,mask=offset_in_loop<low)\n""" 

455 + "\n".join( 

456 [ 

457 f"""\ 

458 elif tensor_idx == {idx}: 

459 x = tl.load(in_{idx}+high_idx *low+offset_in_loop,mask=offset_in_loop<low)""" 

460 for idx in range(1, tensor_num - 1) 

461 ] 

462 ) 

463 + "\n" 

464 + f"""\ 

465 else: 

466 x = tl.load(in_{tensor_num - 1}+high_idx *low+offset_in_loop,mask=offset_in_loop<low)""" 

467 ), 

468 8, 

469 ), 

470 ) 

471 

472 def __gen_code(self, tensor_num, high, low, dtype): 

473 """Entry point for code generation of stack.""" 

474 # generate imports. 

475 self.__imports() 

476 # generate config. 

477 self.__config(tensor_num, high, low, dtype) 

478 # generate kernel. 

479 self.__kernel(tensor_num) 

480 # generate wrapper function. 

481 self.__wrapper() 

482 

483 def __call__( 

484 self, tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

485 ) -> torch.Tensor: 

486 assert dim != 0, "StackKernel template does not optimize `dim=0`." 

487 tensor_num = len(tensors) 

488 inp0_shape = list(tensors[0].shape) 

489 out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:] 

490 high = int(math.prod(out_shape[:dim])) 

491 low = int(math.prod(out_shape[dim + 1 :])) 

492 dtype = tensors[0].dtype 

493 self.kernel_name = f"{self.kernel_name}_num_{tensor_num}" 

494 key = f"num_{tensor_num}_high_{high}_low_{low}_dtype_{dtype}" 

495 for tensor in tensors[1:]: 

496 assert tensor.dtype == dtype, f"{tensor.dtype} != {dtype}" 

497 if key not in self.cache: 

498 # generate code and cache. 

499 self.__gen_code(tensor_num, high, low, dtype) 

500 file_name = f"{cache_dir()}/stack_{key}_pid_{self.pid}.py" 

501 write_atomic(file_name, self.getvalue()) 

502 # load 

503 spec = importlib.util.spec_from_file_location( 

504 f"_gen_module_{key}_pid_{self.pid}", file_name 

505 ) 

506 m = importlib.util.module_from_spec(spec) 

507 # do not expose it to sys.modules 

508 # sys.modules["_add_module"] = m 

509 spec.loader.exec_module(m) 

510 overload = getattr(m, self.wrapper_func_name) 

511 self.cache[key] = overload 

512 overload = self.cache[key] 

513 return overload(tensors, dim) 

514 

515 

516def stack( 

517 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

518) -> torch.Tensor: 

519 logger.debug("GEMS_CAMBRICON STACK") 

520 

521 if len(tensors) == 0: 

522 raise RuntimeError("stack expected a non-empty TensorList") 

523 

524 inp_shapes = [list(_.shape) for _ in tensors] 

525 inp0_shape = inp_shapes[0] 

526 for i, s in enumerate(inp_shapes[1:]): 

527 if (dim < -tensors[i + 1].dim() - 1) or (dim > tensors[i + 1].dim()): 

528 raise IndexError( 

529 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

530 -tensors[i + 1].dim() - 1, tensors[i + 1].dim(), dim 

531 ) 

532 ) 

533 if s != inp0_shape: 

534 raise RuntimeError( 

535 f"stack expects each tensor to be equal size, but got {inp0_shape} at entry 0 and {s} at entry {i + 1}" 

536 ) 

537 

538 if dim < 0: 

539 dim = dim + len(inp0_shape) + 1 

540 out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:] 

541 if dim == 0: 

542 return vstack(tensors).view(out_shape) 

543 tensors = [ 

544 tensor if tensor.is_contiguous() else tensor.contiguous() for tensor in tensors 

545 ] 

546 return StackKernelCode()(tensors, dim)