Coverage for src/flag_gems/runtime/backend/_cambricon/ops/cat.py: 0%

213 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import importlib 

2import logging 

3import math 

4import os 

5from typing import Callable, List, Mapping, Tuple, Union 

6 

7import torch 

8 

9from flag_gems.utils.code_cache import code_cache_dir 

10from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

11 

12from .vstack import vstack 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17class CatKernelGenerator(IndentedBuffer): 

18 overloads: Mapping[str, Callable] = {} 

19 

20 def __init__(self): 

21 self.pid = os.getpid() 

22 self.cache = self.overloads 

23 super().__init__() 

24 

25 def __init( 

26 self, 

27 tensors: List[torch.Tensor], 

28 dim: int, 

29 high_num: int, 

30 low_cat_accum: List[int], 

31 ): 

32 self.dim = dim 

33 self.high_num = high_num 

34 self.low_cat_accum = low_cat_accum 

35 self.tensor_num = len(tensors) 

36 even = all([t.numel() == tensors[0].numel() for t in tensors]) 

37 

38 if even and low_cat_accum[-1] // self.tensor_num <= 128: 

39 # Special case for tensors with small and even low size, 

40 # which means weak contiguity when storing the out tensor. 

41 # Divide each tensor into tiles of `BLOCK_LOW` size, 

42 # and each cta process tiles one by one. 

43 self.kernel_name = "_cat_kernel_small" 

44 self.wrapper_name = "_cat_wrapper_small" 

45 self.MODE = 0 

46 else: 

47 # General cases. 

48 # Divide tasks by high_num, each cta process parts of high of all tensors. 

49 self.kernel_name = "_cat_kernel_parthigh" 

50 self.wrapper_name = "_cat_wrapper_parthigh" 

51 self.MODE = 1 

52 

53 def __call__( 

54 self, 

55 tensors: List[torch.Tensor], 

56 dim: int, 

57 high_num: int, 

58 low_cat_accum: List[int], 

59 ): 

60 self.__init(tensors, dim, high_num, low_cat_accum) 

61 key = f"{len(tensors)}_{high_num}_{low_cat_accum[-1]}" 

62 if key not in self.cache: 

63 self.codegen() 

64 

65 filename = f"{self.kernel_name}_{key}.py" 

66 filepath = code_cache_dir() / filename 

67 write_atomic(filepath, self.getvalue()) 

68 

69 spec = importlib.util.spec_from_file_location( 

70 f"_gen_module_{key}", filepath 

71 ) 

72 m = importlib.util.module_from_spec(spec) 

73 spec.loader.exec_module(m) 

74 overload = getattr(m, self.wrapper_name) 

75 self.cache[key] = overload 

76 overload = self.cache[key] 

77 return overload(tensors, dim, high_num, low_cat_accum) 

78 

79 def gen_imports(self): 

80 self.writeline("import math") 

81 self.writeline("import copy") 

82 self.newline() 

83 self.writeline("import torch") 

84 self.writeline("import triton") 

85 self.writeline("import triton.language as tl") 

86 self.newline() 

87 self.writeline("from flag_gems.runtime import torch_device_fn") 

88 self.writeline("from flag_gems.runtime.backend import vendor_module") 

89 self.writeline("from flag_gems.utils import libentry, libtuner") 

90 self.newline() 

91 self.writeline("TOTAL_CORE_NUM = vendor_module.TOTAL_CORE_NUM") 

92 self.newline() 

93 self.newline() 

94 

95 def gen_wrapper(self): 

96 self.writeline( 

97 f"def {self.wrapper_name}(tensors, dim, high_num, low_cat_accum):" 

98 ) 

99 with self.indent(): 

100 self.writeline("device = tensors[0].device") 

101 self.writeline("dtype = tensors[0].dtype") 

102 self.writeline("tensor_num = len(tensors)") 

103 self.writeline("cat_dim_size = sum([t.shape[dim] for t in tensors])") 

104 self.writeline("out_shape = list(tensors[0].shape)") 

105 self.writeline("out_shape[dim] = cat_dim_size") 

106 self.writeline("out_cat_num = low_cat_accum[-1]") 

107 self.writeline("out = torch.empty(out_shape, device=device, dtype=dtype)") 

108 for i in range(self.tensor_num): 

109 self.writeline(f"in{i}_stride_high = tensors[{i}].stride(dim - 1)") 

110 self.writeline(f"in{i}_stride_low = tensors[{i}].stride(-1)") 

111 self.writeline("out_stride_high = out.stride(dim - 1)") 

112 self.writeline("out_stride_low = out.stride(-1)") 

113 self.writeline( 

114 "grid = lambda meta: (TOTAL_CORE_NUM // meta['num_warps'], )" 

115 ) 

116 self.writeline("with torch_device_fn.device(device):") 

117 with self.indent(): 

118 self.writeline( 

119 f"{self.kernel_name}[grid]({self.gen_kernel_args(is_declare=False)})" 

120 ) 

121 self.writeline("return out") 

122 self.newline() 

123 self.newline() 

124 

125 def gen_decorators(self): 

126 self.writeline("@libentry()") 

127 self.writeline("@libtuner(") 

128 with self.indent(): 

129 self.writeline("configs=[") 

130 with self.indent(): 

131 if self.MODE == 0: 

132 self.writeline( 

133 """ 

134 triton.Config({'BLOCK_LOW': 2 ** i}, num_stages=1, num_warps=1) for i in range(7, 12) 

135 """ 

136 ) 

137 elif self.MODE == 1: 

138 self.writeline( 

139 """ 

140 triton.Config({'BLOCK_HIGH': i, 'BLOCK_LOW': 2 ** j}, num_stages=1, num_warps=1) 

141 for i in [6, 11, 22] 

142 for j in range(8, 12) 

143 """ 

144 ) 

145 self.writeline("],") 

146 self.writeline("key=['high_num', 'out_cat_num'],") 

147 self.writeline("strategy=['log', 'log'],") 

148 self.writeline("restore_value=['out'],") 

149 self.writeline(")") 

150 self.writeline("@triton.jit") 

151 

152 def gen_kernel(self): 

153 self.writeline(f"def {self.kernel_name}({self.gen_kernel_args()}):") 

154 with self.indent(): 

155 self.writeline("pid = tl.program_id(0)") 

156 self.writeline("programs_num = tl.num_programs(0)") 

157 if self.MODE == 0: 

158 self.writeline( 

159 "tiles_per_tensor = tl.cdiv(high_num * tl.cdiv(out_cat_num, tensor_num), BLOCK_LOW)" 

160 ) 

161 self.writeline("num_tiles = tiles_per_tensor * tensor_num") 

162 self.writeline("tiles_per_cta = tl.cdiv(num_tiles, programs_num)") 

163 self.writeline("for i in range(tiles_per_cta):") 

164 with self.indent(): 

165 self.writeline("tile_id = pid + i * programs_num") 

166 self.writeline("tensor_id = tile_id // tiles_per_tensor") 

167 self.writeline("tile_id = tile_id % tiles_per_tensor") 

168 for j in range(self.tensor_num): 

169 self.writeline(f"if tensor_id == {j}:") 

170 with self.indent(): 

171 self.writeline( 

172 f"low_cat = low_cat_accum{j + 1} - low_cat_accum{j}" 

173 ) 

174 self.writeline("offsets = tl.arange(0, BLOCK_LOW)") 

175 self.writeline("in_offsets = tile_id * BLOCK_LOW + offsets") 

176 self.writeline("mask = in_offsets < high_num * low_cat") 

177 self.writeline( 

178 f"data = tl.load(in{j} + in_offsets, mask=mask)" 

179 ) 

180 high_part = "(in_offsets // low_cat) * out_cat_num" 

181 low_part = f"low_cat_accum{j} + (in_offsets % low_cat)" 

182 self.writeline(f"out_offsets = {high_part} + {low_part}") 

183 self.writeline( 

184 "tl.store(out + out_offsets, data, mask=mask)" 

185 ) 

186 elif self.MODE == 1: 

187 self.writeline("num_tiles = tl.cdiv(high_num, BLOCK_HIGH)") 

188 self.writeline("tiles_per_cta = tl.cdiv(num_tiles, programs_num)") 

189 self.writeline("for i in range(tiles_per_cta):") 

190 with self.indent(): 

191 self.writeline("tile_id = pid + i * programs_num") 

192 self.writeline("high_offset = tile_id * BLOCK_HIGH") 

193 for j in range(self.tensor_num): 

194 self.writeline( 

195 f"low_cat = low_cat_accum{j + 1}-low_cat_accum{j}" 

196 ) 

197 self.writeline( 

198 "for low_offset in range(0, low_cat, BLOCK_LOW):" 

199 ) 

200 with self.indent(): 

201 self.writeline( 

202 "high_offsets = high_offset + tl.arange(0, BLOCK_HIGH)" 

203 ) 

204 self.writeline( 

205 "low_offsets = low_offset + tl.arange(0, BLOCK_LOW)" 

206 ) 

207 high_part = f"high_offsets[:, None] * in{j}_stride_high" 

208 low_part = f"low_offsets[None, :] * in{j}_stride_low" 

209 self.writeline(f"in_offsets = {high_part} + {low_part}") 

210 self.writeline( 

211 "in_mask = (high_offsets < high_num)[:,None] & (low_offsets < low_cat)[None,:]" 

212 ) 

213 self.writeline( 

214 f"data = tl.load(in{j}+in_offsets, mask=in_mask)" 

215 ) 

216 high_part = "high_offsets[:, None] * out_stride_high" 

217 low_part = f"(low_cat_accum{j} + low_offsets[None, :]) * out_stride_low" 

218 self.writeline(f"out_offsets = {high_part} + {low_part}") 

219 self.writeline( 

220 "tl.store(out+out_offsets, data, mask=in_mask)" 

221 ) 

222 

223 def gen_kernel_args(self, is_declare=True): 

224 in_args = ", ".join( 

225 f"in{i}" if is_declare else f"tensors[{i}]" for i in range(self.tensor_num) 

226 ) 

227 low_cat_accum_args = ", ".join( 

228 f"low_cat_accum{i}" if is_declare else f"low_cat_accum[{i}]" 

229 for i in range(self.tensor_num + 1) 

230 ) 

231 stride_args = ( 

232 ", ".join( 

233 f"in{i}_stride_high, in{i}_stride_low" for i in range(self.tensor_num) 

234 ) 

235 + ", out_stride_high, out_stride_low" 

236 ) 

237 

238 kernel_args = f"{in_args}, out, {stride_args}, tensor_num, high_num, {low_cat_accum_args}, out_cat_num, " 

239 ex_args = "BLOCK_LOW: tl.constexpr, num_warps: tl.constexpr" 

240 if self.MODE == 1: 

241 ex_args += ", BLOCK_HIGH: tl.constexpr" 

242 

243 return kernel_args if not is_declare else kernel_args + ex_args 

244 

245 def codegen(self): 

246 self.gen_imports() 

247 self.gen_wrapper() 

248 self.gen_decorators() 

249 self.gen_kernel() 

250 

251 

252def cat( 

253 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

254) -> torch.Tensor: 

255 logger.debug("GEMS_CAMBRICON CAT") 

256 

257 # Check empty inputs. 

258 if len(tensors) == 0: 

259 raise RuntimeError( 

260 "Expected a non-empty list or tuple/list of non-empty torch.Tensor" 

261 ) 

262 if len(tensors) == 1: 

263 return tensors[0] 

264 

265 # remove torch.Size([0]) tensors 

266 device = tensors[0].device 

267 dtype = tensors[0].dtype 

268 tensors = list(tensors) 

269 

270 for i in range(len(tensors) - 1, -1, -1): 

271 if tensors[i].shape == torch.Size([0]): 

272 tensors.pop(i) 

273 if len(tensors) == 0: 

274 return torch.tensor([], dtype=dtype, device=device) 

275 elif len(tensors) == 1: 

276 return tensors[0] 

277 

278 # Check dimensions. 

279 ndim = tensors[0].ndim 

280 assert dim >= -ndim and dim < ndim, f"Invalid concat dimension: {dim}" 

281 dim %= ndim 

282 

283 # Check shapes and zero element tensors. 

284 device = tensors[0].device 

285 dtypes = [t.dtype for t in tensors] 

286 dtype = dtypes[0] 

287 for ty in dtypes[1:]: 

288 dtype = torch.promote_types(dtype, ty) 

289 shape = tensors[0].shape 

290 valid_tensors = [] 

291 

292 for _, tensor in enumerate(tensors): 

293 assert ( 

294 tensor.ndim == ndim 

295 ), f"Requires same ndim of inputs, but got {ndim} and {tensor.ndim}" 

296 assert ( 

297 tensor.device == device 

298 ), f"Requires same device of inputs, but got {device} and {tensor.device}" 

299 for d_idx, (size, base_size) in enumerate(zip(tensor.shape, shape)): 

300 assert ( 

301 dim == d_idx or size == base_size 

302 ), f"Requires same dim sizes of dim {d_idx}, but got {size} and {base_size}" 

303 if tensor.numel() != 0: 

304 tensor = tensor.contiguous() 

305 valid_tensors.append(tensor.to(dtype) if tensor.dtype != dtype else tensor) 

306 

307 tensor_num = len(valid_tensors) 

308 

309 # Deal with special cases. 

310 if tensor_num == 1: 

311 return valid_tensors[0] 

312 

313 cat_dim_sizes = [_.shape[dim] for _ in tensors] 

314 out_shape = list(tensors[0].shape) 

315 out_shape[dim] = sum(cat_dim_sizes) 

316 

317 if tensor_num == 0: 

318 return torch.empty(out_shape, dtype=dtype, device=device) 

319 

320 # Preprocess kernel parameters. 

321 high_num = int(math.prod(out_shape[:dim])) 

322 low_num = int(math.prod(out_shape[dim + 1 :])) 

323 out_cat_num = 0 

324 low_cat_accum = [0] 

325 

326 for size in cat_dim_sizes: 

327 out_cat_num += size * low_num 

328 low_cat_accum.append(out_cat_num) 

329 

330 # Launch kernel. 

331 if high_num == 1: 

332 # Vstack and Concat results in the same storage arrangement when high_num == 1. 

333 valid_tensors = [t.view(t.shape[dim], -1) for t in valid_tensors] 

334 return vstack(valid_tensors).view(out_shape) 

335 else: 

336 # Dealing with concat situations that having arbitary nums of inputs via template code genertaor. 

337 return CatKernelGenerator()(valid_tensors, dim, high_num, low_cat_accum)