Coverage for src/flag_gems/ops/upsample_linear1d_backward.py: 31%

71 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.jit 

11def upsample_linear1d_backward_kernel( 

12 grad_out_ptr, 

13 grad_in_ptr, 

14 n, 

15 c, 

16 in_w, 

17 out_w, 

18 go_stride_n, 

19 go_stride_c, 

20 go_stride_w, 

21 gi_stride_n, 

22 gi_stride_c, 

23 gi_stride_w, 

24 align_corners: tl.constexpr, 

25 BLOCK: tl.constexpr, 

26): 

27 pid = tl.program_id(0) 

28 offs = pid * BLOCK + tl.arange(0, BLOCK) 

29 

30 total = n * c * in_w 

31 mask = offs < total 

32 

33 x_in = offs % in_w 

34 tmp = offs // in_w 

35 c_idx = tmp % c 

36 n_idx = tmp // c 

37 

38 x_in_f = x_in.to(tl.float32) 

39 in_w_f = tl.cast(in_w, tl.float32) 

40 out_w_f = tl.cast(out_w, tl.float32) 

41 

42 if align_corners: 

43 if in_w > 1: 

44 center = x_in_f * (out_w_f - 1.0) / (in_w_f - 1.0) 

45 else: 

46 center = tl.zeros([BLOCK], dtype=tl.float32) 

47 else: 

48 center = (x_in_f + 0.5) * out_w_f / in_w_f - 0.5 

49 

50 base = tl.floor(center).to(tl.int32) 

51 

52 go_base = grad_out_ptr + n_idx * go_stride_n + c_idx * go_stride_c 

53 

54 acc = tl.zeros([BLOCK], dtype=tl.float32) 

55 

56 for i in range(-2, 3): 

57 x_out = base + i 

58 valid = (x_out >= 0) & (x_out < out_w) 

59 x_out_f = x_out.to(tl.float32) 

60 

61 if align_corners: 

62 if out_w > 1: 

63 x_real = x_out_f * (in_w_f - 1.0) / (out_w_f - 1.0) 

64 else: 

65 x_real = tl.zeros([BLOCK], dtype=tl.float32) 

66 else: 

67 x_real = (x_out_f + 0.5) * in_w_f / out_w_f - 0.5 

68 

69 x0_f = tl.floor(x_real) 

70 w1 = x_real - x0_f 

71 w0 = 1.0 - w1 

72 

73 x0_i = tl.maximum(x0_f, 0.0).to(tl.int32) 

74 x1_i = tl.minimum(x0_f + 1.0, in_w_f - 1.0).to(tl.int32) 

75 

76 g = tl.load( 

77 go_base + x_out * go_stride_w, 

78 mask=mask & valid, 

79 other=0.0, 

80 ).to(tl.float32) 

81 

82 same = x0_i == x1_i 

83 is_x0 = x_in.to(tl.int32) == x0_i 

84 is_x1 = x_in.to(tl.int32) == x1_i 

85 

86 acc += tl.where(same & is_x0, g * (w0 + w1), 0.0) 

87 acc += tl.where(~same & is_x0, g * w0, 0.0) 

88 acc += tl.where(~same & is_x1, g * w1, 0.0) 

89 

90 gi_ptr = ( 

91 grad_in_ptr + n_idx * gi_stride_n + c_idx * gi_stride_c + x_in * gi_stride_w 

92 ) 

93 tl.store(gi_ptr, acc, mask=mask) 

94 

95 

96def upsample_linear1d_backward( 

97 grad_output: torch.Tensor, 

98 output_size, 

99 input_size, 

100 align_corners: bool, 

101 scale_factors=None, 

102) -> torch.Tensor: 

103 logger.debug("GEMS UPSAMPLE_LINEAR1D_BACKWARD") 

104 

105 if len(input_size) == 3: 

106 n, c, in_w = input_size 

107 elif len(input_size) == 2: 

108 n, c, in_w = input_size[0], 1, input_size[1] 

109 elif len(input_size) == 1: 

110 n, c, in_w = 1, 1, input_size[0] 

111 else: 

112 raise ValueError 

113 

114 if output_size is not None: 

115 out_w = output_size[0] 

116 else: 

117 assert scale_factors is not None 

118 out_w = int(in_w * scale_factors[0]) 

119 

120 assert grad_output.shape[-1] == out_w 

121 

122 grad_out_3d = grad_output.contiguous().view(n, c, out_w) 

123 

124 grad_in = torch.zeros( 

125 (n, c, in_w), 

126 device=grad_output.device, 

127 dtype=grad_output.dtype, 

128 ) 

129 

130 go_stride_n, go_stride_c, go_stride_w = grad_out_3d.stride() 

131 gi_stride_n, gi_stride_c, gi_stride_w = grad_in.stride() 

132 

133 BLOCK = 512 

134 grid = (triton.cdiv(n * c * in_w, BLOCK),) 

135 

136 upsample_linear1d_backward_kernel[grid]( 

137 grad_out_3d, 

138 grad_in, 

139 n, 

140 c, 

141 in_w, 

142 out_w, 

143 go_stride_n, 

144 go_stride_c, 

145 go_stride_w, 

146 gi_stride_n, 

147 gi_stride_c, 

148 gi_stride_w, 

149 align_corners, 

150 BLOCK=BLOCK, 

151 ) 

152 

153 return grad_in