Coverage for src/flag_gems/ops/replication_pad3d.py: 48%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import libentry, libtuner 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@libentry() 

14@libtuner( 

15 configs=runtime.get_tuned_config("replication_pad3d"), 

16 key=["H_out", "W_out"], 

17) 

18@triton.jit 

19def replicationpad3d_kernel( 

20 x_ptr, 

21 out_ptr, 

22 D_in, 

23 H_in, 

24 W_in, 

25 D_out, 

26 H_out, 

27 W_out, 

28 pad_l, 

29 pad_t, 

30 pad_f, 

31 stride_xn, 

32 stride_xc, 

33 stride_xd, 

34 stride_xh, 

35 stride_xw, 

36 stride_on, 

37 stride_oc, 

38 stride_od, 

39 stride_oh, 

40 stride_ow, 

41 C, 

42 BLOCK_H: tl.constexpr, 

43 BLOCK_W: tl.constexpr, 

44): 

45 pid_w = tl.program_id(0) 

46 pid_h = tl.program_id(1) 

47 pid_ncd = tl.program_id(2) 

48 

49 d_idx = pid_ncd % D_out 

50 nc_idx = pid_ncd // D_out 

51 c_idx = nc_idx % C 

52 n_idx = nc_idx // C 

53 

54 iz = d_idx - pad_f 

55 iz = tl.where(iz < 0, 0, iz) 

56 iz = tl.where(iz > D_in - 1, D_in - 1, iz) 

57 

58 x_base_ptr = x_ptr + n_idx * stride_xn + c_idx * stride_xc + iz * stride_xd 

59 out_base_ptr = out_ptr + n_idx * stride_on + c_idx * stride_oc + d_idx * stride_od 

60 

61 offs_h = pid_h * BLOCK_H + tl.arange(0, BLOCK_H) 

62 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

63 

64 iy = offs_h - pad_t 

65 iy = tl.where(iy < 0, 0, iy) 

66 iy = tl.where(iy > H_in - 1, H_in - 1, iy) 

67 

68 ix = offs_w - pad_l 

69 ix = tl.where(ix < 0, 0, ix) 

70 ix = tl.where(ix > W_in - 1, W_in - 1, ix) 

71 

72 x_offset = iy[:, None] * stride_xh + ix[None, :] * stride_xw 

73 out_offset = offs_h[:, None] * stride_oh + offs_w[None, :] * stride_ow 

74 

75 mask = (offs_h[:, None] < H_out) & (offs_w[None, :] < W_out) 

76 

77 vals = tl.load(x_base_ptr + x_offset, mask=mask) 

78 tl.store(out_base_ptr + out_offset, vals, mask=mask) 

79 

80 

81def replication_pad3d(x, padding): 

82 logger.debug("GEMS Replication_Pad3d") 

83 if isinstance(padding, int): 

84 pad_l = pad_r = pad_t = pad_b = pad_f = pad_ba = padding 

85 else: 

86 pad_l, pad_r, pad_t, pad_b, pad_f, pad_ba = padding 

87 

88 is_4d = x.ndim == 4 

89 if is_4d: 

90 x = x.unsqueeze(0) 

91 

92 N, C, D_in, H_in, W_in = x.shape 

93 D_out, H_out, W_out = ( 

94 D_in + pad_f + pad_ba, 

95 H_in + pad_t + pad_b, 

96 W_in + pad_l + pad_r, 

97 ) 

98 

99 out = torch.empty((N, C, D_out, H_out, W_out), device=x.device, dtype=x.dtype) 

100 

101 grid = lambda META: ( 

102 triton.cdiv(W_out, META["BLOCK_W"]), 

103 triton.cdiv(H_out, META["BLOCK_H"]), 

104 N * C * D_out, 

105 ) 

106 

107 replicationpad3d_kernel[grid]( 

108 x, 

109 out, 

110 D_in, 

111 H_in, 

112 W_in, 

113 D_out, 

114 H_out, 

115 W_out, 

116 pad_l, 

117 pad_t, 

118 pad_f, 

119 *x.stride(), 

120 *out.stride(), 

121 C, 

122 ) 

123 

124 return out.squeeze(0) if is_4d else out