Coverage for src/flag_gems/runtime/backend/_sunrise/ops/upsample_nearest2d.py: 0%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import device 

9 

10device = device.name 

11logger = logging.getLogger(__name__) 

12 

13 

14def configs(): 

15 block = [128, 256, 512, 1024] 

16 warps = [4, 8, 16, 32] 

17 return [ 

18 triton.Config({"BLOCK_SIZE": bs}, num_warps=wp) for bs in block for wp in warps 

19 ] 

20 

21 

22@triton.autotune(configs=configs(), key=["N", "C", "OH", "OW"]) 

23@triton.heuristics( 

24 { 

25 "SAME_H": lambda args: args["OH"] == args["IH"], 

26 "SAME_W": lambda args: args["OW"] == args["IW"], 

27 } 

28) 

29@triton.jit 

30def upsample_nearest2d_kernel( 

31 ptr_o, 

32 ptr_i, 

33 sno, 

34 sco, 

35 sho, 

36 swo, 

37 sni, 

38 sci, 

39 shi, 

40 swi, 

41 N, 

42 C, 

43 OH, 

44 OW, 

45 IH, 

46 IW, 

47 reciprocal_scale_h, 

48 reciprocal_scale_w, 

49 BLOCK_SIZE: tl.constexpr, 

50 SAME_H: tl.constexpr, 

51 SAME_W: tl.constexpr, 

52): 

53 pid = tl.program_id(axis=0) 

54 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

55 ow = idx % OW 

56 oh = idx // OW % OH 

57 c = idx // OW // OH % C 

58 n = idx // OW // OH // C % N 

59 if SAME_H: 

60 ih = oh 

61 else: 

62 # tl.floor() cannot be found in 2.3.1, using int trunc 

63 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1) 

64 if SAME_W: 

65 iw = ow 

66 else: 

67 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1) 

68 offset_o = n * sno + c * sco + oh * sho + ow * swo 

69 offset_i = n * sni + c * sci + ih * shi + iw * swi 

70 data = tl.load(ptr_i + offset_i) 

71 tl.store(ptr_o + offset_o, data) 

72 

73 

74def upsample_nearest2d( 

75 input: torch.Tensor, 

76 output_size: Tuple[int], 

77 scales_h: Optional[float] = None, 

78 scales_w: Optional[float] = None, 

79) -> torch.Tensor: 

80 logging.debug("GEMS UPSAMPLE NEAREST2D") 

81 assert input.device.type == device 

82 assert input.ndim == 4, "The ndim of input must be 4" 

83 assert len(output_size) == 2, "The len of output_size must be 2" 

84 OH, OW = output_size 

85 N, C, IH, IW = input.shape 

86 if scales_h is not None: 

87 reciprocal_scale_h = 1 / scales_h 

88 else: 

89 reciprocal_scale_h = IH / OH 

90 if scales_w is not None: 

91 reciprocal_scale_w = 1 / scales_w 

92 else: 

93 reciprocal_scale_w = IW / OW 

94 # allocate output 

95 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype) 

96 total_threads = N * C * OH * OW 

97 sno, sco, sho, swo = output.stride() 

98 sni, sci, shi, swi = input.stride() 

99 grid = lambda META: (triton.cdiv(total_threads, META["BLOCK_SIZE"]),) 

100 upsample_nearest2d_kernel[grid]( 

101 output, 

102 input, 

103 sno, 

104 sco, 

105 sho, 

106 swo, 

107 sni, 

108 sci, 

109 shi, 

110 swi, 

111 N, 

112 C, 

113 OH, 

114 OW, 

115 IH, 

116 IW, 

117 reciprocal_scale_h, 

118 reciprocal_scale_w, 

119 ) 

120 return output