Coverage for src/flag_gems/runtime/backend/_cambricon/ops/pad.py: 0%

344 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7import triton 

8import triton.language as tl 

9 

10from flag_gems.utils import libentry 

11from flag_gems.utils.code_cache import code_cache_dir 

12from flag_gems.utils.code_utils import IndentedBuffer 

13 

14from ..utils import TOTAL_CORE_NUM 

15 

16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

17 

18 

19# --------------------------- padding wrapper genration ----------------------------------- 

20def parameter_for_wrapper() -> str: 

21 """Generate parameter declaration with type annotation for wrapper function. 

22 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

23 """ 

24 parameters: List[str] = [] 

25 

26 parameters.append("in0") 

27 parameters.append("pad") 

28 parameters.append("mode") 

29 parameters.append("value=0") 

30 return ", ".join(parameters) 

31 

32 

33def parameter_for_wrapper_out() -> str: 

34 """Generate parameter declaration with type annotation for wrapper function. 

35 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

36 """ 

37 parameters: List[str] = [] 

38 

39 parameters.append("in0") 

40 parameters.append("out0") 

41 parameters.append("dst_shape") 

42 parameters.append("pad_before") 

43 parameters.append("pad_after") 

44 parameters.append("mode") 

45 parameters.append("value=0") 

46 

47 return ", ".join(parameters) 

48 

49 

50def parameter_ref_for_wrapper() -> str: 

51 """Generate parameter reference for wrapper function. 

52 Example: in0, val0, out0, out0_offset 

53 """ 

54 parameters: List[str] = [] 

55 

56 parameters.append("in0") 

57 parameters.append("out0") 

58 parameters.append("dst_shape") 

59 parameters.append("pad_before") 

60 parameters.append("pad_after") 

61 parameters.append("mode") 

62 parameters.append("value") 

63 

64 return ", ".join(parameters) 

65 

66 

67def output_ref_for_wrapper() -> str: 

68 return "out0" 

69 

70 

71def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

72 code.writeline("import math") 

73 code.writeline("import torch") 

74 code.writeline("import triton") 

75 code.writeline("from triton import language as tl") 

76 code.newline() 

77 code.writeline("from flag_gems.utils.libentry import libentry") 

78 code.writeline("from flag_gems.runtime import torch_device_fn") 

79 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

80 code.writeline("from flag_gems.utils.type_utils import type_promotion") 

81 code.newline() 

82 code.newline() 

83 return code 

84 

85 

86def generate_functional_padding_wrapper( 

87 wrapper_name: str, 

88 destination_passing_func_name: str, 

89 code: IndentedBuffer, 

90) -> IndentedBuffer: 

91 # wrapper signature 

92 parameters: str = parameter_for_wrapper() 

93 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

94 code.writeline(wrapper_signature) 

95 

96 with code.indent(): 

97 code.writeline("ndim = in0.ndim") 

98 code.writeline("pad_size = len(pad)") 

99 code.writeline("assert pad_size % 2 == 0") 

100 code.newline() 

101 code.writeline("pad_before = [0 for _ in range(ndim)]") 

102 code.writeline("pad_after = [0 for _ in range(ndim)]") 

103 code.newline() 

104 code.writeline("pad_pair = pad_size // 2 ") 

105 code.writeline("for i in range(pad_pair): ") 

106 with code.indent(): 

107 code.writeline("pad_before[ndim - i - 1] = pad[2 * i]") 

108 code.writeline("pad_after[ndim - i - 1] = pad[2 * i + 1]") 

109 code.writeline("dst_shape = list(in0.shape)") 

110 code.writeline("for i in range(ndim): ") 

111 with code.indent(): 

112 code.writeline("dst_shape[i] += pad_before[i] + pad_after[i]") 

113 

114 code.writeline( 

115 ("out0 = torch.empty(dst_shape, device=in0.device, dtype=in0.dtype)") 

116 ) 

117 

118 # call destination_passing_func 

119 output_names: str = output_ref_for_wrapper() 

120 call_str = ( 

121 f"{output_names} = {destination_passing_func_name}" 

122 f"({parameter_ref_for_wrapper()})" 

123 ) 

124 code.writeline(call_str) 

125 

126 return_str = "return out0" 

127 code.writeline(return_str) 

128 code.newline() 

129 code.newline() 

130 

131 return code 

132 

133 

134def generate_destination_passing_padding_wrapper( 

135 rank: int, 

136 wrapper_name: str, 

137 kernel_name: str, 

138 code: IndentedBuffer, 

139) -> IndentedBuffer: 

140 # wrapper signature 

141 parameters: str = parameter_for_wrapper_out() 

142 

143 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

144 code.writeline(wrapper_signature) 

145 

146 with code.indent(): 

147 # docstring 

148 code.writeline("BLOCK_SIZE = 2048") 

149 code.writeline("grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)") 

150 code.newline() 

151 

152 code.writeline("x_shape = in0.shape") 

153 code.writeline("in_strides0 = in0.stride()") 

154 code.writeline("out_strides = out0.stride()") 

155 

156 # input strides for each input tensor w.r.t. the task index space 

157 if rank > 0: 

158 code.writeline("# strides of each tensor argument w.r.t the task space") 

159 for i in range(rank): 

160 code.writeline(f"valid_dim{i}_start = pad_before[{i}]") 

161 

162 code.writeline(f"valid_dim{i}_end = dst_shape[{i}] - pad_after[{i}]") 

163 

164 code.newline() 

165 

166 code.writeline("IS_CONSTANT = mode == 'constant'") 

167 code.writeline("IS_REFLECT = mode == 'reflect'") 

168 code.writeline("IS_REPLICATE = mode == 'replicate'") 

169 code.writeline("IS_CIRCULAR = mode == 'circular'") 

170 

171 code.newline() 

172 

173 # grid 

174 code.writeline("# kernel launch") 

175 

176 # launch kernel 

177 code.writeline("with torch_device_fn.device(in0.device):") 

178 with code.indent(): 

179 kernel_launch: str = f"{kernel_name}[grid](" 

180 code.writeline(kernel_launch) 

181 

182 with code.indent(): 

183 code.writeline("in0, out0, ") 

184 

185 if rank > 0: 

186 s = ", ".join(f"x_shape[{j}]" for j in range(rank)) 

187 code.writeline(f"{s}, # shape for x") 

188 

189 s = ", ".join(f"in_strides0[{j}]" for j in range(rank)) 

190 code.writeline(f"{s}, # stride for x") 

191 

192 s = ", ".join(f"out_strides[{j}]" for j in range(rank)) 

193 code.writeline(f"{s}, # stride for out") 

194 

195 s = ", ".join(f"valid_dim{j}_start" for j in range(rank)) 

196 code.writeline(f"{s}, # valid dim start") 

197 

198 s = ", ".join(f"valid_dim{j}_end" for j in range(rank)) 

199 code.writeline(f"{s}, # valid dim end") 

200 

201 code.writeline("in0.numel(), ") 

202 code.writeline("out0.numel(), ") 

203 code.writeline("value, ") 

204 code.writeline("IS_CONSTANT, ") 

205 code.writeline("IS_REFLECT, ") 

206 code.writeline("IS_REPLICATE, ") 

207 code.writeline("IS_CIRCULAR, ") 

208 code.writeline("BLOCK_SIZE, ") 

209 code.writeline(")") 

210 

211 code.writeline("return out0") 

212 code.newline() 

213 code.newline() 

214 return code 

215 

216 

217def generate_pad_kernel( 

218 rank: int, 

219 kernel_name: str, 

220 code: IndentedBuffer, 

221) -> IndentedBuffer: 

222 # make the inlined function visible in the context 

223 code.newline() 

224 

225 # the decorators 

226 code.writeline("@libentry()") 

227 non_specialize_arg_names = ["value"] 

228 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

229 

230 # signature 

231 code.writeline(f"def {kernel_name}(") 

232 with code.indent(): 

233 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type") 

234 

235 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type") 

236 

237 if rank > 0: 

238 # shape for inputs 

239 shape_args = ", ".join(f"x_shape{j}: int" for j in range(rank)) 

240 code.writeline(f"{shape_args}, # shape for x") 

241 

242 # shape for inputs 

243 stride_args = ", ".join(f"in_strides{j}: int" for j in range(rank)) 

244 code.writeline(f"{stride_args}, # stride for x") 

245 

246 # shape for inputs 

247 stride_args = ", ".join(f"out_strides{j}: int" for j in range(rank)) 

248 code.writeline(f"{stride_args}, # stride for out") 

249 

250 # shape for inputs 

251 stride_args = ", ".join(f"valid_dim{j}_start: int" for j in range(rank)) 

252 code.writeline(f"{stride_args}, # valid dim start") 

253 

254 # shape for inputs 

255 stride_args = ", ".join(f"valid_dim{j}_end: int" for j in range(rank)) 

256 code.writeline(f"{stride_args}, # valid dim end") 

257 

258 code.writeline("in_elem_cnt: tl.constexpr, ") 

259 code.writeline("out_elem_cnt: tl.constexpr, ") 

260 code.writeline("value, # padding value") 

261 code.writeline("IS_CONSTANT: tl.constexpr, ") 

262 code.writeline("IS_REFLECT: tl.constexpr, ") 

263 code.writeline("IS_REPLICATE: tl.constexpr, ") 

264 code.writeline("IS_CIRCULAR: tl.constexpr, ") 

265 code.writeline("BLOCK_SIZE: tl.constexpr, ") 

266 

267 code.writeline("):") 

268 

269 with code.indent(): 

270 code.writeline("pid = tl.program_id(0)") 

271 code.writeline("block_offset = pid * BLOCK_SIZE") 

272 code.writeline("offset = block_offset + tl.arange(0, BLOCK_SIZE)") 

273 code.newline() 

274 

275 code.writeline("remaining = offset ") 

276 for i in range(rank): 

277 code.writeline(f"idx = remaining // out_strides{i}") 

278 code.writeline(f"dst_index_{i} = idx") 

279 code.writeline(f"remaining = remaining - idx * out_strides{i}") 

280 code.newline() 

281 

282 code.writeline("if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)") 

283 code.writeline("if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)") 

284 

285 code.writeline( 

286 "cond = ((dst_index_0 >= valid_dim0_start) & (dst_index_0 < valid_dim0_end))" 

287 ) 

288 

289 for i in range(1, rank): 

290 code.writeline( 

291 f"cond &= ((dst_index_{i} >= valid_dim{i}_start) & (dst_index_{i} < valid_dim{i}_end))" 

292 ) 

293 

294 code.writeline( 

295 "if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)" 

296 ) 

297 

298 for i in range(rank): 

299 code.writeline(f"src_index_{i} = dst_index_{i} - valid_dim{i}_start ") 

300 

301 for i in range(rank): 

302 code.writeline( 

303 f"src_index_{i} = tl.where(src_index_{i} < 0, 0, src_index_{i})" 

304 ) 

305 

306 code.newline() 

307 code.writeline("if IS_REFLECT: ") 

308 with code.indent(): 

309 for i in range(rank): 

310 code.writeline( 

311 f"""src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start, 

312 valid_dim{i}_start - dst_index_{i}, src_index_{i})""" 

313 ) 

314 for i in range(rank): 

315 code.writeline( 

316 f"""src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end, 

317 (x_shape{i} + valid_dim{i}_start - 1) * 2 - dst_index_{i} - valid_dim{i}_start, src_index_{i})""" 

318 ) 

319 

320 code.newline() 

321 code.writeline("if IS_REPLICATE: ") 

322 with code.indent(): 

323 for i in range(rank): 

324 code.writeline( 

325 f"src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start, 0, src_index_{i})" 

326 ) 

327 for i in range(rank): 

328 code.writeline( 

329 f"src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end, x_shape{i} - 1, src_index_{i})" 

330 ) 

331 

332 code.newline() 

333 code.writeline("if IS_CIRCULAR: ") 

334 with code.indent(): 

335 for i in range(rank): 

336 code.writeline( 

337 f"""src_index_{i} = tl.where(dst_index_{i} < valid_dim{i}_start, 

338 dst_index_{i} + x_shape{i} - valid_dim{i}_start, src_index_{i})""" 

339 ) 

340 for i in range(rank): 

341 code.writeline( 

342 f"""src_index_{i} = tl.where(dst_index_{i} >= valid_dim{i}_end, 

343 dst_index_{i} - valid_dim{i}_end, src_index_{i})""" 

344 ) 

345 

346 code.newline() 

347 

348 code.writeline("src_offset = src_index_0 * in_strides0") 

349 for i in range(1, rank): 

350 code.writeline(f"src_offset += src_index_{i} * in_strides{i}") 

351 

352 code.writeline(f"load_cond = src_index_{i} < x_shape{i}") 

353 for i in range(1, rank): 

354 code.writeline(f"load_cond &= src_index_{i} < x_shape{i}") 

355 

356 code.writeline("if IS_CONSTANT: ") 

357 with code.indent(): 

358 code.writeline( 

359 "x_val = tl.load(in0_ptr + src_offset, mask=((if_pad == 0) & load_cond), other=value)" 

360 ) 

361 code.writeline("else: ") 

362 with code.indent(): 

363 code.writeline( 

364 "x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)" 

365 ) 

366 code.writeline("tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)") 

367 

368 return code 

369 

370 

371def generate_code( 

372 inputs: Tuple[Any], 

373 wrapper_name: str, 

374 destination_passing_func_name: str, 

375 kernel_name: str, 

376 code: IndentedBuffer, 

377) -> IndentedBuffer: 

378 shape = inputs[0].shape 

379 rank = len(shape) 

380 

381 # the only runtime determined factor is the rank of the task space 

382 code = generate_imports(code) 

383 code = generate_functional_padding_wrapper( 

384 wrapper_name, destination_passing_func_name, code 

385 ) 

386 code = generate_destination_passing_padding_wrapper( 

387 rank, destination_passing_func_name, kernel_name, code 

388 ) 

389 code = generate_pad_kernel(rank, kernel_name, code) 

390 return code 

391 

392 

393class PadFunction: 

394 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

395 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

396 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

397 """ 

398 

399 def __init__(self): 

400 self.pid = os.getpid() 

401 self.overloads: Mapping[str, Callable] = {} 

402 

403 def __call__(self, *args, **kwargs): 

404 # note: kwargs should not be used in JITFunction directly 

405 key = f"{self.arg_key(*args)}" 

406 if key in self.overloads: 

407 overload = self.overloads[key] 

408 else: 

409 # generate file & import it 

410 code = IndentedBuffer() 

411 code = generate_code( 

412 args, 

413 "_pad_wrapper", 

414 "_pad_wrapper_out", 

415 "_pad_jit_function", 

416 code, 

417 ) 

418 

419 file_name = f"constant_pad_rank_{key}_pid_{self.pid}.py" 

420 

421 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

422 f.write(code.getvalue()) 

423 

424 # load 

425 spec = importlib.util.spec_from_file_location( 

426 f"_gen_module_rank_{key}_pid_{self.pid}", 

427 f.name, 

428 ) 

429 

430 m = importlib.util.module_from_spec(spec) 

431 # do not expose it to sys.modules 

432 # sys.modules["_add_module"] = m 

433 spec.loader.exec_module(m) 

434 overload = getattr(m, "_pad_wrapper") 

435 self.overloads[key] = overload 

436 return overload(*args, **kwargs) 

437 

438 def arg_key(self, *args): 

439 tensors = [item for item in args if torch.is_tensor(item)] 

440 max_rank = max(item.ndim for item in tensors) 

441 return max_rank 

442 

443 

444_pad_func = PadFunction() 

445 

446 

447@libentry() 

448@triton.autotune( 

449 configs=[ 

450 triton.Config({"BLOCK_SIZE": 2**n}, num_stages=s) 

451 for n in range(10, 16, 2) 

452 for s in [1, 3] 

453 ], 

454 key=["inp_elements"], 

455) 

456@triton.jit 

457def pad_1d_constant_kernel( 

458 inp_ptr, 

459 out_ptr, 

460 inp_elements, 

461 pad_value, 

462 pad_left, 

463 pad_right, 

464 BLOCK_SIZE: tl.constexpr, 

465): 

466 pid = tl.program_id(0) 

467 num_jobs = tl.num_programs(0) 

468 start = pid * BLOCK_SIZE 

469 step = num_jobs * BLOCK_SIZE 

470 out_elements = pad_left + inp_elements + pad_right 

471 for off in range(start, out_elements, step): 

472 inp_offset = off + tl.arange(0, BLOCK_SIZE) - pad_left 

473 inp_mask = inp_offset >= 0 and inp_offset < inp_elements 

474 inp = tl.load(inp_ptr + inp_offset, mask=inp_mask, other=pad_value) 

475 out_offset = off + tl.arange(0, BLOCK_SIZE) 

476 out_mask = out_offset < out_elements 

477 tl.store(out_ptr + out_offset, inp, mask=out_mask) 

478 

479 

480@libentry() 

481@triton.autotune( 

482 configs=[ 

483 triton.Config({"BLOCK_H": n}, num_stages=s) 

484 for n in [1, 4, 8, 12, 16, 24] 

485 for s in [1, 3] 

486 ], 

487 key=["H", "W"], 

488) 

489@triton.jit 

490def pad_2d_constant_kernel( 

491 inp_ptr, 

492 out_ptr, 

493 H, 

494 W: tl.constexpr, 

495 pad_value, 

496 pad_left: tl.constexpr, 

497 pad_right: tl.constexpr, 

498 pad_top, 

499 pad_bottom, 

500 BLOCK_H: tl.constexpr, 

501): 

502 pid = tl.program_id(0) 

503 num_jobs = tl.num_programs(0) 

504 block_start = pid * BLOCK_H 

505 step = num_jobs * BLOCK_H 

506 out_W: tl.constexpr = pad_left + W + pad_right 

507 out_H = pad_top + H + pad_bottom 

508 for batch_idx in range(block_start, out_H, step): 

509 offset_h = tl.arange(0, BLOCK_H) + batch_idx - pad_top 

510 offset_w = tl.arange(0, out_W) - pad_left 

511 offsets = offset_h[:, None] * W + offset_w[None, :] 

512 mask = (offset_h[:, None] >= 0 and offset_h[:, None] < H) and ( 

513 offset_w[None, :] >= 0 and offset_w[None, :] < W 

514 ) 

515 inp = tl.load(inp_ptr + offsets, mask=mask, other=pad_value) 

516 

517 out_offset_c = tl.arange(0, out_W) 

518 out_offset_n = tl.arange(0, BLOCK_H) + batch_idx 

519 out_offsets = out_offset_n[:, None] * out_W + out_offset_c[None, :] 

520 out_mask = out_offset_n[:, None] < out_H and out_offset_c[None, :] < out_W 

521 tl.store(out_ptr + out_offsets, inp, mask=out_mask) 

522 

523 

524def pad(self, pad, mode="constant", value=None): 

525 logger.debug("GEMS_CAMBRICON CONSTANT PAD ND") 

526 

527 ndim = self.ndim 

528 pad_size = len(pad) 

529 assert pad_size % 2 == 0 

530 

531 if value is None: 

532 value = 0.0 

533 

534 if mode == "constant": 

535 pad_before = [0 for _ in range(ndim)] 

536 pad_after = [0 for _ in range(ndim)] 

537 pad_pair = pad_size // 2 

538 for i in range(pad_pair): 

539 pad_before[ndim - i - 1] = pad[2 * i] 

540 pad_after[ndim - i - 1] = pad[2 * i + 1] 

541 

542 inp_shape = list(self.shape) 

543 out_shape = list(self.shape) 

544 for i in range(ndim): 

545 out_shape[i] += pad_before[i] + pad_after[i] 

546 out = torch.empty(out_shape, dtype=self.dtype, device=self.device) 

547 

548 if ndim == 1: 

549 grid = lambda meta: ( 

550 min(triton.cdiv(out_shape[0], meta["BLOCK_SIZE"]), TOTAL_CORE_NUM), 

551 ) 

552 pad_1d_constant_kernel[grid]( 

553 self.contiguous(), 

554 out, 

555 inp_shape[0], 

556 value, 

557 pad_before[-1], 

558 pad_after[-1], 

559 ) 

560 return out 

561 

562 if ndim == 2: 

563 grid = lambda meta: ( 

564 min(triton.cdiv(out_shape[0], meta["BLOCK_H"]), TOTAL_CORE_NUM), 

565 ) 

566 pad_2d_constant_kernel[grid]( 

567 self.contiguous(), 

568 out, 

569 inp_shape[0], 

570 inp_shape[1], 

571 value, 

572 pad_before[-1], 

573 pad_after[-1], 

574 pad_before[-2], 

575 pad_after[-2], 

576 ) 

577 return out 

578 

579 if ndim == 3: 

580 out[: pad_before[0]] = torch.full( 

581 out[0 : pad_before[0]].shape, 

582 value, 

583 dtype=self.dtype, 

584 device=self.device, 

585 ) 

586 out[pad_before[0] + inp_shape[0] :] = torch.full( 

587 out[pad_before[0] + inp_shape[0] :].shape, 

588 value, 

589 dtype=self.dtype, 

590 device=self.device, 

591 ) 

592 

593 for i in range(pad_before[0], pad_before[0] + inp_shape[0]): 

594 grid = lambda meta: ( 

595 min(triton.cdiv(out_shape[1], meta["BLOCK_H"]), TOTAL_CORE_NUM), 

596 ) 

597 pad_2d_constant_kernel[grid]( 

598 self[i - pad_before[0]].contiguous(), 

599 out[i], 

600 inp_shape[1], 

601 inp_shape[2], 

602 value, 

603 pad_before[-1], 

604 pad_after[-1], 

605 pad_before[-2], 

606 pad_after[-2], 

607 ) 

608 return out 

609 

610 if mode == "reflect": 

611 ndim //= 2 

612 assert ( 

613 len(pad) == 2 * ndim 

614 ), f"padding size is expected to be {2 * ndim}, but got {len(pad)}" 

615 

616 for i in range(ndim): 

617 pad_l, pad_r = pad[2 * i], pad[2 * i + 1] 

618 input_l, input_r = ( 

619 self.shape[ndim - (2 * i + 1) - 1], 

620 self.shape[ndim - (2 * i + 1)], 

621 ) 

622 assert ( 

623 pad_l < input_l and pad_r < input_r 

624 ), \ 

625 f"padding size should be less than the corresponding input dimension, \ 

626 but got padding size: {pad_l}, {pad_r}, input size: {self.shape}" 

627 

628 if mode == "circular": 

629 ndim //= 2 

630 assert ( 

631 len(pad) == 2 * ndim 

632 ), f"padding size is expected to be {2 * ndim}, but got {len(pad)}" 

633 for i in range(ndim): 

634 pad_l, pad_r = pad[2 * i], pad[2 * i + 1] 

635 input_size = self.shape[ndim - i - 1] 

636 assert ( 

637 pad_l <= input_size and pad_r <= input_size 

638 ), "Padding value causes wrapping around more than once." 

639 

640 out = _pad_func(self, pad, mode, float(value)) 

641 return out 

642 

643 

644def constant_pad_nd(self, pad_list, value=0): 

645 return pad(self, pad_list, mode="constant", value=value)