Coverage for src/flag_gems/runtime/backend/_metax/ops/upsample_nearest2d.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import device, torch_device_fn 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems." + __name__) 

13device = device.name 

14 

15 

16@triton.autotune( 

17 configs=runtime.get_tuned_config("upsample_nearest2d"), key=["N", "C", "OH", "OW"] 

18) 

19@triton.heuristics(runtime.get_heuristic_config("upsample_nearest2d")) 

20@triton.jit 

21def upsample_nearest2d_kernel( 

22 ptr_o, 

23 ptr_i, 

24 N, 

25 C, 

26 OH, 

27 OW, 

28 IH, 

29 IW, 

30 reciprocal_scale_h, 

31 reciprocal_scale_w, 

32 BLOCK_SIZE: tl.constexpr, 

33 SAME_H: tl.constexpr, 

34 SAME_W: tl.constexpr, 

35 USE_INT32_IDX: tl.constexpr, 

36): 

37 if USE_INT32_IDX: 

38 pid = tl.program_id(axis=0) 

39 else: 

40 pid = tle.program_id(axis=0) 

41 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

42 ow = idx % OW 

43 oh = idx // OW % OH 

44 c = idx // OW // OH % C 

45 n = idx // OW // OH // C % N 

46 if SAME_H: 

47 ih = oh 

48 else: 

49 # tl.floor() cannot be found in 2.3.1, using int trunc 

50 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1) 

51 if SAME_W: 

52 iw = ow 

53 else: 

54 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1) 

55 offset_o = ((n * C + c) * OH + oh) * OW + ow 

56 offset_i = ((n * C + c) * IH + ih) * IW + iw 

57 data = tl.load(ptr_i + offset_i) 

58 tl.store(ptr_o + offset_o, data) 

59 

60 

61def upsample_nearest2d( 

62 input: torch.Tensor, 

63 output_size: Tuple[int], 

64 scales_h: Optional[float] = None, 

65 scales_w: Optional[float] = None, 

66) -> torch.Tensor: 

67 logger.debug("METAX GEMS UPSAMPLE NEAREST2D") 

68 assert input.device.type == device 

69 assert input.ndim == 4, "The ndim of input must be 4" 

70 assert len(output_size) == 2, "The len of output_size must be 2" 

71 OH, OW = output_size 

72 N, C, IH, IW = input.shape 

73 if scales_h is not None: 

74 reciprocal_scale_h = 1 / scales_h 

75 else: 

76 reciprocal_scale_h = IH / OH 

77 if scales_w is not None: 

78 reciprocal_scale_w = 1 / scales_w 

79 else: 

80 reciprocal_scale_w = IW / OW 

81 # allocate output 

82 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype) 

83 total_threads = N * C * OH * OW 

84 grid = lambda META: (triton.cdiv(total_threads, META["BLOCK_SIZE"]),) 

85 with torch_device_fn.device(input.device): 

86 upsample_nearest2d_kernel[grid]( 

87 output, input, N, C, OH, OW, IH, IW, reciprocal_scale_h, reciprocal_scale_w 

88 ) 

89 return output