Coverage for src/flag_gems/ops/pad.py: 99%

279 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14# --------------------------- padding wrapper genration ----------------------------------- 

15def parameter_for_wrapper() -> str: 

16 """Generate parameter declaration with type annotation for wrapper function. 

17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

18 """ 

19 parameters: List[str] = [] 

20 

21 parameters.append("in0") 

22 parameters.append("pad") 

23 parameters.append("mode") 

24 parameters.append("value=0") 

25 return ", ".join(parameters) 

26 

27 

28def parameter_for_wrapper_out() -> str: 

29 """Generate parameter declaration with type annotation for wrapper function. 

30 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

31 """ 

32 parameters: List[str] = [] 

33 

34 parameters.append("in0") 

35 parameters.append("out0") 

36 parameters.append("dst_shape") 

37 parameters.append("pad_before") 

38 parameters.append("pad_after") 

39 parameters.append("mode") 

40 parameters.append("value=0") 

41 

42 return ", ".join(parameters) 

43 

44 

45def parameter_ref_for_wrapper() -> str: 

46 """Generate parameter reference for wrapper function. 

47 Example: in0, val0, out0, out0_offset 

48 """ 

49 parameters: List[str] = [] 

50 

51 parameters.append("in0") 

52 parameters.append("out0") 

53 parameters.append("dst_shape") 

54 parameters.append("pad_before") 

55 parameters.append("pad_after") 

56 parameters.append("mode") 

57 parameters.append("value") 

58 

59 return ", ".join(parameters) 

60 

61 

62def output_ref_for_wrapper() -> str: 

63 return "out0" 

64 

65 

66def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

67 code.writeline("import math") 

68 code.writeline("import torch") 

69 code.writeline("import triton") 

70 code.writeline("from triton import language as tl") 

71 code.newline() 

72 code.writeline("from flag_gems.utils.libentry import libentry") 

73 code.writeline("from flag_gems.runtime import torch_device_fn") 

74 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

75 code.writeline("from flag_gems.utils.type_utils import type_promotion") 

76 code.newline() 

77 code.newline() 

78 return code 

79 

80 

81def generate_functional_padding_wrapper( 

82 wrapper_name: str, 

83 destination_passing_func_name: str, 

84 code: IndentedBuffer, 

85) -> IndentedBuffer: 

86 # wrapper signature 

87 parameters: str = parameter_for_wrapper() 

88 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

89 code.writeline(wrapper_signature) 

90 

91 with code.indent(): 

92 code.writeline("ndim = in0.ndim") 

93 code.writeline("pad_size = len(pad)") 

94 code.writeline("assert pad_size % 2 == 0") 

95 code.newline() 

96 code.writeline("pad_before = [0 for _ in range(ndim)]") 

97 code.writeline("pad_after = [0 for _ in range(ndim)]") 

98 code.newline() 

99 code.writeline("pad_pair = pad_size // 2 ") 

100 code.writeline("for i in range(pad_pair): ") 

101 with code.indent(): 

102 code.writeline("pad_before[ndim - i - 1] = pad[2 * i]") 

103 code.writeline("pad_after[ndim - i - 1] = pad[2 * i + 1]") 

104 code.writeline("dst_shape = list(in0.shape)") 

105 code.writeline("for i in range(ndim): ") 

106 with code.indent(): 

107 code.writeline("dst_shape[i] += pad_before[i] + pad_after[i]") 

108 

109 code.writeline( 

110 ("out0 = torch.empty(dst_shape, device=in0.device, dtype=in0.dtype)") 

111 ) 

112 

113 # call destination_passing_func 

114 output_names: str = output_ref_for_wrapper() 

115 call_str = ( 

116 f"{output_names} = {destination_passing_func_name}" 

117 f"({parameter_ref_for_wrapper()})" 

118 ) 

119 code.writeline(call_str) 

120 

121 return_str = "return out0" 

122 code.writeline(return_str) 

123 code.newline() 

124 code.newline() 

125 

126 return code 

127 

128 

129def generate_destination_passing_padding_wrapper( 

130 rank: int, 

131 wrapper_name: str, 

132 kernel_name: str, 

133 code: IndentedBuffer, 

134) -> IndentedBuffer: 

135 # wrapper signature 

136 parameters: str = parameter_for_wrapper_out() 

137 

138 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

139 code.writeline(wrapper_signature) 

140 

141 with code.indent(): 

142 # docstring 

143 code.writeline("BLOCK_SIZE = 256") 

144 code.writeline("grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)") 

145 code.newline() 

146 

147 code.writeline("x_shape = in0.shape") 

148 code.writeline("in_strides0 = in0.stride()") 

149 code.writeline("out_strides = out0.stride()") 

150 

151 # input strides for each input tensor w.r.t. the task index space 

152 if rank > 0: 

153 code.writeline("# strides of each tensor argument w.r.t the task space") 

154 for i in range(rank): 

155 code.writeline(f"valid_dim{i}_start = pad_before[{i}]") 

156 

157 code.writeline(f"valid_dim{i}_end = dst_shape[{i}] - pad_after[{i}]") 

158 

159 code.newline() 

160 

161 code.writeline("# Check which dimensions have padding") 

162 for i in range(rank): 

163 code.writeline( 

164 f"dim{i}_has_pad = pad_before[{i}] > 0 or pad_after[{i}] > 0" 

165 ) 

166 code.writeline("IS_CONSTANT = mode == 'constant'") 

167 code.writeline("IS_REFLECT = mode == 'reflect'") 

168 code.writeline("IS_REPLICATE = mode == 'replicate'") 

169 code.writeline("IS_CIRCULAR = mode == 'circular'") 

170 

171 code.newline() 

172 

173 # grid 

174 code.writeline("# kernel launch") 

175 

176 # launch kernel 

177 code.writeline("with torch_device_fn.device(in0.device):") 

178 with code.indent(): 

179 kernel_launch: str = f"{kernel_name}[grid](" 

180 code.writeline(kernel_launch) 

181 

182 with code.indent(): 

183 code.writeline("in0, out0, ") 

184 

185 if rank > 0: 

186 s = ", ".join(f"x_shape[{j}]" for j in range(rank)) 

187 code.writeline(f"{s}, # shape for x") 

188 

189 s = ", ".join(f"in_strides0[{j}]" for j in range(rank)) 

190 code.writeline(f"{s}, # stride for x") 

191 

192 s = ", ".join(f"out_strides[{j}]" for j in range(rank)) 

193 code.writeline(f"{s}, # stride for out") 

194 

195 s = ", ".join(f"valid_dim{j}_start" for j in range(rank)) 

196 code.writeline(f"{s}, # valid dim start") 

197 

198 s = ", ".join(f"valid_dim{j}_end" for j in range(rank)) 

199 code.writeline(f"{s}, # valid dim end") 

200 

201 s = ", ".join(f"bool(dim{i}_has_pad)" for i in range(rank)) 

202 code.writeline(f"{s}, # dim has padding flags") 

203 

204 code.writeline("in0.numel(), ") 

205 code.writeline("out0.numel(), ") 

206 code.writeline("value, ") 

207 code.writeline("IS_CONSTANT, ") 

208 code.writeline("IS_REFLECT, ") 

209 code.writeline("IS_REPLICATE, ") 

210 code.writeline("IS_CIRCULAR, ") 

211 code.writeline("BLOCK_SIZE, ") 

212 code.writeline(")") 

213 

214 code.writeline("return out0") 

215 code.newline() 

216 code.newline() 

217 return code 

218 

219 

220def generate_pad_kernel( 

221 rank: int, 

222 kernel_name: str, 

223 code: IndentedBuffer, 

224) -> IndentedBuffer: 

225 # make the inlined function visible in the context 

226 code.newline() 

227 

228 # the decorators 

229 code.writeline("@libentry()") 

230 non_specialize_arg_names = ["value"] 

231 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

232 

233 # signature 

234 code.writeline(f"def {kernel_name}(") 

235 with code.indent(): 

236 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type") 

237 

238 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type") 

239 

240 if rank > 0: 

241 # shape for inputs 

242 shape_args = ", ".join(f"x_shape{j}: int" for j in range(rank)) 

243 code.writeline(f"{shape_args}, # shape for x") 

244 

245 # shape for inputs 

246 stride_args = ", ".join(f"in_strides{j}: int" for j in range(rank)) 

247 code.writeline(f"{stride_args}, # stride for x") 

248 

249 # shape for inputs 

250 stride_args = ", ".join(f"out_strides{j}: int" for j in range(rank)) 

251 code.writeline(f"{stride_args}, # stride for out") 

252 

253 # shape for inputs 

254 stride_args = ", ".join(f"valid_dim{j}_start: int" for j in range(rank)) 

255 code.writeline(f"{stride_args}, # valid dim start") 

256 

257 # shape for inputs 

258 stride_args = ", ".join(f"valid_dim{j}_end: int" for j in range(rank)) 

259 code.writeline(f"{stride_args}, # valid dim end") 

260 

261 for i in range(rank): 

262 code.writeline(f"dim{i}_has_pad: tl.constexpr, ") 

263 

264 code.writeline("in_elem_cnt: tl.constexpr, ") 

265 code.writeline("out_elem_cnt: tl.constexpr, ") 

266 code.writeline("value, # padding value") 

267 code.writeline("IS_CONSTANT: tl.constexpr, ") 

268 code.writeline("IS_REFLECT: tl.constexpr, ") 

269 code.writeline("IS_REPLICATE: tl.constexpr, ") 

270 code.writeline("IS_CIRCULAR: tl.constexpr, ") 

271 code.writeline("BLOCK_SIZE: tl.constexpr, ") 

272 

273 code.writeline("):") 

274 

275 with code.indent(): 

276 code.writeline("pid = tle.program_id(0)") 

277 code.writeline("block_offset = pid * BLOCK_SIZE") 

278 code.writeline("offset = block_offset + tl.arange(0, BLOCK_SIZE)") 

279 code.newline() 

280 

281 code.writeline("remaining = offset ") 

282 for i in range(rank): 

283 code.writeline(f"idx = remaining // out_strides{i}") 

284 code.writeline(f"dst_index_{i} = idx") 

285 code.writeline(f"remaining = remaining - idx * out_strides{i}") 

286 code.newline() 

287 

288 code.writeline("if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)") 

289 code.writeline("if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)") 

290 

291 code.writeline( 

292 "cond = ((dst_index_0 >= valid_dim0_start) & (dst_index_0 < valid_dim0_end))" 

293 ) 

294 

295 for i in range(1, rank): 

296 code.writeline( 

297 f"cond &= ((dst_index_{i} >= valid_dim{i}_start) & (dst_index_{i} < valid_dim{i}_end))" 

298 ) 

299 

300 code.writeline( 

301 "if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)" 

302 ) 

303 

304 for i in range(rank): 

305 code.writeline(f"src_index_{i} = dst_index_{i} - valid_dim{i}_start ") 

306 

307 for i in range(rank): 

308 code.writeline( 

309 f"src_index_{i} = tl.where(src_index_{i} < 0, 0, src_index_{i})" 

310 ) 

311 

312 code.newline() 

313 code.writeline("if IS_REFLECT: ") 

314 with code.indent(): 

315 for i in range(rank): 

316 code.writeline( 

317 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start), 

318 valid_dim{i}_start - dst_index_{i}, src_index_{i})""" 

319 ) 

320 for i in range(rank): 

321 code.writeline( 

322 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} >= valid_dim{i}_end), 

323 (x_shape{i} + valid_dim{i}_start - 1) * 2 - dst_index_{i} - valid_dim{i}_start, src_index_{i})""" 

324 ) 

325 

326 code.newline() 

327 code.writeline("if IS_REPLICATE: ") 

328 with code.indent(): 

329 for i in range(rank): 

330 code.writeline( 

331 f"src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start), 0, src_index_{i})" 

332 ) 

333 for i in range(rank): 

334 end_cond = f"dst_index_{i} >= valid_dim{i}_end" 

335 code.writeline( 

336 f"src_index_{i} = tl.where(dim{i}_has_pad & ({end_cond}), " 

337 f"x_shape{i} - 1, src_index_{i})" 

338 ) 

339 

340 code.newline() 

341 code.writeline("if IS_CIRCULAR: ") 

342 with code.indent(): 

343 for i in range(rank): 

344 code.writeline( 

345 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start), 

346 dst_index_{i} + x_shape{i} - valid_dim{i}_start, src_index_{i})""" 

347 ) 

348 for i in range(rank): 

349 code.writeline( 

350 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} >= valid_dim{i}_end), 

351 dst_index_{i} - valid_dim{i}_end, src_index_{i})""" 

352 ) 

353 

354 code.newline() 

355 

356 code.writeline("src_offset = src_index_0 * in_strides0") 

357 for i in range(1, rank): 

358 code.writeline(f"src_offset += src_index_{i} * in_strides{i}") 

359 

360 code.writeline("load_cond = src_index_0 < x_shape0") 

361 for i in range(1, rank): 

362 code.writeline(f"load_cond &= src_index_{i} < x_shape{i}") 

363 

364 code.writeline("if IS_CONSTANT: ") 

365 with code.indent(): 

366 # use explicit comparison and bitwise-and for non-scalar masks 

367 code.writeline( 

368 "x_val = tl.load(in0_ptr + src_offset, mask=((if_pad == 0) & load_cond), other=value)" 

369 ) 

370 code.writeline("else: ") 

371 with code.indent(): 

372 code.writeline( 

373 "x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)" 

374 ) 

375 code.writeline("tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)") 

376 

377 return code 

378 

379 

380def generate_code( 

381 inputs: Tuple[Any], 

382 wrapper_name: str, 

383 destination_passing_func_name: str, 

384 kernel_name: str, 

385 code: IndentedBuffer, 

386) -> IndentedBuffer: 

387 shape = inputs[0].shape 

388 rank = len(shape) 

389 

390 # the only runtime determined factor is the rank of the task space 

391 code = generate_imports(code) 

392 code = generate_functional_padding_wrapper( 

393 wrapper_name, destination_passing_func_name, code 

394 ) 

395 code = generate_destination_passing_padding_wrapper( 

396 rank, destination_passing_func_name, kernel_name, code 

397 ) 

398 code = generate_pad_kernel(rank, kernel_name, code) 

399 return code 

400 

401 

402class PadFunction: 

403 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

404 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

405 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

406 """ 

407 

408 def __init__(self): 

409 self.pid = os.getpid() 

410 self.overloads: Mapping[str, Callable] = {} 

411 

412 def __call__(self, *args, **kwargs): 

413 # note: kwargs should not be used in JITFunction directly 

414 key = f"{self.arg_key(*args)}" 

415 if key in self.overloads: 

416 overload = self.overloads[key] 

417 else: 

418 # generate file & import it 

419 code = IndentedBuffer() 

420 code = generate_code( 

421 args, 

422 "_pad_wrapper", 

423 "_pad_wrapper_out", 

424 "_pad_jit_function", 

425 code, 

426 ) 

427 

428 file_name = f"constant_pad_rank_{key}.py" 

429 file_path = code_cache_dir() / file_name 

430 write_atomic(file_path, code.getvalue()) 

431 

432 # load 

433 spec = importlib.util.spec_from_file_location( 

434 f"_gen_module_rank_{key}", 

435 file_path, 

436 ) 

437 

438 m = importlib.util.module_from_spec(spec) 

439 # do not expose it to sys.modules 

440 # sys.modules["_add_module"] = m 

441 spec.loader.exec_module(m) 

442 overload = getattr(m, "_pad_wrapper") 

443 self.overloads[key] = overload 

444 return overload(*args, **kwargs) 

445 

446 def arg_key(self, *args): 

447 tensors = [item for item in args if torch.is_tensor(item)] 

448 max_rank = max(item.ndim for item in tensors) 

449 return max_rank 

450 

451 

452_pad_func = PadFunction() 

453 

454 

455def pad(self, pad, mode="constant", value=None): 

456 logger.debug("GEMS CONSTANT PAD ND") 

457 

458 ndim = self.ndim 

459 

460 if value is None: 

461 value = 0.0 

462 

463 pad_pairs = len(pad) // 2 

464 

465 if mode == "reflect": 

466 for i in range(pad_pairs): 

467 pad_l, pad_r = pad[2 * i], pad[2 * i + 1] 

468 input_size = self.shape[ndim - 1 - i] 

469 assert ( 

470 pad_l < input_size and pad_r < input_size 

471 ), \ 

472 f"padding size should be less than the corresponding input dimension, \ 

473 but got padding size: {pad_l}, {pad_r}, input size: {self.shape}" 

474 

475 if mode == "circular": 

476 for i in range(pad_pairs): 

477 pad_l, pad_r = pad[2 * i], pad[2 * i + 1] 

478 input_size = self.shape[ndim - 1 - i] 

479 assert ( 

480 pad_l <= input_size and pad_r <= input_size 

481 ), "Padding value causes wrapping around more than once." 

482 

483 out = _pad_func(self, pad, mode, float(value)) 

484 return out 

485 

486 

487def constant_pad_nd(self, pad_list, value=0): 

488 return pad(self, pad_list, mode="constant", value=value)