Coverage for src/flag_gems/ops/pad.py: 99%

274 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14# --------------------------- padding wrapper genration ----------------------------------- 

15def parameter_for_wrapper() -> str: 

16 """Generate parameter declaration with type annotation for wrapper function. 

17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

18 """ 

19 parameters: List[str] = [] 

20 

21 parameters.append("in0") 

22 parameters.append("pad") 

23 parameters.append("mode") 

24 parameters.append("value=0") 

25 return ", ".join(parameters) 

26 

27 

28def parameter_for_wrapper_out() -> str: 

29 """Generate parameter declaration with type annotation for wrapper function. 

30 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

31 """ 

32 parameters: List[str] = [] 

33 

34 parameters.append("in0") 

35 parameters.append("out0") 

36 parameters.append("dst_shape") 

37 parameters.append("pad_before") 

38 parameters.append("pad_after") 

39 parameters.append("mode") 

40 parameters.append("value=0") 

41 

42 return ", ".join(parameters) 

43 

44 

45def parameter_ref_for_wrapper() -> str: 

46 """Generate parameter reference for wrapper function. 

47 Example: in0, val0, out0, out0_offset 

48 """ 

49 parameters: List[str] = [] 

50 

51 parameters.append("in0") 

52 parameters.append("out0") 

53 parameters.append("dst_shape") 

54 parameters.append("pad_before") 

55 parameters.append("pad_after") 

56 parameters.append("mode") 

57 parameters.append("value") 

58 

59 return ", ".join(parameters) 

60 

61 

62def output_ref_for_wrapper() -> str: 

63 return "out0" 

64 

65 

66def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

67 code.writeline("import math") 

68 code.writeline("import torch") 

69 code.writeline("import triton") 

70 code.writeline("from triton import language as tl") 

71 code.newline() 

72 code.writeline("from flag_gems.utils.libentry import libentry") 

73 code.writeline("from flag_gems.runtime import torch_device_fn") 

74 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

75 code.writeline("from flag_gems.utils.type_utils import type_promotion") 

76 code.newline() 

77 code.newline() 

78 return code 

79 

80 

81def generate_functional_padding_wrapper( 

82 wrapper_name: str, 

83 destination_passing_func_name: str, 

84 code: IndentedBuffer, 

85) -> IndentedBuffer: 

86 # wrapper signature 

87 parameters: str = parameter_for_wrapper() 

88 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

89 code.writeline(wrapper_signature) 

90 

91 with code.indent(): 

92 code.writeline("ndim = in0.ndim") 

93 code.writeline("pad_size = len(pad)") 

94 code.writeline("assert pad_size % 2 == 0") 

95 code.newline() 

96 code.writeline("pad_before = [0 for _ in range(ndim)]") 

97 code.writeline("pad_after = [0 for _ in range(ndim)]") 

98 code.newline() 

99 code.writeline("pad_pair = pad_size // 2 ") 

100 code.writeline("for i in range(pad_pair): ") 

101 with code.indent(): 

102 code.writeline("pad_before[ndim - i - 1] = pad[2 * i]") 

103 code.writeline("pad_after[ndim - i - 1] = pad[2 * i + 1]") 

104 code.writeline("dst_shape = list(in0.shape)") 

105 code.writeline("for i in range(ndim): ") 

106 with code.indent(): 

107 code.writeline("dst_shape[i] += pad_before[i] + pad_after[i]") 

108 

109 code.writeline( 

110 ("out0 = torch.empty(dst_shape, device=in0.device, dtype=in0.dtype)") 

111 ) 

112 

113 # call destination_passing_func 

114 output_names: str = output_ref_for_wrapper() 

115 call_str = ( 

116 f"{output_names} = {destination_passing_func_name}" 

117 f"({parameter_ref_for_wrapper()})" 

118 ) 

119 code.writeline(call_str) 

120 

121 return_str = "return out0" 

122 code.writeline(return_str) 

123 code.newline() 

124 code.newline() 

125 

126 return code 

127 

128 

129def generate_destination_passing_padding_wrapper( 

130 rank: int, 

131 wrapper_name: str, 

132 kernel_name: str, 

133 code: IndentedBuffer, 

134) -> IndentedBuffer: 

135 # wrapper signature 

136 parameters: str = parameter_for_wrapper_out() 

137 

138 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

139 code.writeline(wrapper_signature) 

140 

141 with code.indent(): 

142 # docstring 

143 code.writeline("BLOCK_SIZE = 256") 

144 code.writeline("grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)") 

145 code.newline() 

146 

147 code.writeline("x_shape = in0.shape") 

148 code.writeline("in_strides0 = in0.stride()") 

149 code.writeline("out_strides = out0.stride()") 

150 

151 # input strides for each input tensor w.r.t. the task index space 

152 if rank > 0: 

153 code.writeline("# strides of each tensor argument w.r.t the task space") 

154 for i in range(rank): 

155 code.writeline(f"valid_dim{i}_start = pad_before[{i}]") 

156 

157 code.writeline(f"valid_dim{i}_end = dst_shape[{i}] - pad_after[{i}]") 

158 

159 code.newline() 

160 

161 code.writeline("IS_CONSTANT = mode == 'constant'") 

162 code.writeline("IS_REFLECT = mode == 'reflect'") 

163 code.writeline("IS_REPLICATE = mode == 'replicate'") 

164 code.writeline("IS_CIRCULAR = mode == 'circular'") 

165 

166 code.newline() 

167 

168 # grid 

169 code.writeline("# kernel launch") 

170 

171 # launch kernel 

172 code.writeline("with torch_device_fn.device(in0.device):") 

173 with code.indent(): 

174 kernel_launch: str = f"{kernel_name}[grid](" 

175 code.writeline(kernel_launch) 

176 

177 with code.indent(): 

178 code.writeline("in0, out0, ") 

179 

180 if rank > 0: 

181 s = ", ".join(f"x_shape[{j}]" for j in range(rank)) 

182 code.writeline(f"{s}, # shape for x") 

183 

184 s = ", ".join(f"in_strides0[{j}]" for j in range(rank)) 

185 code.writeline(f"{s}, # stride for x") 

186 

187 s = ", ".join(f"out_strides[{j}]" for j in range(rank)) 

188 code.writeline(f"{s}, # stride for out") 

189 

190 s = ", ".join(f"valid_dim{j}_start" for j in range(rank)) 

191 code.writeline(f"{s}, # valid dim start") 

192 

193 s = ", ".join(f"valid_dim{j}_end" for j in range(rank)) 

194 code.writeline(f"{s}, # valid dim end") 

195 

196 code.writeline("in0.numel(), ") 

197 code.writeline("out0.numel(), ") 

198 code.writeline("value, ") 

199 code.writeline("IS_CONSTANT, ") 

200 code.writeline("IS_REFLECT, ") 

201 code.writeline("IS_REPLICATE, ") 

202 code.writeline("IS_CIRCULAR, ") 

203 code.writeline("BLOCK_SIZE, ") 

204 code.writeline(")") 

205 

206 code.writeline("return out0") 

207 code.newline() 

208 code.newline() 

209 return code 

210 

211 

212def generate_pad_kernel( 

213 rank: int, 

214 kernel_name: str, 

215 code: IndentedBuffer, 

216) -> IndentedBuffer: 

217 # make the inlined function visible in the context 

218 code.newline() 

219 

220 # the decorators 

221 code.writeline("@libentry()") 

222 non_specialize_arg_names = ["value"] 

223 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

224 

225 # signature 

226 code.writeline(f"def {kernel_name}(") 

227 with code.indent(): 

228 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type") 

229 

230 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type") 

231 

232 if rank > 0: 

233 # shape for inputs 

234 shape_args = ", ".join(f"x_shape{j}: int" for j in range(rank)) 

235 code.writeline(f"{shape_args}, # shape for x") 

236 

237 # shape for inputs 

238 stride_args = ", ".join(f"in_strides{j}: int" for j in range(rank)) 

239 code.writeline(f"{stride_args}, # stride for x") 

240 

241 # shape for inputs 

242 stride_args = ", ".join(f"out_strides{j}: int" for j in range(rank)) 

243 code.writeline(f"{stride_args}, # stride for out") 

244 

245 # shape for inputs 

246 stride_args = ", ".join(f"valid_dim{j}_start: int" for j in range(rank)) 

247 code.writeline(f"{stride_args}, # valid dim start") 

248 

249 # shape for inputs 

250 stride_args = ", ".join(f"valid_dim{j}_end: int" for j in range(rank)) 

251 code.writeline(f"{stride_args}, # valid dim end") 

252 

253 code.writeline("in_elem_cnt: tl.constexpr, ") 

254 code.writeline("out_elem_cnt: tl.constexpr, ") 

255 code.writeline("value, # padding value") 

256 code.writeline("IS_CONSTANT: tl.constexpr, ") 

257 code.writeline("IS_REFLECT: tl.constexpr, ") 

258 code.writeline("IS_REPLICATE: tl.constexpr, ") 

259 code.writeline("IS_CIRCULAR: tl.constexpr, ") 

260 code.writeline("BLOCK_SIZE: tl.constexpr, ") 

261 

262 code.writeline("):") 

263 

264 with code.indent(): 

265 code.writeline("pid = tle.program_id(0)") 

266 code.writeline("block_offset = pid * BLOCK_SIZE") 

267 code.writeline("offset = block_offset + tl.arange(0, BLOCK_SIZE)") 

268 code.newline() 

269 

270 code.writeline("remaining = offset ") 

271 for i in range(rank): 

272 code.writeline(f"idx = remaining // out_strides{i}") 

273 code.writeline(f"dst_index_{i} = idx") 

274 code.writeline(f"remaining = remaining - idx * out_strides{i}") 

275 code.newline() 

276 

277 code.writeline("if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)") 

278 code.writeline("if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)") 

279 

280 code.writeline( 

281 "cond = ((dst_index_0 >= valid_dim0_start) & (dst_index_0 < valid_dim0_end))" 

282 ) 

283 

284 for i in range(1, rank): 

285 code.writeline( 

286 f"cond &= ((dst_index_{i} >= valid_dim{i}_start) & (dst_index_{i} < valid_dim{i}_end))" 

287 ) 

288 

289 code.writeline( 

290 "if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)" 

291 ) 

292 

293 for i in range(rank): 

294 code.writeline(f"src_index_{i} = dst_index_{i} - valid_dim{i}_start ") 

295 

296 for i in range(rank): 

297 code.writeline( 

298 f"src_index_{i} = tl.where(src_index_{i} < 0, 0, src_index_{i})" 

299 ) 

300 

301 code.newline() 

302 code.writeline("if IS_REFLECT: ") 

303 with code.indent(): 

304 for i in range(rank): 

305 code.writeline( 

306 f"""src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start, 

307 valid_dim{i}_start - dst_index_{i}, src_index_{i})""" 

308 ) 

309 for i in range(rank): 

310 code.writeline( 

311 f"""src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end, 

312 (x_shape{i} + valid_dim{i}_start - 1) * 2 - dst_index_{i} - valid_dim{i}_start, src_index_{i})""" 

313 ) 

314 

315 code.newline() 

316 code.writeline("if IS_REPLICATE: ") 

317 with code.indent(): 

318 for i in range(rank): 

319 code.writeline( 

320 f"src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start, 0, src_index_{i})" 

321 ) 

322 for i in range(rank): 

323 code.writeline( 

324 f"src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end, x_shape{i} - 1, src_index_{i})" 

325 ) 

326 

327 code.newline() 

328 code.writeline("if IS_CIRCULAR: ") 

329 with code.indent(): 

330 for i in range(rank): 

331 code.writeline( 

332 f"""src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start, 

333 dst_index_{i} + x_shape{i} - valid_dim{i}_start, src_index_{i})""" 

334 ) 

335 for i in range(rank): 

336 code.writeline( 

337 f"""src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end, 

338 dst_index_{i} - valid_dim{i}_end, src_index_{i})""" 

339 ) 

340 

341 code.newline() 

342 

343 code.writeline("src_offset = src_index_0 * in_strides0") 

344 for i in range(1, rank): 

345 code.writeline(f"src_offset += src_index_{i} * in_strides{i}") 

346 

347 code.writeline(f"load_cond = src_index_{i} < x_shape{i}") 

348 for i in range(1, rank): 

349 code.writeline(f"load_cond &= src_index_{i} < x_shape{i}") 

350 

351 code.writeline("if IS_CONSTANT: ") 

352 with code.indent(): 

353 # use explicit comparison and bitwise-and for non-scalar masks 

354 code.writeline( 

355 "x_val = tl.load(in0_ptr + src_offset, mask=((if_pad == 0) & load_cond), other=value)" 

356 ) 

357 code.writeline("else: ") 

358 with code.indent(): 

359 code.writeline( 

360 "x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)" 

361 ) 

362 code.writeline("tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)") 

363 

364 return code 

365 

366 

367def generate_code( 

368 inputs: Tuple[Any], 

369 wrapper_name: str, 

370 destination_passing_func_name: str, 

371 kernel_name: str, 

372 code: IndentedBuffer, 

373) -> IndentedBuffer: 

374 shape = inputs[0].shape 

375 rank = len(shape) 

376 

377 # the only runtime determined factor is the rank of the task space 

378 code = generate_imports(code) 

379 code = generate_functional_padding_wrapper( 

380 wrapper_name, destination_passing_func_name, code 

381 ) 

382 code = generate_destination_passing_padding_wrapper( 

383 rank, destination_passing_func_name, kernel_name, code 

384 ) 

385 code = generate_pad_kernel(rank, kernel_name, code) 

386 return code 

387 

388 

389class PadFunction: 

390 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

391 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

392 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

393 """ 

394 

395 def __init__(self): 

396 self.pid = os.getpid() 

397 self.overloads: Mapping[str, Callable] = {} 

398 

399 def __call__(self, *args, **kwargs): 

400 # note: kwargs should not be used in JITFunction directly 

401 key = f"{self.arg_key(*args)}" 

402 if key in self.overloads: 

403 overload = self.overloads[key] 

404 else: 

405 # generate file & import it 

406 code = IndentedBuffer() 

407 code = generate_code( 

408 args, 

409 "_pad_wrapper", 

410 "_pad_wrapper_out", 

411 "_pad_jit_function", 

412 code, 

413 ) 

414 

415 file_name = f"constant_pad_rank_{key}.py" 

416 file_path = code_cache_dir() / file_name 

417 write_atomic(file_path, code.getvalue()) 

418 

419 # load 

420 spec = importlib.util.spec_from_file_location( 

421 f"_gen_module_rank_{key}", 

422 file_path, 

423 ) 

424 

425 m = importlib.util.module_from_spec(spec) 

426 # do not expose it to sys.modules 

427 # sys.modules["_add_module"] = m 

428 spec.loader.exec_module(m) 

429 overload = getattr(m, "_pad_wrapper") 

430 self.overloads[key] = overload 

431 return overload(*args, **kwargs) 

432 

433 def arg_key(self, *args): 

434 tensors = [item for item in args if torch.is_tensor(item)] 

435 max_rank = max(item.ndim for item in tensors) 

436 return max_rank 

437 

438 

439_pad_func = PadFunction() 

440 

441 

442def pad(self, pad, mode="constant", value=None): 

443 logger.debug("GEMS CONSTANT PAD ND") 

444 

445 ndim = self.ndim 

446 

447 if value is None: 

448 value = 0.0 

449 

450 if mode == "reflect": 

451 ndim //= 2 

452 assert ( 

453 len(pad) == 2 * ndim 

454 ), f"padding size is expected to be {2 * ndim}, but got {len(pad)}" 

455 

456 for i in range(ndim): 

457 pad_l, pad_r = pad[2 * i], pad[2 * i + 1] 

458 input_l, input_r = ( 

459 self.shape[ndim - (2 * i + 1) - 1], 

460 self.shape[ndim - (2 * i + 1)], 

461 ) 

462 assert ( 

463 pad_l < input_l and pad_r < input_r 

464 ), \ 

465 f"padding size should be less than the corresponding input dimension, \ 

466 but got padding size: {pad_l}, {pad_r}, input size: {self.shape}" 

467 

468 if mode == "circular": 

469 ndim //= 2 

470 assert ( 

471 len(pad) == 2 * ndim 

472 ), f"padding size is expected to be {2 * ndim}, but got {len(pad)}" 

473 for i in range(ndim): 

474 pad_l, pad_r = pad[2 * i], pad[2 * i + 1] 

475 input_size = self.shape[ndim - i - 1] 

476 assert ( 

477 pad_l <= input_size and pad_r <= input_size 

478 ), "Padding value causes wrapping around more than once." 

479 

480 out = _pad_func(self, pad, mode, float(value)) 

481 return out 

482 

483 

484def constant_pad_nd(self, pad_list, value=0): 

485 return pad(self, pad_list, mode="constant", value=value)