Coverage for src/flag_gems/ops/upsample_linear1d.py: 54%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def upsample_linear1d_kernel( 

13 input_ptr, 

14 output_ptr, 

15 NC, 

16 W_in, 

17 W_out, 

18 align_corners, 

19 scale_ac, 

20 scale_nc, 

21 BLOCK_SIZE: tl.constexpr, 

22): 

23 pid_nc = tl.program_id(0) 

24 pid_w = tl.program_id(1) 

25 

26 base_in = pid_nc * W_in 

27 base_out = pid_nc * W_out 

28 

29 offs_w = pid_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

30 mask = (pid_nc < NC) & (offs_w < W_out) 

31 

32 offs_w_f = offs_w.to(tl.float32) 

33 

34 src = tl.where( 

35 align_corners != 0, 

36 offs_w_f * scale_ac, 

37 (offs_w_f + 0.5) * scale_nc - 0.5, 

38 ) 

39 

40 src = tl.maximum(0.0, tl.minimum(src, W_in - 1.0)) 

41 

42 lower = tl.floor(src).to(tl.int32) 

43 upper = tl.minimum(lower + 1, W_in - 1) 

44 

45 t = src - lower.to(tl.float32) 

46 w0 = 1.0 - t 

47 w1 = t 

48 

49 x0 = tl.load(input_ptr + base_in + lower, mask=mask) 

50 x1 = tl.load(input_ptr + base_in + upper, mask=mask) 

51 

52 x0_f = x0.to(tl.float32) 

53 x1_f = x1.to(tl.float32) 

54 

55 out = w0 * x0_f + w1 * x1_f 

56 

57 out = out.to(x0.dtype) 

58 

59 tl.store(output_ptr + base_out + offs_w, out, mask=mask) 

60 

61 

62def upsample_linear1d( 

63 self: torch.Tensor, 

64 output_size, 

65 align_corners: bool, 

66 scales: float = None, 

67): 

68 logger.debug("GEMS UPSAMPLE LINEAR1D") 

69 assert self.ndim == 3, "Input must be [N, C, W]" 

70 assert self.is_cuda 

71 

72 N, C, W_in = self.shape 

73 NC = N * C 

74 

75 if output_size is not None: 

76 W_out = int( 

77 output_size[0] if isinstance(output_size, (list, tuple)) else output_size 

78 ) 

79 else: 

80 assert scales is not None 

81 W_out = int(math.floor(W_in * scales)) 

82 

83 inp = self.contiguous().view(NC, W_in) 

84 out = torch.empty((NC, W_out), device=self.device, dtype=self.dtype) 

85 

86 if align_corners: 

87 scale_ac = (W_in - 1) / (W_out - 1) if W_out > 1 else 0.0 

88 scale_nc = 0.0 

89 else: 

90 scale_nc = 1.0 / scales if scales is not None else W_in / W_out 

91 scale_ac = 0.0 

92 

93 BLOCK_SIZE = 256 

94 grid = (NC, triton.cdiv(W_out, BLOCK_SIZE)) 

95 

96 upsample_linear1d_kernel[grid]( 

97 inp, 

98 out, 

99 NC, 

100 W_in, 

101 W_out, 

102 int(align_corners), 

103 float(scale_ac), 

104 float(scale_nc), 

105 BLOCK_SIZE=BLOCK_SIZE, 

106 ) 

107 

108 return out.view(N, C, W_out)