Coverage for src/flag_gems/ops/upsample_nearest3d.py: 49%
63 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import device, torch_device_fn
11device = device.name
12logger = logging.getLogger(__name__)
15@triton.autotune(
16 configs=runtime.get_tuned_config("upsample_nearest3d"),
17 key=["N", "C", "OD", "OH", "OW"],
18)
19@triton.heuristics(runtime.get_heuristic_config("upsample_nearest3d"))
20@triton.jit
21def upsample_nearest3d_kernel(
22 ptr_o,
23 ptr_i,
24 N,
25 C,
26 OD,
27 OH,
28 OW,
29 ID,
30 IH,
31 IW,
32 reciprocal_scale_d,
33 reciprocal_scale_h,
34 reciprocal_scale_w,
35 BLOCK_SIZE: tl.constexpr,
36 SAME_D: tl.constexpr,
37 SAME_H: tl.constexpr,
38 SAME_W: tl.constexpr,
39 USE_INT32_IDX: tl.constexpr,
40):
41 if USE_INT32_IDX:
42 pid0 = tl.program_id(axis=0)
43 else:
44 pid0 = tl.program_id(axis=0).to(tl.int64)
46 nc_stride = tl.num_programs(axis=1)
47 NC = N * C
48 nc_iter = tl.program_id(axis=1)
50 idx = pid0 * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
51 total_spatial_size = OD * OH * OW
53 mask = idx < total_spatial_size
55 ow = idx % OW
56 oh = (idx // OW) % OH
57 od = idx // (OW * OH)
59 if SAME_D:
60 id = od
61 else:
62 id = tl.minimum(
63 tl.math.floor(od.to(tl.float32) * reciprocal_scale_d).to(tl.int32), ID - 1
64 )
66 if SAME_H:
67 ih = oh
68 else:
69 ih = tl.minimum(
70 tl.math.floor(oh.to(tl.float32) * reciprocal_scale_h).to(tl.int32), IH - 1
71 )
73 if SAME_W:
74 iw = ow
75 else:
76 iw = tl.minimum(
77 tl.math.floor(ow.to(tl.float32) * reciprocal_scale_w).to(tl.int32), IW - 1
78 )
80 offset_o = nc_iter * (OD * OH * OW) + idx
81 offset_i = nc_iter * (ID * IH * IW) + (id * IH * IW + ih * IW + iw)
83 src_nc_stride = nc_stride * (ID * IH * IW)
84 dst_nc_stride = nc_stride * (OD * OH * OW)
86 while nc_iter < NC:
87 data = tl.load(ptr_i + offset_i, mask=mask)
88 tl.store(ptr_o + offset_o, data, mask=mask)
90 offset_i += src_nc_stride
91 offset_o += dst_nc_stride
92 nc_iter += nc_stride
95def upsample_nearest3d(
96 input: torch.Tensor,
97 output_size: Tuple[int, int, int],
98 scales_d: Optional[float] = None,
99 scales_h: Optional[float] = None,
100 scales_w: Optional[float] = None,
101) -> torch.Tensor:
102 logger.debug("GEMS UPSAMPLE NEAREST3D")
103 assert input.device.type == device
104 assert input.ndim == 5, "The ndim of input must be 5"
106 OD, OH, OW = output_size
107 N, C, ID, IH, IW = input.shape
109 def calculate_scale(in_sz, out_sz, s):
110 if s is not None:
111 return float(torch.tensor(1.0 / s, dtype=torch.float32).item())
112 return float(
113 (
114 torch.tensor(in_sz, dtype=torch.float32)
115 / torch.tensor(out_sz, dtype=torch.float32)
116 ).item()
117 )
119 reciprocal_scale_d = calculate_scale(ID, OD, scales_d)
120 reciprocal_scale_h = calculate_scale(IH, OH, scales_h)
121 reciprocal_scale_w = calculate_scale(IW, OW, scales_w)
123 output = torch.empty((N, C, OD, OH, OW), device=input.device, dtype=input.dtype)
125 total_threads = OD * OH * OW
126 grid = lambda meta: (
127 triton.cdiv(total_threads, meta["BLOCK_SIZE"]),
128 triton.cdiv(N * C, 4),
129 )
131 with torch_device_fn.device(input.device):
132 upsample_nearest3d_kernel[grid](
133 output,
134 input,
135 N,
136 C,
137 OD,
138 OH,
139 OW,
140 ID,
141 IH,
142 IW,
143 reciprocal_scale_d,
144 reciprocal_scale_h,
145 reciprocal_scale_w,
146 )
147 return output