Coverage for src/flag_gems/runtime/backend/_hygon/ops/upsample_nearest2d.py: 0%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import device, torch_device_fn
11device = device.name
12logger = logging.getLogger(__name__)
15@triton.autotune(
16 configs=runtime.get_tuned_config("upsample_nearest2d"), key=["N", "C", "OH", "OW"]
17)
18@triton.heuristics(runtime.get_heuristic_config("upsample_nearest2d"))
19@triton.jit
20def upsample_nearest2d_kernel(
21 ptr_o,
22 ptr_i,
23 N,
24 C,
25 OH,
26 OW,
27 IH,
28 IW,
29 reciprocal_scale_h,
30 reciprocal_scale_w,
31 BLOCK_SIZE: tl.constexpr,
32 SAME_H: tl.constexpr,
33 SAME_W: tl.constexpr,
34 USE_INT32_IDX: tl.constexpr,
35):
36 if USE_INT32_IDX:
37 pid_x = tl.program_id(axis=0)
38 else:
39 pid_x = tl.program_id(axis=0).to(tl.int64)
41 idx = pid_x * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
43 ow = idx % OW
44 oh = idx // OW % OH
46 if SAME_H:
47 ih = oh
48 else:
49 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1)
51 if SAME_W:
52 iw = ow
53 else:
54 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1)
55 mask = idx < OH * OW
56 pid_y = tl.program_id(axis=1)
57 num_pid_y = tl.num_programs(axis=1)
59 nc_iter = pid_y
60 total_nc = N * C
62 src_stride_step = (num_pid_y * IH * IW).to(tl.int64)
63 dst_stride_step = (num_pid_y * OH * OW).to(tl.int64)
65 current_ptr_i = ptr_i + (nc_iter * IH * IW).to(tl.int64) + (ih * IW + iw)
66 current_ptr_o = ptr_o + (nc_iter * OH * OW).to(tl.int64) + (oh * OW + ow)
68 while nc_iter < total_nc:
69 val = tl.load(current_ptr_i, mask=mask)
70 tl.store(current_ptr_o, val, mask=mask)
71 nc_iter += num_pid_y
72 current_ptr_i += src_stride_step
73 current_ptr_o += dst_stride_step
76def upsample_nearest2d(
77 input: torch.Tensor,
78 output_size: Tuple[int],
79 scales_h: Optional[float] = None,
80 scales_w: Optional[float] = None,
81) -> torch.Tensor:
82 logger.debug("GEMS UPSAMPLE NEAREST2D")
83 assert input.device.type == device
84 assert input.ndim == 4, "The ndim of input must be 4"
85 assert len(output_size) == 2, "The len of output_size must be 2"
87 OH, OW = output_size
88 N, C, IH, IW = input.shape
90 if scales_h is not None:
91 reciprocal_scale_h = 1 / scales_h
92 else:
93 reciprocal_scale_h = IH / OH
94 if scales_w is not None:
95 reciprocal_scale_w = 1 / scales_w
96 else:
97 reciprocal_scale_w = IW / OW
99 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype)
101 total_threads = OH * OW
103 use_int32 = (N * C * OH * OW) < 2**31
105 grid = lambda META: (
106 triton.cdiv(total_threads, META["BLOCK_SIZE"]),
107 triton.cdiv(N * C, 4),
108 )
110 with torch_device_fn.device(input.device):
111 upsample_nearest2d_kernel[grid](
112 output,
113 input,
114 N,
115 C,
116 OH,
117 OW,
118 IH,
119 IW,
120 reciprocal_scale_h,
121 reciprocal_scale_w,
122 USE_INT32_IDX=use_int32,
123 )
124 return output