Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/upsample_nearest2d.py: 0%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import device, torch_device_fn 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13device = device.name 

14 

15 

16def heur_block_size(args): 

17 return triton.next_power_of_2( 

18 triton.cdiv(args["N"] * args["C"] * args["OH"] * args["OW"], 12) 

19 ) # cluster_num 

20 

21 

22# @triton.autotune( 

23# configs=runtime.get_tuned_config("upsample_nearest2d"), key=["N", "C", "OH", "OW"] 

24# ) 

25@triton.heuristics( 

26 { 

27 "SAME_H": lambda args: args["OH"] == args["IH"], 

28 "SAME_W": lambda args: args["OW"] == args["IW"], 

29 "BLOCK_SIZE": heur_block_size, 

30 } 

31) 

32@triton.jit 

33def upsample_nearest2d_kernel( 

34 ptr_o, 

35 ptr_i, 

36 N: tl.constexpr, 

37 C: tl.constexpr, 

38 OH, 

39 OW, 

40 IH, 

41 IW, 

42 reciprocal_scale_h, 

43 reciprocal_scale_w, 

44 BLOCK_SIZE: tl.constexpr, 

45 SAME_H: tl.constexpr, 

46 SAME_W: tl.constexpr, 

47): 

48 pid = tle.program_id(axis=0) 

49 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

50 ow = idx % OW 

51 oh = idx // OW % OH 

52 c = idx // OW // OH % C 

53 n = idx // OW // OH // C % N 

54 if SAME_H: 

55 ih = oh 

56 else: 

57 # tl.floor() cannot be found in 2.3.1, using int trunc 

58 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1) 

59 if SAME_W: 

60 iw = ow 

61 else: 

62 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1) 

63 offset_o = ((n * C + c) * OH + oh) * OW + ow 

64 offset_i = ((n * C + c) * IH + ih) * IW + iw 

65 data = tl.load(ptr_i + offset_i) 

66 tl.store(ptr_o + offset_o, data) 

67 

68 

69def upsample_nearest2d( 

70 input: torch.Tensor, 

71 output_size: Tuple[int], 

72 scales_h: Optional[float] = None, 

73 scales_w: Optional[float] = None, 

74) -> torch.Tensor: 

75 logger.debug("GEMS UPSAMPLE NEAREST2D") 

76 assert input.device.type == device 

77 assert input.ndim == 4, "The ndim of input must be 4" 

78 assert len(output_size) == 2, "The len of output_size must be 2" 

79 OH, OW = output_size 

80 N, C, IH, IW = input.shape 

81 if scales_h is not None: 

82 reciprocal_scale_h = 1 / scales_h 

83 else: 

84 reciprocal_scale_h = IH / OH 

85 if scales_w is not None: 

86 reciprocal_scale_w = 1 / scales_w 

87 else: 

88 reciprocal_scale_w = IW / OW 

89 # allocate output 

90 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype) 

91 total_threads = N * C * OH * OW 

92 grid = lambda META: (triton.cdiv(total_threads, META["BLOCK_SIZE"]),) 

93 with torch_device_fn.device(input.device): 

94 upsample_nearest2d_kernel[grid]( 

95 output, input, N, C, OH, OW, IH, IW, reciprocal_scale_h, reciprocal_scale_w 

96 ) 

97 return output