Coverage for src/flag_gems/ops/upsample_nearest1d.py: 55%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import device, torch_device_fn
11device = device.name
12logger = logging.getLogger(__name__)
15@triton.autotune(
16 configs=runtime.get_tuned_config("upsample_nearest1d"), key=["N", "C", "OL"]
17)
18@triton.heuristics(runtime.get_heuristic_config("upsample_nearest1d"))
19@triton.jit
20def upsample_nearest1d_kernel(
21 ptr_o,
22 ptr_i,
23 N,
24 C,
25 OL,
26 IL,
27 reciprocal_scale_l,
28 BLOCK_SIZE: tl.constexpr,
29 SAME_L: tl.constexpr,
30 USE_INT32_IDX: tl.constexpr,
31):
32 if USE_INT32_IDX:
33 pid = tl.program_id(axis=0)
34 else:
35 pid = tl.program_id(axis=0).to(tl.int64)
36 nc_stride = tl.num_programs(axis=1)
37 NC = N * C
38 nc_iter = tl.program_id(axis=1)
39 pid = tl.program_id(axis=0)
40 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
41 ol = idx % OL
42 if SAME_L:
43 il = ol
44 else:
45 il = tl.minimum(
46 tl.math.floor(ol.to(tl.float32) * reciprocal_scale_l).to(tl.int32), IL - 1
47 )
49 offset_o = nc_iter * OL + ol
50 offset_i = nc_iter * IL + il
51 src_index_stride = nc_stride * IL
52 dst_index_stride = nc_stride * OL
54 while nc_iter < NC:
55 data = tl.load(ptr_i + offset_i)
56 tl.store(ptr_o + offset_o, data)
57 ptr_i += src_index_stride
58 ptr_o += dst_index_stride
59 nc_iter += nc_stride
62def upsample_nearest1d(
63 input: torch.Tensor,
64 output_size: Optional[Tuple[int]] = None,
65 scales: Optional[float] = None,
66) -> torch.Tensor:
67 logger.debug("GEMS UPSAMPLE NEAREST1D")
68 assert input.device.type == device
69 assert input.ndim == 3, "The ndim of input must be 3"
70 assert (
71 output_size is not None or scales is not None
72 ), "Either output_size or scales should be defined."
74 OL = output_size[0] if output_size is not None else int(input.shape[2] * scales)
75 N, C, IL = input.shape
77 if scales is not None:
78 reciprocal_scale_l = float(
79 torch.tensor(1.0 / scales, dtype=torch.float32).item()
80 )
81 else:
82 # Use float32 division to match PyTorch's behavior
83 reciprocal_scale_l = float(
84 (
85 torch.tensor(IL, dtype=torch.float32)
86 / torch.tensor(OL, dtype=torch.float32)
87 ).item()
88 )
90 # allocate output
91 output = torch.empty((N, C, OL), device=input.device, dtype=input.dtype)
92 total_threads = OL
93 grid = lambda meta: (
94 triton.cdiv(total_threads, meta["BLOCK_SIZE"]),
95 triton.cdiv(N * C, 4),
96 )
98 with torch_device_fn.device(input.device):
99 upsample_nearest1d_kernel[grid](
100 output,
101 input,
102 N,
103 C,
104 OL,
105 IL,
106 reciprocal_scale_l,
107 )
108 return output