Coverage for src/flag_gems/ops/upsample_nearest3d.py: 49%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import device, torch_device_fn 

10 

11device = device.name 

12logger = logging.getLogger(__name__) 

13 

14 

15@triton.autotune( 

16 configs=runtime.get_tuned_config("upsample_nearest3d"), 

17 key=["N", "C", "OD", "OH", "OW"], 

18) 

19@triton.heuristics(runtime.get_heuristic_config("upsample_nearest3d")) 

20@triton.jit 

21def upsample_nearest3d_kernel( 

22 ptr_o, 

23 ptr_i, 

24 N, 

25 C, 

26 OD, 

27 OH, 

28 OW, 

29 ID, 

30 IH, 

31 IW, 

32 reciprocal_scale_d, 

33 reciprocal_scale_h, 

34 reciprocal_scale_w, 

35 BLOCK_SIZE: tl.constexpr, 

36 SAME_D: tl.constexpr, 

37 SAME_H: tl.constexpr, 

38 SAME_W: tl.constexpr, 

39 USE_INT32_IDX: tl.constexpr, 

40): 

41 if USE_INT32_IDX: 

42 pid0 = tl.program_id(axis=0) 

43 else: 

44 pid0 = tl.program_id(axis=0).to(tl.int64) 

45 

46 nc_stride = tl.num_programs(axis=1) 

47 NC = N * C 

48 nc_iter = tl.program_id(axis=1) 

49 

50 idx = pid0 * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

51 total_spatial_size = OD * OH * OW 

52 

53 mask = idx < total_spatial_size 

54 

55 ow = idx % OW 

56 oh = (idx // OW) % OH 

57 od = idx // (OW * OH) 

58 

59 if SAME_D: 

60 id = od 

61 else: 

62 id = tl.minimum( 

63 tl.math.floor(od.to(tl.float32) * reciprocal_scale_d).to(tl.int32), ID - 1 

64 ) 

65 

66 if SAME_H: 

67 ih = oh 

68 else: 

69 ih = tl.minimum( 

70 tl.math.floor(oh.to(tl.float32) * reciprocal_scale_h).to(tl.int32), IH - 1 

71 ) 

72 

73 if SAME_W: 

74 iw = ow 

75 else: 

76 iw = tl.minimum( 

77 tl.math.floor(ow.to(tl.float32) * reciprocal_scale_w).to(tl.int32), IW - 1 

78 ) 

79 

80 offset_o = nc_iter * (OD * OH * OW) + idx 

81 offset_i = nc_iter * (ID * IH * IW) + (id * IH * IW + ih * IW + iw) 

82 

83 src_nc_stride = nc_stride * (ID * IH * IW) 

84 dst_nc_stride = nc_stride * (OD * OH * OW) 

85 

86 while nc_iter < NC: 

87 data = tl.load(ptr_i + offset_i, mask=mask) 

88 tl.store(ptr_o + offset_o, data, mask=mask) 

89 

90 offset_i += src_nc_stride 

91 offset_o += dst_nc_stride 

92 nc_iter += nc_stride 

93 

94 

95def upsample_nearest3d( 

96 input: torch.Tensor, 

97 output_size: Tuple[int, int, int], 

98 scales_d: Optional[float] = None, 

99 scales_h: Optional[float] = None, 

100 scales_w: Optional[float] = None, 

101) -> torch.Tensor: 

102 logger.debug("GEMS UPSAMPLE NEAREST3D") 

103 assert input.device.type == device 

104 assert input.ndim == 5, "The ndim of input must be 5" 

105 

106 OD, OH, OW = output_size 

107 N, C, ID, IH, IW = input.shape 

108 

109 def calculate_scale(in_sz, out_sz, s): 

110 if s is not None: 

111 return float(torch.tensor(1.0 / s, dtype=torch.float32).item()) 

112 return float( 

113 ( 

114 torch.tensor(in_sz, dtype=torch.float32) 

115 / torch.tensor(out_sz, dtype=torch.float32) 

116 ).item() 

117 ) 

118 

119 reciprocal_scale_d = calculate_scale(ID, OD, scales_d) 

120 reciprocal_scale_h = calculate_scale(IH, OH, scales_h) 

121 reciprocal_scale_w = calculate_scale(IW, OW, scales_w) 

122 

123 output = torch.empty((N, C, OD, OH, OW), device=input.device, dtype=input.dtype) 

124 

125 total_threads = OD * OH * OW 

126 grid = lambda meta: ( 

127 triton.cdiv(total_threads, meta["BLOCK_SIZE"]), 

128 triton.cdiv(N * C, 4), 

129 ) 

130 

131 with torch_device_fn.device(input.device): 

132 upsample_nearest3d_kernel[grid]( 

133 output, 

134 input, 

135 N, 

136 C, 

137 OD, 

138 OH, 

139 OW, 

140 ID, 

141 IH, 

142 IW, 

143 reciprocal_scale_d, 

144 reciprocal_scale_h, 

145 reciprocal_scale_w, 

146 ) 

147 return output