Coverage for src/flag_gems/runtime/backend/_ascend/ops/upsample_nearest2d.py: 0%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import device, torch_device_fn 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13device = device.name 

14 

15 

16@triton.autotune( 

17 configs=[ 

18 triton.Config({"BLOCK_SIZE": 128}), 

19 triton.Config({"BLOCK_SIZE": 256}), 

20 triton.Config({"BLOCK_SIZE": 512}), 

21 triton.Config({"BLOCK_SIZE": 1024}), 

22 triton.Config({"BLOCK_SIZE": 2048}), 

23 triton.Config({"BLOCK_SIZE": 4096}), 

24 ], 

25 key=["N", "C", "OH", "OW"], 

26) 

27@triton.heuristics(runtime.get_heuristic_config("upsample_nearest2d")) 

28@triton.jit 

29def upsample_nearest2d_kernel( 

30 ptr_o, 

31 ptr_i, 

32 N, 

33 C, 

34 OH, 

35 OW, 

36 IH, 

37 IW, 

38 reciprocal_scale_h, 

39 reciprocal_scale_w, 

40 BLOCK_SIZE: tl.constexpr, 

41 SAME_H: tl.constexpr, 

42 SAME_W: tl.constexpr, 

43): 

44 pid = tle.program_id(axis=0) 

45 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

46 ow = idx % OW 

47 oh = idx // OW % OH 

48 c = idx // OW // OH % C 

49 n = idx // OW // OH // C % N 

50 if SAME_H: 

51 ih = oh 

52 else: 

53 # tl.floor() cannot be found in 2.3.1, using int trunc 

54 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1) 

55 if SAME_W: 

56 iw = ow 

57 else: 

58 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1) 

59 offset_o = ((n * C + c) * OH + oh) * OW + ow 

60 offset_i = ((n * C + c) * IH + ih) * IW + iw 

61 data = tl.load(ptr_i + offset_i) 

62 tl.store(ptr_o + offset_o, data) 

63 

64 

65def upsample_nearest2d( 

66 input: torch.Tensor, 

67 output_size: Tuple[int], 

68 scales_h: Optional[float] = None, 

69 scales_w: Optional[float] = None, 

70) -> torch.Tensor: 

71 logger.debug("GEMS_ASCEND UPSAMPLE NEAREST2D") 

72 assert input.device.type == device 

73 assert input.ndim == 4, "The ndim of input must be 4" 

74 assert len(output_size) == 2, "The len of output_size must be 2" 

75 OH, OW = output_size 

76 N, C, IH, IW = input.shape 

77 if scales_h is not None: 

78 reciprocal_scale_h = 1 / scales_h 

79 else: 

80 reciprocal_scale_h = IH / OH 

81 if scales_w is not None: 

82 reciprocal_scale_w = 1 / scales_w 

83 else: 

84 reciprocal_scale_w = IW / OW 

85 # allocate output 

86 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype) 

87 total_threads = N * C * OH * OW 

88 grid = lambda META: (triton.cdiv(total_threads, META["BLOCK_SIZE"]),) 

89 with torch_device_fn.device(input.device): 

90 upsample_nearest2d_kernel[grid]( 

91 output, input, N, C, OH, OW, IH, IW, reciprocal_scale_h, reciprocal_scale_w 

92 ) 

93 return output