Coverage for src/flag_gems/runtime/backend/_cambricon/ops/vstack.py: 0%

184 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import importlib 

2import logging 

3import math 

4import os 

5from typing import Callable, Mapping 

6 

7import torch 

8 

9from flag_gems.utils.code_cache import cache_dir 

10from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

11 

12from ..utils import TOTAL_CORE_NUM 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17class VstackKernelCode(IndentedBuffer): 

18 """ 

19 Vstack kernel template. 

20 """ 

21 

22 overloads: Mapping[str, Callable] = {} 

23 

24 def __init__(self): 

25 self.pid = os.getpid() 

26 self.cache = self.overloads 

27 self.kernel_name = "_vstack_jit_kernel" 

28 self.wrapper_func_name = "_wrapper" 

29 self.vstack_small_limit = 49152 

30 super(VstackKernelCode, self).__init__() 

31 

32 def __init(self, tensors): 

33 """Initialize the vstack kernel.""" 

34 self.device = tensors[0].device 

35 dtypes = [t.dtype for t in tensors] 

36 dtype = dtypes[0] 

37 for ty in dtypes[1:]: 

38 dtype = torch.promote_types(dtype, ty) 

39 self.dtype = dtype 

40 for i, tensor in enumerate(tensors): 

41 assert ( 

42 tensor.device == self.device 

43 and tensor.dim() == tensors[0].dim() 

44 and tensors[0].shape[1:] == tensor.shape[1:] 

45 if tensors[0].dim() > 1 

46 else tensors[0].shape == tensor.shape 

47 ) 

48 if tensor.dtype != self.dtype: 

49 tensors[i] = tensor.to(self.dtype) 

50 c_tensors = [t.contiguous() for t in tensors] 

51 self.inputs = [] 

52 self.idxs = [0] 

53 self.total_size = 0 

54 for tensor in c_tensors: 

55 self.total_size += tensor.numel() 

56 self.idxs.append(self.total_size) 

57 self.inputs.append(tensor) 

58 self.deal_num = math.ceil(self.total_size / TOTAL_CORE_NUM) 

59 self.input_num = len(self.inputs) 

60 flag = (self.total_size / self.input_num) == self.idxs[1] 

61 if ( 

62 self.total_size < self.vstack_small_limit 

63 and self.input_num <= TOTAL_CORE_NUM 

64 and flag 

65 ): 

66 self.is_small = True 

67 else: 

68 self.is_small = False 

69 

70 def __imports(self): 

71 """Generate imports for the kernel code.""" 

72 self.tpl( 

73 """ 

74import math 

75import torch 

76import triton 

77from triton import language as tl 

78from flag_gems.runtime import torch_device_fn 

79from flag_gems.utils import libentry, libtuner 

80from flag_gems.runtime.backend import vendor_module 

81TOTAL_CORE_NUM = vendor_module.TOTAL_CORE_NUM 

82MAX_NRAM_SIZE = vendor_module.MAX_NRAM_SIZE 

83 """ 

84 ) 

85 

86 def __wrapper(self): 

87 """Generate wrapper function for the kernel code.""" 

88 self.newline() 

89 self.tpl( 

90 """ 

91def {wrapper_name}(tensors, inputs, idx, total_size, input_num, deal_num, is_small): 

92 tensors = torch.atleast_2d(tensors) 

93 num_tensors = len(tensors) 

94 assert num_tensors > 0 

95 if num_tensors == 1: 

96 return tensors[0] 

97 device = tensors[0].device 

98 dtype = tensors[0].dtype 

99 input = [i for i in inputs] 

100 c_tensors = [t.contiguous() for t in tensors] 

101 total_rows = sum(tensor.shape[0] for tensor in c_tensors) 

102 output_shape = list(c_tensors[0].shape) 

103 output_shape[0] = total_rows 

104 output = torch.empty(output_shape, device=device, dtype=dtype) 

105 with torch_device_fn.device(device): 

106 {kernel_name}[(TOTAL_CORE_NUM,)]({args}) 

107 return output 

108 """, 

109 wrapper_name=self.wrapper_func_name, 

110 kernel_name=self.kernel_name, 

111 args=self.__kernel_args(is_declare=False), 

112 ) 

113 

114 def __config(self): 

115 """Generate config for the kernel code.""" 

116 # generate config key. 

117 self.newline() 

118 self.tpl( 

119 """ 

120@libentry() 

121@libtuner( 

122 configs=[ 

123 triton.Config({{'BLOCK_SIZE' : 512}}, num_warps=1), 

124 triton.Config({{'BLOCK_SIZE' : 2048}}, num_warps=1), 

125 triton.Config({{'BLOCK_SIZE' : 4096}}, num_warps=1), 

126 triton.Config({{'BLOCK_SIZE' : 8192}}, num_warps=1), 

127 triton.Config({{'BLOCK_SIZE' : 10240}}, num_warps=1), 

128 triton.Config({{'BLOCK_SIZE' : 14336}}, num_warps=1), 

129 triton.Config({{'BLOCK_SIZE' : 18000}}, num_warps=1), 

130 triton.Config({{'BLOCK_SIZE' : 22000}}, num_warps=1), 

131 triton.Config({{'BLOCK_SIZE' : 28000}}, num_warps=1), 

132 triton.Config({{'BLOCK_SIZE' : 32000}}, num_warps=1), 

133 ], 

134 key = [{config_keys}], 

135) 

136@triton.jit 

137 """, 

138 config_keys="'total_size'", 

139 ) 

140 

141 def __kernel(self): 

142 """Generate kernel code body.""" 

143 # configuration. 

144 self.__config() 

145 kernel_signature = f"def {self.kernel_name}({self.__kernel_args()}):" 

146 self.idx_1 = 1 

147 self.idx_0 = 0 

148 self.writeline(kernel_signature) 

149 with self.indent(): 

150 self.writeline("pid_x = tl.program_id(0)") 

151 self.writeline("block = tl.arange(0, BLOCK_SIZE)") 

152 self.writeline("if is_small:") 

153 with self.indent(): 

154 self.writeline("for i in range(input_num):") 

155 with self.indent(): 

156 for i in range(self.input_num): 

157 self.writeline(f"if pid_x == {i} and pid_x == i:") 

158 with self.indent(): 

159 self.writeline( 

160 f"for num in range(0, idx_{i + 1} - idx_{i}, BLOCK_SIZE):" 

161 ) 

162 with self.indent(): 

163 self.writeline("in_offset = num + block") 

164 self.writeline(f"dst_offset = idx_{i} + num + block") 

165 self.writeline( 

166 f"x = tl.load(input_{i} + in_offset, mask = in_offset < idx_{i + 1} - idx_{i})" 

167 ) 

168 self.writeline( 

169 f"tl.store(output + dst_offset, x, mask = dst_offset < idx_{i + 1})" 

170 ) 

171 self.writeline("else:") 

172 with self.indent(): 

173 self.writeline("condidate_num = idx_1") 

174 self.writeline("input_iter = 0") 

175 self.writeline("for pid in range(pid_x + 1):") 

176 with self.indent(): 

177 self.writeline("need_num = deal_num") 

178 self.writeline("while(need_num > 0):") 

179 with self.indent(): 

180 self.writeline("per_fetch_num = min(condidate_num, need_num)") 

181 self.writeline("if pid == pid_x:") 

182 with self.indent(): 

183 self.writeline("if input_iter == 0:") 

184 with self.indent(): 

185 self.writeline("offset = idx_1 - idx_0 - condidate_num") 

186 self.writeline("deal_rem = deal_num - per_fetch_num") 

187 self.writeline( 

188 "for i in range(0, deal_num, BLOCK_SIZE):" 

189 ) 

190 with self.indent(): 

191 self.writeline("in_offset = offset + i + block") 

192 self.writeline("dst_offset = in_offset") 

193 self.writeline( 

194 "x = tl.load(input_0 + in_offset, mask=in_offset < idx_1 - idx_0)" 

195 ) 

196 self.writeline( 

197 "tl.store(output + dst_offset, x, mask=dst_offset<idx_1)" 

198 ) 

199 if self.input_num > 1: 

200 self.writeline("else:") 

201 with self.indent(): 

202 for i in range(1, self.input_num, 1): 

203 idx = i + 1 

204 self.writeline(f"if input_iter == {i}:") 

205 with self.indent(): 

206 self.writeline( 

207 f"offset = idx_{idx} - idx_{i} - condidate_num" 

208 ) 

209 self.writeline("if need_num != deal_num:") 

210 with self.indent(): 

211 self.writeline( 

212 "deal_rem = deal_num - per_fetch_num" 

213 ) 

214 self.writeline( 

215 "for i in range(0, need_num, BLOCK_SIZE):" 

216 ) 

217 with self.indent(): 

218 self.writeline( 

219 "in_offset = offset + i + block" 

220 ) 

221 self.writeline( 

222 f"dst_offset = idx_{i} + in_offset" 

223 ) 

224 self.writeline( 

225 f"x = tl.load(input_{i} + in_offset, mask=in_offset < need_num)" 

226 ) 

227 self.writeline( 

228 f"tl.store(output + dst_offset, x, \ 

229 mask=dst_offset<idx_{i}+per_fetch_num)" 

230 ) 

231 self.writeline("else:") 

232 with self.indent(): 

233 self.writeline( 

234 "for i in range(0, need_num, BLOCK_SIZE):" 

235 ) 

236 with self.indent(): 

237 self.writeline( 

238 "in_offset = offset + i + block" 

239 ) 

240 self.writeline( 

241 f"dst_offset = idx_{i} + in_offset" 

242 ) 

243 self.writeline( 

244 f"x = tl.load(input_{i} + in_offset, \ 

245 mask=in_offset < idx_{idx}-idx_{i})" 

246 ) 

247 self.writeline( 

248 f"tl.store(output + dst_offset, x, mask=dst_offset<idx_{idx})" 

249 ) 

250 self.writeline("condidate_num -= per_fetch_num") 

251 self.writeline("need_num -= per_fetch_num") 

252 self.writeline("if (condidate_num <= 0):") 

253 with self.indent(): 

254 for i in range(1, self.input_num, 1): 

255 idx = i + 1 

256 input_idx = i - 1 

257 if self.input_num == 2: 

258 self.writeline( 

259 f"condidate_num = idx_{idx} - idx_{i}" 

260 ) 

261 else: 

262 if i == 1: 

263 self.writeline(f"if input_iter == {input_idx}:") 

264 with self.indent(): 

265 self.writeline( 

266 f"condidate_num = idx_{idx} - idx_{i}" 

267 ) 

268 else: 

269 if i < self.input_num - 1: 

270 self.writeline( 

271 f"elif input_iter == {input_idx}:" 

272 ) 

273 with self.indent(): 

274 self.writeline( 

275 f"condidate_num = idx_{idx} - idx_{i}" 

276 ) 

277 else: 

278 self.writeline("else:") 

279 with self.indent(): 

280 self.writeline( 

281 f"condidate_num = idx_{idx} - idx_{i}" 

282 ) 

283 

284 self.writeline("input_iter += 1") 

285 

286 def __gen_code(self): 

287 """Entry point for code generation of vstack.""" 

288 # generate imports. 

289 self.__imports() 

290 # generate wrapper function. 

291 self.__wrapper() 

292 

293 # generate kernel. 

294 self.__kernel() 

295 

296 def __kernel_args(self, is_declare=True): 

297 input_args = [] 

298 idxs_args = [] 

299 if is_declare: 

300 for i in range(self.input_num): 

301 input_args.append(f"input_{i}") 

302 for i in range(len(self.idxs)): 

303 idxs_args.append(f"idx_{i}") 

304 else: 

305 for i in range(self.input_num): 

306 input_args.append(f"input[{i}]") 

307 for i in range(len(self.idxs)): 

308 idxs_args.append(f"idx[{i}]") 

309 input_args_str = ", ".join(input_args) 

310 idxs_args_str = ", ".join(idxs_args) 

311 

312 extra_args_str = f"{input_args_str}, {idxs_args_str}" 

313 if is_declare: 

314 return f"{extra_args_str}, output, total_size, input_num, deal_num, is_small, BLOCK_SIZE: tl.constexpr" 

315 else: 

316 return ( 

317 f"{extra_args_str}, output, total_size, input_num, deal_num, is_small" 

318 ) 

319 

320 def __call__(self, tensors: list) -> torch.Tensor: 

321 # get overload kernel. 

322 self.__init(tensors) 

323 

324 vstack_input_num = "_".join(str(self.input_num)) 

325 

326 self.kernel_name = self.kernel_name + "_vstack_" + vstack_input_num 

327 key = f"{self.total_size}_{self.input_num}" 

328 if key not in self.cache: 

329 # generate code and cache. 

330 self.__gen_code() 

331 file_name = f"vstack_{key}_pid_{self.pid}.py" 

332 filepath = cache_dir() / file_name 

333 write_atomic(filepath, self.getvalue()) 

334 # load 

335 spec = importlib.util.spec_from_file_location( 

336 f"_gen_module_{key}_pid_{self.pid}", filepath 

337 ) 

338 m = importlib.util.module_from_spec(spec) 

339 # do not expose it to sys.modules 

340 # sys.modules["_add_module"] = m 

341 spec.loader.exec_module(m) 

342 overload = getattr(m, self.wrapper_func_name) 

343 self.cache[key] = overload 

344 overload = self.cache[key] 

345 return overload( 

346 tensors, 

347 self.inputs, 

348 self.idxs, 

349 self.total_size, 

350 self.input_num, 

351 self.deal_num, 

352 self.is_small, 

353 ) 

354 

355 

356def vstack(tensors: list): 

357 logger.debug("GEMS_CAMBRICON VSTACK") 

358 

359 return VstackKernelCode()(tensors)