Coverage for src/flag_gems/runtime/backend/_hygon/ops/upsample_nearest2d.py: 0%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import device, torch_device_fn 

10 

11device = device.name 

12logger = logging.getLogger(__name__) 

13 

14 

15@triton.autotune( 

16 configs=runtime.get_tuned_config("upsample_nearest2d"), key=["N", "C", "OH", "OW"] 

17) 

18@triton.heuristics(runtime.get_heuristic_config("upsample_nearest2d")) 

19@triton.jit 

20def upsample_nearest2d_kernel( 

21 ptr_o, 

22 ptr_i, 

23 N, 

24 C, 

25 OH, 

26 OW, 

27 IH, 

28 IW, 

29 reciprocal_scale_h, 

30 reciprocal_scale_w, 

31 BLOCK_SIZE: tl.constexpr, 

32 SAME_H: tl.constexpr, 

33 SAME_W: tl.constexpr, 

34 USE_INT32_IDX: tl.constexpr, 

35): 

36 if USE_INT32_IDX: 

37 pid_x = tl.program_id(axis=0) 

38 else: 

39 pid_x = tl.program_id(axis=0).to(tl.int64) 

40 

41 idx = pid_x * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

42 

43 ow = idx % OW 

44 oh = idx // OW % OH 

45 

46 if SAME_H: 

47 ih = oh 

48 else: 

49 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1) 

50 

51 if SAME_W: 

52 iw = ow 

53 else: 

54 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1) 

55 mask = idx < OH * OW 

56 pid_y = tl.program_id(axis=1) 

57 num_pid_y = tl.num_programs(axis=1) 

58 

59 nc_iter = pid_y 

60 total_nc = N * C 

61 

62 src_stride_step = (num_pid_y * IH * IW).to(tl.int64) 

63 dst_stride_step = (num_pid_y * OH * OW).to(tl.int64) 

64 

65 current_ptr_i = ptr_i + (nc_iter * IH * IW).to(tl.int64) + (ih * IW + iw) 

66 current_ptr_o = ptr_o + (nc_iter * OH * OW).to(tl.int64) + (oh * OW + ow) 

67 

68 while nc_iter < total_nc: 

69 val = tl.load(current_ptr_i, mask=mask) 

70 tl.store(current_ptr_o, val, mask=mask) 

71 nc_iter += num_pid_y 

72 current_ptr_i += src_stride_step 

73 current_ptr_o += dst_stride_step 

74 

75 

76def upsample_nearest2d( 

77 input: torch.Tensor, 

78 output_size: Tuple[int], 

79 scales_h: Optional[float] = None, 

80 scales_w: Optional[float] = None, 

81) -> torch.Tensor: 

82 logger.debug("GEMS UPSAMPLE NEAREST2D") 

83 assert input.device.type == device 

84 assert input.ndim == 4, "The ndim of input must be 4" 

85 assert len(output_size) == 2, "The len of output_size must be 2" 

86 

87 OH, OW = output_size 

88 N, C, IH, IW = input.shape 

89 

90 if scales_h is not None: 

91 reciprocal_scale_h = 1 / scales_h 

92 else: 

93 reciprocal_scale_h = IH / OH 

94 if scales_w is not None: 

95 reciprocal_scale_w = 1 / scales_w 

96 else: 

97 reciprocal_scale_w = IW / OW 

98 

99 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype) 

100 

101 total_threads = OH * OW 

102 

103 use_int32 = (N * C * OH * OW) < 2**31 

104 

105 grid = lambda META: ( 

106 triton.cdiv(total_threads, META["BLOCK_SIZE"]), 

107 triton.cdiv(N * C, 4), 

108 ) 

109 

110 with torch_device_fn.device(input.device): 

111 upsample_nearest2d_kernel[grid]( 

112 output, 

113 input, 

114 N, 

115 C, 

116 OH, 

117 OW, 

118 IH, 

119 IW, 

120 reciprocal_scale_h, 

121 reciprocal_scale_w, 

122 USE_INT32_IDX=use_int32, 

123 ) 

124 return output