Coverage for src/flag_gems/runtime/backend/_cambricon/ops/upsample_nearest2d.py: 0%

108 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import device, torch_device_fn 

10 

11from ..utils import MAX_GRID_SIZE_X, TOTAL_CORE_NUM 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14device = device.name 

15 

16 

17@triton.autotune( 

18 configs=runtime.get_tuned_config("upsample_nearest2d"), key=["N", "C", "OH", "OW"] 

19) 

20@triton.heuristics(runtime.get_heuristic_config("upsample_nearest2d")) 

21@triton.jit 

22def upsample_nearest2d_kernel( 

23 ptr_o, 

24 ptr_i, 

25 N, 

26 C, 

27 OH, 

28 OW, 

29 IH, 

30 IW, 

31 reciprocal_scale_h, 

32 reciprocal_scale_w, 

33 BLOCK_SIZE: tl.constexpr, 

34 SAME_H: tl.constexpr, 

35 SAME_W: tl.constexpr, 

36): 

37 pid = tl.program_id(axis=0) + tl.program_id(axis=1) * tl.num_programs(0) 

38 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

39 ow = idx % OW 

40 oh = idx // OW % OH 

41 c = idx // OW // OH % C 

42 n = idx // OW // OH // C % N 

43 if SAME_H: 

44 ih = oh 

45 else: 

46 # tl.floor() cannot be found in 2.3.1, using int trunc 

47 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1) 

48 if SAME_W: 

49 iw = ow 

50 else: 

51 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1) 

52 offset_o = ((n * C + c) * OH + oh) * OW + ow 

53 offset_i = ((n * C + c) * IH + ih) * IW + iw 

54 data = tl.load(ptr_i + offset_i) 

55 tl.store(ptr_o + offset_o, data) 

56 

57 

58def configs2(): 

59 block_h = [8, 16, 32, 64, 128, 256] 

60 num_stage = [1, 3] 

61 return [ 

62 triton.Config({"BLOCK_H": bh}, num_warps=1, num_stages=s) 

63 for s in num_stage 

64 for bh in block_h 

65 ] 

66 

67 

68@triton.autotune(configs=configs2(), key=["N", "C", "OH", "OW"]) 

69@triton.jit 

70def upsample_nearest2d_kernel_opt( 

71 ptr_o, 

72 ptr_i, 

73 N, 

74 C, 

75 OH, 

76 OW: tl.constexpr, 

77 IH, 

78 IW: tl.constexpr, 

79 BLOCK_H: tl.constexpr, 

80): 

81 pid = tl.program_id(axis=0) 

82 num_jobs = tl.num_programs(axis=0) 

83 

84 nc_nums_per_job = (N * C + num_jobs - 1) // num_jobs 

85 nc_begin = pid * nc_nums_per_job 

86 nc_end = min(nc_begin + nc_nums_per_job, N * C) 

87 

88 loop_num_h = (OH + BLOCK_H - 1) // BLOCK_H 

89 for idx in range((nc_end - nc_begin) * loop_num_h): 

90 nc_idx = nc_begin + (idx // loop_num_h) 

91 h_idx = (idx % loop_num_h) * BLOCK_H 

92 

93 init_out = nc_idx * OH * OW 

94 init_in = nc_idx * IH * IW 

95 

96 ih = h_idx // 2 + tl.arange(0, BLOCK_H // 2) 

97 iw = tl.arange(0, IW) 

98 offset_i = init_in + ih[:, None] * IW + iw 

99 

100 oh = h_idx + tl.arange(0, BLOCK_H) 

101 ow = tl.arange(0, OW) 

102 offset_o = init_out + oh[:, None] * OW + ow 

103 

104 data = tl.load(ptr_i + offset_i, mask=(ih[:, None] < IH)) 

105 

106 tmp = ( 

107 data.reshape(BLOCK_H // 2, OW // 2, 1) 

108 .broadcast_to(BLOCK_H // 2, OW // 2, 2) 

109 .reshape(BLOCK_H // 2, 1, OW) 

110 ) 

111 tmp1 = tmp.broadcast_to(BLOCK_H // 2, 2, OW).reshape(BLOCK_H, OW) 

112 

113 tl.store(ptr_o + offset_o, tmp1, mask=(oh[:, None] < OH)) 

114 

115 

116@triton.autotune(configs=configs2(), key=["N", "C", "OH", "OW"]) 

117@triton.jit 

118def upsample_nearest2d_kernel_opt_tile_h( 

119 ptr_o, 

120 ptr_i, 

121 N, 

122 C, 

123 OH, 

124 OW: tl.constexpr, 

125 IH, 

126 IW: tl.constexpr, 

127 BLOCK_H: tl.constexpr, 

128): 

129 pid = tl.program_id(axis=0) 

130 num_jobs = tl.num_programs(axis=0) 

131 

132 start = pid * BLOCK_H 

133 step = BLOCK_H * num_jobs 

134 loop_num_h = (OH - start + step - 1) // step 

135 

136 for idx in range(N * C * loop_num_h): 

137 nc_idx = idx // loop_num_h 

138 h_idx = (idx % loop_num_h) * step + start 

139 

140 init_out = nc_idx * OH * OW 

141 init_in = nc_idx * IH * IW 

142 

143 ih = h_idx // 2 + tl.arange(0, BLOCK_H // 2) 

144 iw = tl.arange(0, IW) 

145 offset_i = init_in + ih[:, None] * IW + iw 

146 

147 oh = h_idx + tl.arange(0, BLOCK_H) 

148 ow = tl.arange(0, OW) 

149 offset_o = init_out + oh[:, None] * OW + ow 

150 

151 data = tl.load(ptr_i + offset_i, mask=(ih[:, None] < IH)) 

152 

153 tmp = ( 

154 data.reshape(BLOCK_H // 2, OW // 2, 1) 

155 .broadcast_to(BLOCK_H // 2, OW // 2, 2) 

156 .reshape(BLOCK_H // 2, 1, OW) 

157 ) 

158 tmp1 = tmp.broadcast_to(BLOCK_H // 2, 2, OW).reshape(BLOCK_H, OW) 

159 

160 tl.store(ptr_o + offset_o, tmp1, mask=(oh[:, None] < OH)) 

161 

162 

163def upsample_nearest2d( 

164 input: torch.Tensor, 

165 output_size: Tuple[int], 

166 scales_h: Optional[float] = None, 

167 scales_w: Optional[float] = None, 

168) -> torch.Tensor: 

169 logger.debug("GEMS_CAMBRICON UPSAMPLE NEAREST2D") 

170 assert input.device.type == device 

171 assert input.ndim == 4, "The ndim of input must be 4" 

172 assert len(output_size) == 2, "The len of output_size must be 2" 

173 OH, OW = output_size 

174 N, C, IH, IW = input.shape 

175 if scales_h is not None: 

176 reciprocal_scale_h = 1 / scales_h 

177 else: 

178 reciprocal_scale_h = IH / OH 

179 if scales_w is not None: 

180 reciprocal_scale_w = 1 / scales_w 

181 else: 

182 reciprocal_scale_w = IW / OW 

183 # allocate output 

184 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype) 

185 

186 with torch_device_fn.device(input.device): 

187 if ( 

188 reciprocal_scale_h == 0.5 

189 and reciprocal_scale_w == 0.5 

190 and IH / OH == 0.5 

191 and IW / OW == 0.5 

192 ): 

193 if N * C > 48: 

194 upsample_nearest2d_kernel_opt[TOTAL_CORE_NUM,]( 

195 output, input, N, C, OH, OW, IH, IW 

196 ) 

197 else: 

198 upsample_nearest2d_kernel_opt_tile_h[TOTAL_CORE_NUM,]( 

199 output, input, N, C, OH, OW, IH, IW 

200 ) 

201 else: 

202 total_threads = N * C * OH * OW 

203 

204 # incase grid check error 

205 def grid_fn(META): 

206 num_threads = triton.cdiv(total_threads, META["BLOCK_SIZE"]) 

207 grid_x = min(num_threads, MAX_GRID_SIZE_X) 

208 grid_y = triton.cdiv(num_threads, grid_x) 

209 return ( 

210 grid_x, 

211 grid_y, 

212 ) 

213 

214 upsample_nearest2d_kernel[grid_fn]( 

215 output, 

216 input, 

217 N, 

218 C, 

219 OH, 

220 OW, 

221 IH, 

222 IW, 

223 reciprocal_scale_h, 

224 reciprocal_scale_w, 

225 ) 

226 

227 return output