Coverage for src/flag_gems/ops/upsample_nearest2d.py: 52%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import device, torch_device_fn 

10 

11device = device.name 

12logger = logging.getLogger(__name__) 

13 

14 

15@triton.autotune( 

16 configs=runtime.get_tuned_config("upsample_nearest2d"), key=["N", "C", "OH", "OW"] 

17) 

18@triton.heuristics(runtime.get_heuristic_config("upsample_nearest2d")) 

19@triton.jit 

20def upsample_nearest2d_kernel( 

21 ptr_o, 

22 ptr_i, 

23 N, 

24 C, 

25 OH, 

26 OW, 

27 IH, 

28 IW, 

29 reciprocal_scale_h, 

30 reciprocal_scale_w, 

31 BLOCK_SIZE: tl.constexpr, 

32 SAME_H: tl.constexpr, 

33 SAME_W: tl.constexpr, 

34 USE_INT32_IDX: tl.constexpr, 

35): 

36 if USE_INT32_IDX: 

37 pid = tl.program_id(axis=0) 

38 else: 

39 pid = tl.program_id(axis=0).to(tl.int64) 

40 nc_stride = tl.num_programs(axis=1) 

41 NC = N * C 

42 nc_iter = tl.program_id(axis=1) 

43 pid = tl.program_id(axis=0) 

44 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

45 ow = idx % OW 

46 oh = idx // OW % OH 

47 if SAME_H: 

48 ih = oh 

49 else: 

50 # tl.floor() cannot be found in 2.3.1, using int trunc 

51 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1) 

52 if SAME_W: 

53 iw = ow 

54 else: 

55 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1) 

56 

57 offset_o = (nc_iter * OH + oh) * OW + ow 

58 offset_i = (nc_iter * IH + ih) * IW + iw 

59 src_index_stride = nc_stride * IH * IW 

60 dst_index_stride = nc_stride * OH * OW 

61 while nc_iter < NC: 

62 data = tl.load(ptr_i + offset_i) 

63 tl.store(ptr_o + offset_o, data) 

64 ptr_i += src_index_stride 

65 ptr_o += dst_index_stride 

66 nc_iter += nc_stride 

67 

68 

69def upsample_nearest2d( 

70 input: torch.Tensor, 

71 output_size: Tuple[int], 

72 scales_h: Optional[float] = None, 

73 scales_w: Optional[float] = None, 

74) -> torch.Tensor: 

75 logger.debug("GEMS UPSAMPLE NEAREST2D") 

76 assert input.device.type == device 

77 assert input.ndim == 4, "The ndim of input must be 4" 

78 assert len(output_size) == 2, "The len of output_size must be 2" 

79 OH, OW = output_size 

80 N, C, IH, IW = input.shape 

81 if scales_h is not None: 

82 reciprocal_scale_h = 1 / scales_h 

83 else: 

84 reciprocal_scale_h = IH / OH 

85 if scales_w is not None: 

86 reciprocal_scale_w = 1 / scales_w 

87 else: 

88 reciprocal_scale_w = IW / OW 

89 # allocate output 

90 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype) 

91 total_threads = OH * OW 

92 grid = lambda META: ( 

93 triton.cdiv(total_threads, META["BLOCK_SIZE"]), 

94 triton.cdiv(N * C, 4), 

95 ) 

96 

97 with torch_device_fn.device(input.device): 

98 upsample_nearest2d_kernel[grid]( 

99 output, 

100 input, 

101 N, 

102 C, 

103 OH, 

104 OW, 

105 IH, 

106 IW, 

107 reciprocal_scale_h, 

108 reciprocal_scale_w, 

109 ) 

110 return output