Coverage for src/flag_gems/ops/upsample_linear1d.py: 54%
50 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def upsample_linear1d_kernel(
13 input_ptr,
14 output_ptr,
15 NC,
16 W_in,
17 W_out,
18 align_corners,
19 scale_ac,
20 scale_nc,
21 BLOCK_SIZE: tl.constexpr,
22):
23 pid_nc = tl.program_id(0)
24 pid_w = tl.program_id(1)
26 base_in = pid_nc * W_in
27 base_out = pid_nc * W_out
29 offs_w = pid_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
30 mask = (pid_nc < NC) & (offs_w < W_out)
32 offs_w_f = offs_w.to(tl.float32)
34 src = tl.where(
35 align_corners != 0,
36 offs_w_f * scale_ac,
37 (offs_w_f + 0.5) * scale_nc - 0.5,
38 )
40 src = tl.maximum(0.0, tl.minimum(src, W_in - 1.0))
42 lower = tl.floor(src).to(tl.int32)
43 upper = tl.minimum(lower + 1, W_in - 1)
45 t = src - lower.to(tl.float32)
46 w0 = 1.0 - t
47 w1 = t
49 x0 = tl.load(input_ptr + base_in + lower, mask=mask)
50 x1 = tl.load(input_ptr + base_in + upper, mask=mask)
52 x0_f = x0.to(tl.float32)
53 x1_f = x1.to(tl.float32)
55 out = w0 * x0_f + w1 * x1_f
57 out = out.to(x0.dtype)
59 tl.store(output_ptr + base_out + offs_w, out, mask=mask)
62def upsample_linear1d(
63 self: torch.Tensor,
64 output_size,
65 align_corners: bool,
66 scales: float = None,
67):
68 logger.debug("GEMS UPSAMPLE LINEAR1D")
69 assert self.ndim == 3, "Input must be [N, C, W]"
70 assert self.is_cuda
72 N, C, W_in = self.shape
73 NC = N * C
75 if output_size is not None:
76 W_out = int(
77 output_size[0] if isinstance(output_size, (list, tuple)) else output_size
78 )
79 else:
80 assert scales is not None
81 W_out = int(math.floor(W_in * scales))
83 inp = self.contiguous().view(NC, W_in)
84 out = torch.empty((NC, W_out), device=self.device, dtype=self.dtype)
86 if align_corners:
87 scale_ac = (W_in - 1) / (W_out - 1) if W_out > 1 else 0.0
88 scale_nc = 0.0
89 else:
90 scale_nc = 1.0 / scales if scales is not None else W_in / W_out
91 scale_ac = 0.0
93 BLOCK_SIZE = 256
94 grid = (NC, triton.cdiv(W_out, BLOCK_SIZE))
96 upsample_linear1d_kernel[grid](
97 inp,
98 out,
99 NC,
100 W_in,
101 W_out,
102 int(align_corners),
103 float(scale_ac),
104 float(scale_nc),
105 BLOCK_SIZE=BLOCK_SIZE,
106 )
108 return out.view(N, C, W_out)