Coverage for src/flag_gems/experimental_ops/upsample_nearest1d.py: 0%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import math 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7 

8@triton.jit 

9def _upsample_nearest1d_kernel( 

10 in_ptr, 

11 out_ptr, 

12 N, 

13 C, 

14 W_IN, 

15 W_OUT, 

16 in_stride_n, 

17 in_stride_c, 

18 in_stride_w, 

19 out_stride_n, 

20 out_stride_c, 

21 out_stride_w, 

22 use_scale, 

23 inv_scale, 

24 BLOCK_W: tl.constexpr, 

25): 

26 pid_w = tl.program_id(0) # along W_OUT 

27 pid_nc = tl.program_id(1) # along N*C 

28 

29 offs_w = pid_w * BLOCK_W + tl.arange(0, BLOCK_W) 

30 nc = pid_nc 

31 

32 n = nc // C 

33 c = nc % C 

34 

35 mask = (offs_w < W_OUT) & (n < N) & (c < C) 

36 

37 # Compute source indices 

38 # Using integer math when output_size is provided: j = floor(offs_w * W_IN / W_OUT) 

39 j_from_output = tl.minimum((offs_w * W_IN) // W_OUT, W_IN - 1) 

40 

41 # Using explicit scale factor when provided: j = floor(offs_w / scale) = floor(offs_w * inv_scale) 

42 j_from_scale = tl.minimum( 

43 (offs_w.to(tl.float32) * inv_scale).to(tl.int32), W_IN - 1 

44 ) 

45 

46 cond = use_scale != 0 

47 j = tl.where(cond, j_from_scale, j_from_output) 

48 

49 base_in = n * in_stride_n + c * in_stride_c 

50 base_out = n * out_stride_n + c * out_stride_c 

51 

52 in_idx = base_in + j * in_stride_w 

53 out_idx = base_out + offs_w * out_stride_w 

54 

55 val = tl.load(in_ptr + in_idx, mask=mask, other=0) 

56 tl.store(out_ptr + out_idx, val, mask=mask) 

57 

58 

59def _upsample_nearest1d_impl( 

60 input: torch.Tensor, output_size=None, scales=None, out: torch.Tensor = None 

61): 

62 if not input.is_cuda: 

63 raise ValueError("Input tensor must be on CUDA device.") 

64 if input.dim() != 3: 

65 raise ValueError("upsample_nearest1d expects a 3D tensor of shape (N, C, W).") 

66 N, C, W_in = input.shape 

67 

68 use_scale = False 

69 inv_scale = 0.0 

70 

71 if output_size is not None: 

72 if not isinstance(output_size, (list, tuple)) or len(output_size) != 1: 

73 raise ValueError( 

74 "output_size must be a sequence of length 1 for 1D upsampling." 

75 ) 

76 W_out = int(output_size[0]) 

77 else: 

78 # derive from scales 

79 if scales is None: 

80 raise ValueError("Either output_size or scales must be provided.") 

81 if isinstance(scales, (list, tuple)): 

82 if len(scales) == 0 or scales[0] is None: 

83 raise ValueError("Invalid scales for 1D upsampling.") 

84 s = float(scales[0]) 

85 else: 

86 s = float(scales) 

87 if s <= 0: 

88 raise ValueError("Scale factor must be positive.") 

89 W_out = int(math.floor(W_in * s)) 

90 use_scale = True 

91 inv_scale = 1.0 / s 

92 

93 if W_out <= 0: 

94 raise ValueError("Computed output width must be positive.") 

95 

96 # Prepare output 

97 if out is None: 

98 out = torch.empty((N, C, W_out), device=input.device, dtype=input.dtype) 

99 else: 

100 if not out.is_cuda: 

101 raise ValueError("Output tensor must be on CUDA device.") 

102 if list(out.shape) != [N, C, W_out]: 

103 raise ValueError( 

104 f"Output tensor has incorrect shape, expected ({N}, {C}, {W_out})." 

105 ) 

106 if out.dtype != input.dtype: 

107 raise ValueError("Output tensor must have the same dtype as input.") 

108 

109 # Extract strides 

110 in_stride_n, in_stride_c, in_stride_w = input.stride() 

111 out_stride_n, out_stride_c, out_stride_w = out.stride() 

112 

113 # Launch kernel 

114 BLOCK_W = 256 

115 grid = (triton.cdiv(W_out, BLOCK_W), N * C) 

116 _upsample_nearest1d_kernel[grid]( 

117 input, 

118 out, 

119 N, 

120 C, 

121 W_in, 

122 W_out, 

123 in_stride_n, 

124 in_stride_c, 

125 in_stride_w, 

126 out_stride_n, 

127 out_stride_c, 

128 out_stride_w, 

129 int(use_scale), 

130 float(inv_scale), 

131 BLOCK_W=BLOCK_W, 

132 ) 

133 return out 

134 

135 

136def upsample_nearest1d(input: torch.Tensor, output_size=None, scales=None): 

137 return _upsample_nearest1d_impl( 

138 input, output_size=output_size, scales=scales, out=None 

139 ) 

140 

141 

142def upsample_nearest1d_vec(input: torch.Tensor, output_size=None, scales=None): 

143 # scales expected to be a sequence; pass through as-is 

144 return _upsample_nearest1d_impl( 

145 input, output_size=output_size, scales=scales, out=None 

146 ) 

147 

148 

149def upsample_nearest1d_out( 

150 input: torch.Tensor, output_size=None, scales=None, *, out: torch.Tensor 

151): 

152 _upsample_nearest1d_impl(input, output_size=output_size, scales=scales, out=out) 

153 return out