Coverage for src/flag_gems/ops/replication_pad3d.py: 48%
50 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import libentry, libtuner
10logger = logging.getLogger(__name__)
13@libentry()
14@libtuner(
15 configs=runtime.get_tuned_config("replication_pad3d"),
16 key=["H_out", "W_out"],
17)
18@triton.jit
19def replicationpad3d_kernel(
20 x_ptr,
21 out_ptr,
22 D_in,
23 H_in,
24 W_in,
25 D_out,
26 H_out,
27 W_out,
28 pad_l,
29 pad_t,
30 pad_f,
31 stride_xn,
32 stride_xc,
33 stride_xd,
34 stride_xh,
35 stride_xw,
36 stride_on,
37 stride_oc,
38 stride_od,
39 stride_oh,
40 stride_ow,
41 C,
42 BLOCK_H: tl.constexpr,
43 BLOCK_W: tl.constexpr,
44):
45 pid_w = tl.program_id(0)
46 pid_h = tl.program_id(1)
47 pid_ncd = tl.program_id(2)
49 d_idx = pid_ncd % D_out
50 nc_idx = pid_ncd // D_out
51 c_idx = nc_idx % C
52 n_idx = nc_idx // C
54 iz = d_idx - pad_f
55 iz = tl.where(iz < 0, 0, iz)
56 iz = tl.where(iz > D_in - 1, D_in - 1, iz)
58 x_base_ptr = x_ptr + n_idx * stride_xn + c_idx * stride_xc + iz * stride_xd
59 out_base_ptr = out_ptr + n_idx * stride_on + c_idx * stride_oc + d_idx * stride_od
61 offs_h = pid_h * BLOCK_H + tl.arange(0, BLOCK_H)
62 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W)
64 iy = offs_h - pad_t
65 iy = tl.where(iy < 0, 0, iy)
66 iy = tl.where(iy > H_in - 1, H_in - 1, iy)
68 ix = offs_w - pad_l
69 ix = tl.where(ix < 0, 0, ix)
70 ix = tl.where(ix > W_in - 1, W_in - 1, ix)
72 x_offset = iy[:, None] * stride_xh + ix[None, :] * stride_xw
73 out_offset = offs_h[:, None] * stride_oh + offs_w[None, :] * stride_ow
75 mask = (offs_h[:, None] < H_out) & (offs_w[None, :] < W_out)
77 vals = tl.load(x_base_ptr + x_offset, mask=mask)
78 tl.store(out_base_ptr + out_offset, vals, mask=mask)
81def replication_pad3d(x, padding):
82 logger.debug("GEMS Replication_Pad3d")
83 if isinstance(padding, int):
84 pad_l = pad_r = pad_t = pad_b = pad_f = pad_ba = padding
85 else:
86 pad_l, pad_r, pad_t, pad_b, pad_f, pad_ba = padding
88 is_4d = x.ndim == 4
89 if is_4d:
90 x = x.unsqueeze(0)
92 N, C, D_in, H_in, W_in = x.shape
93 D_out, H_out, W_out = (
94 D_in + pad_f + pad_ba,
95 H_in + pad_t + pad_b,
96 W_in + pad_l + pad_r,
97 )
99 out = torch.empty((N, C, D_out, H_out, W_out), device=x.device, dtype=x.dtype)
101 grid = lambda META: (
102 triton.cdiv(W_out, META["BLOCK_W"]),
103 triton.cdiv(H_out, META["BLOCK_H"]),
104 N * C * D_out,
105 )
107 replicationpad3d_kernel[grid](
108 x,
109 out,
110 D_in,
111 H_in,
112 W_in,
113 D_out,
114 H_out,
115 W_out,
116 pad_l,
117 pad_t,
118 pad_f,
119 *x.stride(),
120 *out.stride(),
121 C,
122 )
124 return out.squeeze(0) if is_4d else out