Coverage for src/flag_gems/runtime/backend/_cambricon/ops/flip.py: 0%

275 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import importlib 

2import logging 

3import os 

4from typing import Callable, Mapping 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14class FlipKernelCode(IndentedBuffer): 

15 """ 

16 Flip kernel template. 

17 """ 

18 

19 overloads: Mapping[str, Callable] = {} 

20 

21 def __init__(self): 

22 self.pid = os.getpid() 

23 self.cache = self.overloads 

24 self.kernel_name = "_flip_jit_kernel" 

25 self.wrapper_func_name = "_wrapper" 

26 super(FlipKernelCode, self).__init__() 

27 

28 def __init(self, x, dims): 

29 """Initialize the flip kernel.""" 

30 dim_size = x.dim() 

31 

32 flip_dims = list(dims) 

33 flip_dims_flags = [False for _ in x.stride()] 

34 for i in range(len(flip_dims)): 

35 dim = flip_dims[i] 

36 assert ( 

37 dim >= -dim_size and dim < dim_size 

38 ), "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

39 -dim_size, dim_size - 1, dim 

40 ) 

41 if dim < 0: 

42 flip_dims[i] = dim_size + dim 

43 assert not flip_dims_flags[ 

44 dim 

45 ], "dim {} appears multiple times in the list of dims".format(dim) 

46 flip_dims_flags[dim] = True 

47 

48 # merge shapes and flip_dims_flags by flip flags. 

49 self.merge_shapes = [] 

50 self.merge_strides = [] 

51 flag = flip_dims_flags[0] 

52 self.merge_flip_dims_flags = [] 

53 self.merge_flip_dim = 0 

54 shape = 1 

55 for i in range(dim_size): 

56 if (flag == flip_dims_flags[i]) or x.shape[i] == 1: 

57 shape *= x.shape[i] 

58 else: 

59 self.merge_shapes.append(shape) 

60 self.merge_strides.append(x.stride(i - 1)) 

61 self.merge_flip_dims_flags.append(flag) 

62 if flag: 

63 self.merge_flip_dim += 1 

64 flag = flip_dims_flags[i] 

65 shape = x.shape[i] 

66 self.merge_shapes.append(shape) 

67 self.merge_strides.append(1) 

68 self.merge_flip_dims_flags.append(flag) 

69 if flag: 

70 self.merge_flip_dim += 1 

71 

72 self.merge_dim_size = len(self.merge_shapes) 

73 

74 def __imports(self): 

75 """Generate imports for the kernel code.""" 

76 self.tpl( 

77 """ 

78import math 

79import torch 

80import triton 

81from triton import language as tl 

82 

83from flag_gems.utils import libentry 

84from flag_gems.runtime.backend import vendor_module 

85TOTAL_CORE_NUM = vendor_module.utils.TOTAL_CORE_NUM 

86MAX_NRAM_SIZE = vendor_module.utils.MAX_NRAM_SIZE 

87 

88 

89 """ 

90 ) 

91 

92 def __wrapper(self): 

93 """Generate wrapper function for the kernel code.""" 

94 self.newline() 

95 self.tpl( 

96 """ 

97def {wrapper_name}(x, merge_shapes, merge_strides, merge_dim_size): 

98 if merge_dim_size == 0 or x.numel() <= 1: 

99 return x.clone() 

100 

101 low_task = merge_shapes[merge_dim_size - 1] 

102 sub_dim = 1 

103 

104 high_task = 1 

105 if merge_dim_size > 1: 

106 sub_dim = merge_shapes[merge_dim_size - 2] 

107 low_task *= sub_dim 

108 for i in range(merge_dim_size - 2): 

109 high_task *= merge_shapes[i] 

110 y = 1 

111 if high_task < TOTAL_CORE_NUM: 

112 for i in range(1, sub_dim + 1): 

113 if sub_dim % i == 0: 

114 y = i 

115 if y * high_task >= TOTAL_CORE_NUM: 

116 break 

117 

118 grid = lambda meta: (min(high_task, TOTAL_CORE_NUM), y, ) 

119 

120 # in case of one-dim. 

121 if (high_task == 1) and (y == 1) and (merge_dim_size == 1): 

122 if low_task <= 1024: 

123 grid = lambda meta: (1, 1, ) 

124 else: 

125 grid = lambda meta: (1, TOTAL_CORE_NUM, ) 

126 

127 out = torch.empty_like(x) 

128 with torch.cuda.device(x.device): 

129 {kernel_name}[grid]({args}) 

130 return out 

131 """, 

132 wrapper_name=self.wrapper_func_name, 

133 kernel_name=self.kernel_name, 

134 args=self.__kernel_args(is_declare=False), 

135 ) 

136 

137 def __config(self): 

138 """Generate config for the kernel code.""" 

139 # generate config key. 

140 merge_shapes_args_str = ", ".join( 

141 [f"'merge_shape_{i}'" for i in range(self.merge_dim_size)] 

142 ) 

143 merge_strides_args_str = ", ".join( 

144 [f"'merge_stride_{i}'" for i in range(self.merge_dim_size)] 

145 ) 

146 

147 self.newline() 

148 self.tpl( 

149 """ 

150 

151def get_h_dim(args): 

152 merge_dim_size = args['merge_dim_size']; 

153 high = 0 

154 if merge_dim_size > 1: 

155 high = args['merge_shape_{merge_dim_size_2}'] 

156 width = args['merge_shape_{merge_dim_size_1}'] 

157 max_nram_size = 3072 

158 if max_nram_size >= width: 

159 tmp_h = max_nram_size // width 

160 if tmp_h < high: 

161 return tmp_h 

162 return high 

163 return 0 

164 

165def get_w_dim(args): 

166 merge_dim_size = args['merge_dim_size']; 

167 width = args['merge_shape_{merge_dim_size_1}'] 

168 max_nram_size = 3072 

169 if max_nram_size >= width: 

170 return width 

171 return max_nram_size 

172 

173@libentry() 

174@triton.autotune( 

175 configs=[ 

176 triton.Config({{}}, num_stages=3, num_warps=1), 

177 ], 

178 key = [{config_keys}], 

179) 

180@triton.heuristics( 

181 values={{ 

182 "H_DIM": get_h_dim, 

183 "W_DIM": get_w_dim, 

184 }}, 

185) 

186@triton.jit 

187 """, 

188 merge_dim_size_2=str(self.merge_dim_size - 2), 

189 merge_dim_size_1=str(self.merge_dim_size - 1), 

190 config_keys=f"'x_ptr', {merge_shapes_args_str}, {merge_strides_args_str}", 

191 ) 

192 

193 def __kernel_flip_2d(self): 

194 """Generate kernel for 2d buffer flip.""" 

195 self.writeline(f"step = merge_shape_{self.merge_dim_size - 2} // num_y") 

196 self.writeline( 

197 f"src_offset += pid_y * step * merge_shape_{self.merge_dim_size - 1}" 

198 ) 

199 if self.merge_flip_dims_flags[self.merge_dim_size - 2]: 

200 # [flip, no-flip] 

201 self.writeline("# flip low-2d [flip, no-flip]") 

202 self.writeline( 

203 f"dst_offset += (num_y - pid_y - 1) * step * merge_shape_{self.merge_dim_size - 1}" 

204 ) 

205 self.writeline("if H_DIM != 0:") 

206 with self.indent(): 

207 self.writeline( 

208 "offset = tl.arange(0, H_DIM)[:,None]*W_DIM + tl.arange(0, W_DIM)[None,:]" 

209 ) 

210 self.writeline("tail = step % H_DIM") 

211 self.writeline("iter = step // H_DIM") 

212 self.writeline("for i in range(0, iter):") 

213 with self.indent(): 

214 self.writeline("in_offset = src_offset + i * H_DIM*W_DIM") 

215 self.writeline( 

216 "out_offset = dst_offset + tail * W_DIM + (iter - i - 1) * H_DIM*W_DIM" 

217 ) 

218 self.writeline( 

219 "src = tl.load(x_ptr + offset + in_offset, cache_modifier='.cg')" 

220 ) 

221 self.writeline("src = tl.flip(src, [0])") 

222 self.writeline( 

223 "tl.store(out_ptr + offset + out_offset, src, cache_modifier='.cg')" 

224 ) 

225 self.writeline("if tail > 0:") 

226 with self.indent(): 

227 self.writeline("# process tail.") 

228 self.writeline("in_offset = src_offset + iter * H_DIM*W_DIM") 

229 self.writeline("out_offset = dst_offset - (H_DIM-tail)*W_DIM") 

230 self.writeline("mask = offset < tail*W_DIM") 

231 self.writeline( 

232 "src = tl.load(x_ptr + offset + in_offset, mask=mask, other=0.0, cache_modifier='.cg')" 

233 ) 

234 self.writeline("src = tl.flip(src, [0])") 

235 self.writeline("mask = offset >= (H_DIM - tail) * W_DIM") 

236 self.writeline( 

237 "tl.store(out_ptr + offset + out_offset, src, mask=mask, cache_modifier='.cg')" 

238 ) 

239 self.writeline("else:") 

240 with self.indent(): 

241 self.writeline("offset = tl.arange(0, W_DIM)") 

242 self.writeline(f"iter = merge_shape_{self.merge_dim_size - 1} // W_DIM") 

243 self.writeline(f"tail = merge_shape_{self.merge_dim_size - 1} % W_DIM") 

244 self.writeline("src = tl.zeros((W_DIM,), dtype=x_ptr.dtype.element_ty)") 

245 self.writeline("for i in range(0, step):") 

246 with self.indent(): 

247 self.writeline( 

248 f"in_offset = src_offset + i * merge_shape_{self.merge_dim_size - 1}" 

249 ) 

250 self.writeline( 

251 f"out_offset = dst_offset + (step - i - 1) * merge_shape_{self.merge_dim_size - 1}" 

252 ) 

253 self.writeline("for j in range(0, iter):") 

254 with self.indent(): 

255 self.writeline("new_offset = offset + j*W_DIM") 

256 self.writeline( 

257 "src = tl.load(x_ptr + in_offset + new_offset, cache_modifier='.cg')" 

258 ) 

259 self.writeline( 

260 "tl.store(out_ptr + out_offset + new_offset, src, cache_modifier='.cg')" 

261 ) 

262 self.writeline("if tail > 0:") 

263 with self.indent(): 

264 self.writeline("new_offset = offset + iter*W_DIM") 

265 self.writeline("mask = offset < tail") 

266 self.writeline( 

267 "src = tl.load(x_ptr + in_offset + new_offset, mask=mask, cache_modifier='.cg')" 

268 ) 

269 self.writeline( 

270 "tl.store(out_ptr + out_offset + new_offset, src, mask=mask, cache_modifier='.cg')" 

271 ) 

272 else: 

273 # [no-flip, flip] 

274 self.writeline("# flip low-2d [no-flip, flip]") 

275 self.writeline( 

276 f"dst_offset += pid_y * step * merge_shape_{self.merge_dim_size - 1}" 

277 ) 

278 self.writeline("if H_DIM != 0:") 

279 with self.indent(): 

280 self.writeline( 

281 "offset = tl.arange(0, H_DIM)[:,None]*W_DIM + tl.arange(0, W_DIM)[None,:]" 

282 ) 

283 self.writeline("tail = step % H_DIM") 

284 self.writeline("iter = step // H_DIM") 

285 self.writeline("for i in range(0, iter):") 

286 with self.indent(): 

287 self.writeline("in_offset = src_offset + i * H_DIM*W_DIM") 

288 self.writeline("out_offset = dst_offset + i * H_DIM*W_DIM") 

289 self.writeline( 

290 "src = tl.load(x_ptr + offset + in_offset, cache_modifier='.cg')" 

291 ) 

292 self.writeline("src = tl.flip(src, [1])") 

293 self.writeline( 

294 "tl.store(out_ptr + offset + out_offset, src, cache_modifier='.cg')" 

295 ) 

296 self.writeline("if tail > 0:") 

297 with self.indent(): 

298 self.writeline("# process tail.") 

299 self.writeline("in_offset = src_offset + iter * H_DIM*W_DIM") 

300 self.writeline("out_offset = dst_offset + iter * H_DIM*W_DIM") 

301 self.writeline("mask = offset < tail*W_DIM") 

302 self.writeline( 

303 "src = tl.load(x_ptr + offset + in_offset, mask=mask, other=0.0, cache_modifier='.cg')" 

304 ) 

305 self.writeline("src = tl.flip(src, [1])") 

306 self.writeline( 

307 "tl.store(out_ptr + offset + out_offset, src, mask=mask, cache_modifier='.cg')" 

308 ) 

309 self.writeline("else:") 

310 with self.indent(): 

311 self.writeline("offset = tl.arange(0, W_DIM)") 

312 self.writeline("src = tl.zeros((W_DIM,), dtype=x_ptr.dtype.element_ty)") 

313 self.writeline(f"tail = merge_shape_{self.merge_dim_size - 1} % W_DIM") 

314 self.writeline(f"iter = merge_shape_{self.merge_dim_size - 1} // W_DIM") 

315 self.writeline("for i in range(0, step):") 

316 with self.indent(): 

317 self.writeline( 

318 f"in_offset = src_offset + i * merge_shape_{self.merge_dim_size - 1}" 

319 ) 

320 self.writeline( 

321 f"out_offset = dst_offset + i * merge_shape_{self.merge_dim_size - 1}" 

322 ) 

323 self.writeline("if tail > 0:") 

324 with self.indent(): 

325 self.writeline("new_offset = in_offset + iter * W_DIM") 

326 self.writeline("mask = offset < tail") 

327 self.writeline( 

328 "src = tl.load(x_ptr + new_offset + offset, mask=mask, cache_modifier='.cg')" 

329 ) 

330 self.writeline("src = tl.flip(src, [0])") 

331 self.writeline("mask = offset >= (W_DIM-tail)") 

332 self.writeline( 

333 "tl.store(out_ptr + out_offset - (W_DIM - tail) + offset, \ 

334 src, mask=mask, cache_modifier='.cg')" 

335 ) 

336 self.writeline("for j in range(0, iter):") 

337 with self.indent(): 

338 self.writeline("new_in_offset = in_offset + j * W_DIM") 

339 self.writeline( 

340 "new_out_offset = tail + out_offset + (iter - j - 1) * W_DIM" 

341 ) 

342 self.writeline( 

343 "src = tl.load(x_ptr + new_in_offset + offset, cache_modifier='.cg')" 

344 ) 

345 self.writeline("src = tl.flip(src, [0])") 

346 self.writeline( 

347 "tl.store(out_ptr + new_out_offset + offset, src, cache_modifier='.cg')" 

348 ) 

349 

350 def __kernel(self): 

351 """Generate kernel code body.""" 

352 # configuration. 

353 self.__config() 

354 kernel_signature = f"def {self.kernel_name}({self.__kernel_args()}):" 

355 self.writeline(kernel_signature) 

356 with self.indent(): 

357 self.writeline("pid_x = tl.program_id(0)") 

358 self.writeline("num_x = tl.num_programs(0)") 

359 self.writeline("pid_y = tl.program_id(1)") 

360 self.writeline("num_y = tl.num_programs(1)") 

361 # iteration on high dimension. 

362 self.writeline("for high_id in range(pid_x, high_task, num_x):") 

363 with self.indent(): 

364 self.writeline("src_offset = 0") 

365 self.writeline("dst_offset = 0") 

366 self.writeline("temp_high_id = high_id") 

367 # get src_offset and dst offset 

368 if self.merge_dim_size > 2: 

369 for i in range(self.merge_dim_size - 2): 

370 self.writeline(f"tmp_stride = merge_stride_{i} // low_task") 

371 self.writeline(f"id_{i} = temp_high_id // tmp_stride") 

372 self.writeline("temp_high_id = temp_high_id % tmp_stride") 

373 self.writeline(f"src_offset += id_{i} * merge_stride_{i}") 

374 if not self.merge_flip_dims_flags[i]: 

375 self.writeline(f"dst_offset += id_{i} * merge_stride_{i}") 

376 else: 

377 self.writeline( 

378 f"dst_offset += (merge_shape_{i} - id_{i} -1) * merge_stride_{i}" 

379 ) 

380 self.__kernel_flip_2d() 

381 elif self.merge_dim_size == 2: 

382 self.__kernel_flip_2d() 

383 elif self.merge_dim_size == 1: 

384 assert self.merge_flip_dims_flags[0] 

385 self.writeline("offset = tl.arange(0, W_DIM)") 

386 self.writeline( 

387 f"step = merge_shape_{self.merge_dim_size - 1} // num_y" 

388 ) 

389 self.writeline( 

390 f"tail = merge_shape_{self.merge_dim_size - 1} % num_y" 

391 ) 

392 self.writeline("# process step.") 

393 self.writeline("src_offset = pid_y * step") 

394 self.writeline("dst_offset = tail + (num_y - pid_y - 1) * step") 

395 self.writeline("step_iter = step // W_DIM") 

396 self.writeline("step_tail = step % W_DIM") 

397 self.writeline("for i in range(0, step_iter):") 

398 with self.indent(): 

399 self.writeline("in_offset = src_offset + i * W_DIM") 

400 self.writeline( 

401 "out_offset = dst_offset + step_tail + (step_iter - i - 1) * W_DIM" 

402 ) 

403 self.writeline( 

404 "src = tl.load(x_ptr + offset + in_offset, cache_modifier='.cg')" 

405 ) 

406 self.writeline("src = tl.flip(src, [0])") 

407 self.writeline( 

408 "tl.store(out_ptr + offset + out_offset, src, cache_modifier='.cg')" 

409 ) 

410 self.writeline("if step_tail > 0:") 

411 with self.indent(): 

412 self.writeline("in_offset = src_offset + step_iter * W_DIM") 

413 self.writeline("out_offset = dst_offset") 

414 self.writeline("mask = offset < step_tail") 

415 self.writeline( 

416 "src = tl.load(x_ptr + offset + in_offset, mask=mask, cache_modifier='.cg')" 

417 ) 

418 self.writeline("src = tl.flip(src, [0])") 

419 self.writeline("mask = offset >= (W_DIM - step_tail)") 

420 self.writeline( 

421 "tl.store(out_ptr + offset + out_offset - (W_DIM - step_tail), \ 

422 src, mask=mask, cache_modifier='.cg')" 

423 ) 

424 self.writeline("if pid_y == num_y - 1:") 

425 with self.indent(): 

426 self.writeline("# process tail.") 

427 self.writeline("src_offset = num_y * step") 

428 self.writeline("dst_offset = 0") 

429 self.writeline("tail_iter = tail // W_DIM") 

430 self.writeline("tail_remain = tail % W_DIM") 

431 self.writeline("for i in range(0, tail_iter):") 

432 with self.indent(): 

433 self.writeline("in_offset = src_offset + i * W_DIM") 

434 self.writeline( 

435 "out_offset = dst_offset + tail_remain + (tail_iter - i - 1) * W_DIM" 

436 ) 

437 self.writeline( 

438 "src = tl.load(x_ptr + offset + in_offset, cache_modifier='.cg')" 

439 ) 

440 self.writeline("src = tl.flip(src, [0])") 

441 self.writeline( 

442 "tl.store(out_ptr + offset + out_offset, src, cache_modifier='.cg')" 

443 ) 

444 self.writeline("if tail_remain > 0:") 

445 with self.indent(): 

446 self.writeline("in_offset = src_offset + tail_iter * W_DIM") 

447 self.writeline("out_offset = dst_offset") 

448 self.writeline("mask = offset < tail_remain") 

449 self.writeline( 

450 "src = tl.load(x_ptr + offset + in_offset, mask=mask, cache_modifier='.cg')" 

451 ) 

452 self.writeline("src = tl.flip(src, [0])") 

453 self.writeline("mask = offset >= (W_DIM-tail_remain)") 

454 self.writeline( 

455 "tl.store(out_ptr + offset + out_offset - (W_DIM - tail_remain), \ 

456 src, mask=mask, cache_modifier='.cg')" 

457 ) 

458 else: 

459 raise RuntimeError(f"merge dim size error({self.merge_dim_size})") 

460 

461 def __gen_code(self): 

462 """Entry point for code generation of flip.""" 

463 # generate imports. 

464 self.__imports() 

465 # generate wrapper function. 

466 self.__wrapper() 

467 

468 # generate kernel. 

469 self.__kernel() 

470 

471 def __kernel_args(self, is_declare=True): 

472 """Generate string type of jit kernel arguments.""" 

473 merge_shapes_args = [] 

474 merge_strides_args = [] 

475 for i in range(self.merge_dim_size): 

476 if is_declare: 

477 merge_shapes_args.append(f"merge_shape_{i}") 

478 merge_strides_args.append(f"merge_stride_{i}") 

479 else: 

480 merge_shapes_args.append(f"merge_shapes[{i}]") 

481 merge_strides_args.append(f"merge_strides[{i}]") 

482 merge_shapes_args_str = ", ".join(merge_shapes_args) 

483 merge_strides_args_str = ", ".join(merge_strides_args) 

484 

485 extra_args_str = f"{merge_shapes_args_str}, {merge_strides_args_str}" 

486 if is_declare: 

487 return f"x_ptr, out_ptr, {extra_args_str}, merge_dim_size, high_task: tl.constexpr, \ 

488 low_task: tl.constexpr, H_DIM: tl.constexpr, W_DIM: tl.constexpr" 

489 else: 

490 return f"x, out, {extra_args_str}, merge_dim_size, high_task, low_task" 

491 

492 def __call__(self, x: torch.Tensor, dims) -> torch.Tensor: 

493 """Call flip kernel.""" 

494 # initialize the funtion. 

495 # note: 

496 # - This function must be call first and only once. 

497 self.__init(x, dims) 

498 if (self.merge_flip_dim == 0) or (self.merge_dim_size == 0 or x.numel() <= 1): 

499 return x.clone() 

500 # get overload kernel. 

501 flip_dim_str = "_".join([str(i) for i in self.merge_flip_dims_flags]) 

502 self.kernel_name = self.kernel_name + "_flip_" + flip_dim_str 

503 key = f"{self.merge_dim_size}_{flip_dim_str}" 

504 if key not in self.cache: 

505 # generate code and cache. 

506 self.__gen_code() 

507 

508 file_name = f"flip_{key}_pid_{self.pid}.py" 

509 with open(cache_dir() / file_name, "wt", encoding="utf-8") as f: 

510 f.write(self.getvalue()) 

511 # load 

512 spec = importlib.util.spec_from_file_location( 

513 f"_gen_module_{key}_pid_{self.pid}", f.name 

514 ) 

515 m = importlib.util.module_from_spec(spec) 

516 # do not expose it to sys.modules 

517 # sys.modules["_add_module"] = m 

518 spec.loader.exec_module(m) 

519 overload = getattr(m, self.wrapper_func_name) 

520 self.cache[key] = overload 

521 

522 overload = self.cache[key] 

523 return overload(x, self.merge_shapes, self.merge_strides, self.merge_dim_size) 

524 

525 

526def flip(A: torch.Tensor, dims) -> torch.Tensor: 

527 logger.debug("GEMS_CAMBRICON FLIP") 

528 if not A.is_contiguous(): 

529 A = A.contiguous() 

530 return FlipKernelCode()(A, dims)