Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/upsample_nearest1d.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import device, torch_device_fn 

10 

11device = device.name 

12logger = logging.getLogger(__name__) 

13 

14 

15@triton.autotune( 

16 configs=runtime.get_tuned_config("upsample_nearest1d"), key=["N", "C", "OL"] 

17) 

18@triton.heuristics(runtime.get_heuristic_config("upsample_nearest1d")) 

19@triton.jit 

20def upsample_nearest1d_kernel( 

21 ptr_o, 

22 ptr_i, 

23 N, 

24 C, 

25 OL, 

26 IL, 

27 reciprocal_scale_l, 

28 BLOCK_SIZE: tl.constexpr, 

29 SAME_L: tl.constexpr, 

30 USE_INT32_IDX: tl.constexpr, 

31): 

32 if USE_INT32_IDX: 

33 pid = tl.program_id(axis=0) 

34 else: 

35 pid = tl.program_id(axis=0).to(tl.int64) 

36 nc_stride = tl.num_programs(axis=1) 

37 NC = N * C 

38 nc_iter = tl.program_id(axis=1) 

39 pid = tl.program_id(axis=0) 

40 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

41 ol = idx % OL 

42 if SAME_L: 

43 il = ol 

44 else: 

45 il = tl.minimum( 

46 tl.math.floor(ol.to(tl.float32) * reciprocal_scale_l).to(tl.int32), IL - 1 

47 ) 

48 

49 offset_o = nc_iter * OL + ol 

50 offset_i = nc_iter * IL + il 

51 src_index_stride = nc_stride * IL 

52 dst_index_stride = nc_stride * OL 

53 

54 while nc_iter < NC: 

55 data = tl.load(ptr_i + offset_i) 

56 tl.store(ptr_o + offset_o, data) 

57 ptr_i += src_index_stride 

58 ptr_o += dst_index_stride 

59 nc_iter += nc_stride 

60 

61 

62def upsample_nearest1d( 

63 input: torch.Tensor, 

64 output_size: Optional[Tuple[int]] = None, 

65 scales: Optional[float] = None, 

66) -> torch.Tensor: 

67 logger.debug("GEMS UPSAMPLE NEAREST1D") 

68 assert input.device.type == device 

69 assert input.ndim == 3, "The ndim of input must be 3" 

70 assert ( 

71 output_size is not None or scales is not None 

72 ), "Either output_size or scales should be defined." 

73 

74 OL = output_size[0] if output_size is not None else int(input.shape[2] * scales) 

75 N, C, IL = input.shape 

76 

77 if scales is not None: 

78 reciprocal_scale_l = float( 

79 torch.tensor(1.0 / scales, dtype=torch.float32).item() 

80 ) 

81 else: 

82 # Use float32 division to match PyTorch's behavior 

83 reciprocal_scale_l = float( 

84 ( 

85 torch.tensor(IL, dtype=torch.float32) 

86 / torch.tensor(OL, dtype=torch.float32) 

87 ).item() 

88 ) 

89 

90 # allocate output 

91 output = torch.empty((N, C, OL), device=input.device, dtype=input.dtype) 

92 total_threads = OL 

93 grid = lambda meta: ( 

94 triton.cdiv(total_threads, meta["BLOCK_SIZE"]), 

95 triton.cdiv(N * C, 4), 

96 ) 

97 

98 with torch_device_fn.device(input.device): 

99 upsample_nearest1d_kernel[grid]( 

100 output, 

101 input, 

102 N, 

103 C, 

104 OL, 

105 IL, 

106 reciprocal_scale_l, 

107 ) 

108 return output