Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/pad.py: 0%

285 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14# --------------------------- padding wrapper genration ----------------------------------- 

15def parameter_for_wrapper() -> str: 

16 """Generate parameter declaration with type annotation for wrapper function. 

17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

18 """ 

19 parameters: List[str] = [] 

20 

21 parameters.append("in0") 

22 parameters.append("pad") 

23 parameters.append("mode") 

24 parameters.append("value=0") 

25 return ", ".join(parameters) 

26 

27 

28def parameter_for_wrapper_out() -> str: 

29 """Generate parameter declaration with type annotation for wrapper function. 

30 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

31 """ 

32 parameters: List[str] = [] 

33 

34 parameters.append("in0") 

35 parameters.append("out0") 

36 parameters.append("dst_shape") 

37 parameters.append("pad_before") 

38 parameters.append("pad_after") 

39 parameters.append("mode") 

40 parameters.append("value=0") 

41 

42 return ", ".join(parameters) 

43 

44 

45def parameter_ref_for_wrapper() -> str: 

46 """Generate parameter reference for wrapper function. 

47 Example: in0, val0, out0, out0_offset 

48 """ 

49 parameters: List[str] = [] 

50 

51 parameters.append("in0") 

52 parameters.append("out0") 

53 parameters.append("dst_shape") 

54 parameters.append("pad_before") 

55 parameters.append("pad_after") 

56 parameters.append("mode") 

57 parameters.append("value") 

58 

59 return ", ".join(parameters) 

60 

61 

62def output_ref_for_wrapper() -> str: 

63 return "out0" 

64 

65 

66def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

67 code.writeline("import math") 

68 code.writeline("import torch") 

69 code.writeline("import triton") 

70 code.writeline("from triton import language as tl") 

71 code.newline() 

72 code.writeline("from flag_gems.utils.libentry import libentry") 

73 code.writeline("from flag_gems.runtime import torch_device_fn") 

74 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

75 code.writeline("from flag_gems.utils.type_utils import type_promotion") 

76 code.newline() 

77 code.newline() 

78 return code 

79 

80 

81def generate_functional_padding_wrapper( 

82 wrapper_name: str, 

83 destination_passing_func_name: str, 

84 code: IndentedBuffer, 

85) -> IndentedBuffer: 

86 # wrapper signature 

87 parameters: str = parameter_for_wrapper() 

88 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

89 code.writeline(wrapper_signature) 

90 

91 with code.indent(): 

92 code.writeline("ndim = in0.ndim") 

93 code.writeline("pad_size = len(pad)") 

94 code.writeline("assert pad_size % 2 == 0") 

95 code.newline() 

96 code.writeline("pad_before = [0 for _ in range(ndim)]") 

97 code.writeline("pad_after = [0 for _ in range(ndim)]") 

98 code.newline() 

99 code.writeline("pad_pair = pad_size // 2 ") 

100 code.writeline("for i in range(pad_pair): ") 

101 with code.indent(): 

102 code.writeline("pad_before[ndim - i - 1] = pad[2 * i]") 

103 code.writeline("pad_after[ndim - i - 1] = pad[2 * i + 1]") 

104 code.writeline("dst_shape = list(in0.shape)") 

105 code.writeline("for i in range(ndim): ") 

106 with code.indent(): 

107 code.writeline("dst_shape[i] += pad_before[i] + pad_after[i]") 

108 

109 code.writeline( 

110 ("out0 = torch.empty(dst_shape, device=in0.device, dtype=in0.dtype)") 

111 ) 

112 

113 # call destination_passing_func 

114 output_names: str = output_ref_for_wrapper() 

115 call_str = ( 

116 f"{output_names} = {destination_passing_func_name}" 

117 f"({parameter_ref_for_wrapper()})" 

118 ) 

119 code.writeline(call_str) 

120 

121 return_str = "return out0" 

122 code.writeline(return_str) 

123 code.newline() 

124 code.newline() 

125 

126 return code 

127 

128 

129def generate_destination_passing_padding_wrapper( 

130 rank: int, 

131 wrapper_name: str, 

132 kernel_name: str, 

133 code: IndentedBuffer, 

134) -> IndentedBuffer: 

135 # wrapper signature 

136 parameters: str = parameter_for_wrapper_out() 

137 

138 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

139 code.writeline(wrapper_signature) 

140 

141 with code.indent(): 

142 # docstring 

143 code.writeline("num_ctas = 12") 

144 code.writeline( 

145 "BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(out0.numel(), num_ctas))" 

146 ) 

147 code.writeline("grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)") 

148 code.newline() 

149 

150 code.writeline("x_shape = in0.shape") 

151 code.writeline("in_strides0 = in0.stride()") 

152 code.writeline("out_strides = out0.stride()") 

153 

154 # input strides for each input tensor w.r.t. the task index space 

155 if rank > 0: 

156 code.writeline("# strides of each tensor argument w.r.t the task space") 

157 for i in range(rank): 

158 code.writeline(f"valid_dim{i}_start = pad_before[{i}]") 

159 

160 code.writeline(f"valid_dim{i}_end = dst_shape[{i}] - pad_after[{i}]") 

161 

162 code.newline() 

163 

164 code.writeline("IS_CONSTANT = mode == 'constant'") 

165 code.writeline("IS_REFLECT = mode == 'reflect'") 

166 code.writeline("IS_REPLICATE = mode == 'replicate'") 

167 code.writeline("IS_CIRCULAR = mode == 'circular'") 

168 

169 code.newline() 

170 

171 # grid 

172 code.writeline("# kernel launch") 

173 code.writeline("import os") 

174 code.writeline('os.environ["TRITONXPU_OTHER_SIM"] = "1"') 

175 code.writeline('os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"') 

176 # launch kernel 

177 code.writeline("with torch_device_fn.device(in0.device):") 

178 with code.indent(): 

179 kernel_launch: str = f"{kernel_name}[grid](" 

180 code.writeline(kernel_launch) 

181 

182 with code.indent(): 

183 code.writeline("in0, out0, ") 

184 

185 if rank > 0: 

186 s = ", ".join(f"x_shape[{j}]" for j in range(rank)) 

187 code.writeline(f"{s}, # shape for x") 

188 

189 s = ", ".join(f"in_strides0[{j}]" for j in range(rank)) 

190 code.writeline(f"{s}, # stride for x") 

191 

192 s = ", ".join(f"out_strides[{j}]" for j in range(rank)) 

193 code.writeline(f"{s}, # stride for out") 

194 

195 s = ", ".join(f"valid_dim{j}_start" for j in range(rank)) 

196 code.writeline(f"{s}, # valid dim start") 

197 

198 s = ", ".join(f"valid_dim{j}_end" for j in range(rank)) 

199 code.writeline(f"{s}, # valid dim end") 

200 

201 code.writeline("in0.numel(), ") 

202 code.writeline("out0.numel(), ") 

203 code.writeline("value, ") 

204 code.writeline("IS_CONSTANT, ") 

205 code.writeline("IS_REFLECT, ") 

206 code.writeline("IS_REPLICATE, ") 

207 code.writeline("IS_CIRCULAR, ") 

208 code.writeline("BLOCK_SIZE, ") 

209 code.writeline("buffer_size_limit=512, ") 

210 code.writeline(")") 

211 

212 code.writeline('if "TRITONXPU_OTHER_SIM" in os.environ: ') 

213 with code.indent(): 

214 code.writeline('del os.environ["TRITONXPU_OTHER_SIM"]') 

215 

216 code.writeline('if "TRITONXPU_STORE_MASK_SIM" in os.environ: ') 

217 with code.indent(): 

218 code.writeline('del os.environ["TRITONXPU_STORE_MASK_SIM"]') 

219 

220 code.writeline("return out0") 

221 code.newline() 

222 code.newline() 

223 return code 

224 

225 

226def generate_pad_kernel( 

227 rank: int, 

228 kernel_name: str, 

229 code: IndentedBuffer, 

230) -> IndentedBuffer: 

231 # make the inlined function visible in the context 

232 code.newline() 

233 

234 # the decorators 

235 code.writeline("@libentry()") 

236 non_specialize_arg_names = ["value"] 

237 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

238 

239 # signature 

240 code.writeline(f"def {kernel_name}(") 

241 with code.indent(): 

242 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type") 

243 

244 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type") 

245 

246 if rank > 0: 

247 # shape for inputs 

248 shape_args = ", ".join(f"x_shape{j}: tl.constexpr" for j in range(rank)) 

249 code.writeline(f"{shape_args}, # shape for x") 

250 

251 # shape for inputs 

252 stride_args = ", ".join(f"in_strides{j}: tl.constexpr" for j in range(rank)) 

253 code.writeline(f"{stride_args}, # stride for x") 

254 

255 # shape for inputs 

256 stride_args = ", ".join( 

257 f"out_strides{j}: tl.constexpr" for j in range(rank) 

258 ) 

259 code.writeline(f"{stride_args}, # stride for out") 

260 

261 # shape for inputs 

262 stride_args = ", ".join( 

263 f"valid_dim{j}_start: tl.constexpr" for j in range(rank) 

264 ) 

265 code.writeline(f"{stride_args}, # valid dim start") 

266 

267 # shape for inputs 

268 stride_args = ", ".join( 

269 f"valid_dim{j}_end: tl.constexpr" for j in range(rank) 

270 ) 

271 code.writeline(f"{stride_args}, # valid dim end") 

272 

273 code.writeline("in_elem_cnt: tl.constexpr, ") 

274 code.writeline("out_elem_cnt: tl.constexpr, ") 

275 code.writeline("value, # padding value") 

276 code.writeline("IS_CONSTANT: tl.constexpr, ") 

277 code.writeline("IS_REFLECT: tl.constexpr, ") 

278 code.writeline("IS_REPLICATE: tl.constexpr, ") 

279 code.writeline("IS_CIRCULAR: tl.constexpr, ") 

280 code.writeline("BLOCK_SIZE: tl.constexpr, ") 

281 

282 code.writeline("):") 

283 

284 with code.indent(): 

285 code.writeline("pid = tle.program_id(0)") 

286 code.writeline("block_offset = pid * BLOCK_SIZE") 

287 code.writeline("offset = block_offset + tl.arange(0, BLOCK_SIZE)") 

288 code.newline() 

289 

290 code.writeline("remaining = offset ") 

291 for i in range(rank): 

292 code.writeline(f"idx = remaining // out_strides{i}") 

293 code.writeline(f"dst_index_{i} = idx") 

294 code.writeline(f"remaining = remaining - idx * out_strides{i}") 

295 code.newline() 

296 

297 code.writeline("if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)") 

298 code.writeline("if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)") 

299 

300 code.writeline( 

301 "cond = ((dst_index_0 >= valid_dim0_start) & (dst_index_0 < valid_dim0_end))" 

302 ) 

303 

304 for i in range(1, rank): 

305 code.writeline( 

306 f"cond &= ((dst_index_{i} >= valid_dim{i}_start) & (dst_index_{i} < valid_dim{i}_end))" 

307 ) 

308 

309 code.writeline( 

310 "if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)" 

311 ) 

312 

313 for i in range(rank): 

314 code.writeline(f"src_index_{i} = dst_index_{i} - valid_dim{i}_start ") 

315 

316 for i in range(rank): 

317 code.writeline( 

318 f"src_index_{i} = tl.where(src_index_{i} < 0, 0, src_index_{i})" 

319 ) 

320 

321 code.newline() 

322 code.writeline("if IS_REFLECT: ") 

323 with code.indent(): 

324 for i in range(rank): 

325 code.writeline( 

326 f"""src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start, 

327 valid_dim{i}_start - dst_index_{i}, src_index_{i})""" 

328 ) 

329 for i in range(rank): 

330 code.writeline( 

331 f"""src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end, 

332 (x_shape{i} + valid_dim{i}_start - 1) * 2 - dst_index_{i} - valid_dim{i}_start, src_index_{i})""" 

333 ) 

334 

335 code.newline() 

336 code.writeline("if IS_REPLICATE: ") 

337 with code.indent(): 

338 for i in range(rank): 

339 code.writeline( 

340 f"src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start, 0, src_index_{i})" 

341 ) 

342 for i in range(rank): 

343 code.writeline( 

344 f"src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end, x_shape{i} - 1, src_index_{i})" 

345 ) 

346 

347 code.newline() 

348 code.writeline("if IS_CIRCULAR: ") 

349 with code.indent(): 

350 for i in range(rank): 

351 code.writeline( 

352 f"""src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start, 

353 dst_index_{i} + x_shape{i} - valid_dim{i}_start, src_index_{i})""" 

354 ) 

355 for i in range(rank): 

356 code.writeline( 

357 f"""src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end, 

358 dst_index_{i} - valid_dim{i}_end, src_index_{i})""" 

359 ) 

360 

361 code.newline() 

362 

363 code.writeline("src_offset = src_index_0 * in_strides0") 

364 for i in range(1, rank): 

365 code.writeline(f"src_offset += src_index_{i} * in_strides{i}") 

366 

367 code.writeline(f"load_cond = src_index_{i} < x_shape{i}") 

368 for i in range(1, rank): 

369 code.writeline(f"load_cond &= src_index_{i} < x_shape{i}") 

370 

371 code.writeline("if IS_CONSTANT: ") 

372 with code.indent(): 

373 code.writeline( 

374 "x_val = tl.load(in0_ptr + src_offset, mask=((if_pad == 0) & load_cond), other=value)" 

375 ) 

376 code.writeline("else: ") 

377 with code.indent(): 

378 code.writeline( 

379 "x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)" 

380 ) 

381 code.writeline("tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)") 

382 

383 return code 

384 

385 

386def generate_code( 

387 inputs: Tuple[Any], 

388 wrapper_name: str, 

389 destination_passing_func_name: str, 

390 kernel_name: str, 

391 code: IndentedBuffer, 

392) -> IndentedBuffer: 

393 shape = inputs[0].shape 

394 rank = len(shape) 

395 

396 # the only runtime determined factor is the rank of the task space 

397 code = generate_imports(code) 

398 code = generate_functional_padding_wrapper( 

399 wrapper_name, destination_passing_func_name, code 

400 ) 

401 code = generate_destination_passing_padding_wrapper( 

402 rank, destination_passing_func_name, kernel_name, code 

403 ) 

404 code = generate_pad_kernel(rank, kernel_name, code) 

405 return code 

406 

407 

408class PadFunction: 

409 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

410 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

411 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

412 """ 

413 

414 def __init__(self): 

415 self.pid = os.getpid() 

416 self.overloads: Mapping[str, Callable] = {} 

417 

418 def __call__(self, *args, **kwargs): 

419 # note: kwargs should not be used in JITFunction directly 

420 key = f"{self.arg_key(*args)}" 

421 if key in self.overloads: 

422 overload = self.overloads[key] 

423 else: 

424 # generate file & import it 

425 code = IndentedBuffer() 

426 code = generate_code( 

427 args, 

428 "_wrapper", 

429 "_wrapper_out", 

430 "_jit_function", 

431 code, 

432 ) 

433 

434 file_name = f"constant_pad_rank_{key}_pid_{self.pid}.py" 

435 

436 with open(cache_dir() / file_name, "wt", encoding="utf-8") as f: 

437 f.write(code.getvalue()) 

438 

439 # load 

440 spec = importlib.util.spec_from_file_location( 

441 f"_gen_module_rank_{key}_pid_{self.pid}", 

442 f.name, 

443 ) 

444 

445 m = importlib.util.module_from_spec(spec) 

446 # do not expose it to sys.modules 

447 # sys.modules["_add_module"] = m 

448 spec.loader.exec_module(m) 

449 overload = getattr(m, "_wrapper") 

450 self.overloads[key] = overload 

451 return overload(*args, **kwargs) 

452 

453 def arg_key(self, *args): 

454 tensors = [item for item in args if torch.is_tensor(item)] 

455 max_rank = max(item.ndim for item in tensors) 

456 return max_rank 

457 

458 

459_pad_func = PadFunction() 

460 

461 

462def pad(self, pad, mode="constant", value=None): 

463 logger.debug("GEMS CONSTANT PAD ND") 

464 

465 ndim = self.ndim 

466 

467 if value is None: 

468 value = 0.0 

469 

470 if mode == "reflect": 

471 ndim //= 2 

472 assert ( 

473 len(pad) == 2 * ndim 

474 ), f"padding size is expected to be {2 * ndim}, but got {len(pad)}" 

475 

476 for i in range(ndim): 

477 pad_l, pad_r = pad[2 * i], pad[2 * i + 1] 

478 input_l, input_r = ( 

479 self.shape[ndim - (2 * i + 1) - 1], 

480 self.shape[ndim - (2 * i + 1)], 

481 ) 

482 assert ( 

483 pad_l < input_l and pad_r < input_r 

484 ), \ 

485 f"padding size should be less than the corresponding input dimension, \ 

486 but got padding size: {pad_l}, {pad_r}, input size: {self.shape}" 

487 

488 if mode == "circular": 

489 ndim //= 2 

490 assert ( 

491 len(pad) == 2 * ndim 

492 ), f"padding size is expected to be {2 * ndim}, but got {len(pad)}" 

493 for i in range(ndim): 

494 pad_l, pad_r = pad[2 * i], pad[2 * i + 1] 

495 input_size = self.shape[ndim - i - 1] 

496 assert ( 

497 pad_l <= input_size and pad_r <= input_size 

498 ), "Padding value causes wrapping around more than once." 

499 

500 out = _pad_func(self, pad, mode, float(value)) 

501 return out 

502 

503 

504def constant_pad_nd(self, pad, value=0): 

505 return pad(self, pad, mode="constant", value=value)