Coverage for src/flag_gems/runtime/backend/_ascend/ops/upsample_nearest2d.py: 0%
49 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
13device = device.name
16@triton.autotune(
17 configs=[
18 triton.Config({"BLOCK_SIZE": 128}),
19 triton.Config({"BLOCK_SIZE": 256}),
20 triton.Config({"BLOCK_SIZE": 512}),
21 triton.Config({"BLOCK_SIZE": 1024}),
22 triton.Config({"BLOCK_SIZE": 2048}),
23 triton.Config({"BLOCK_SIZE": 4096}),
24 ],
25 key=["N", "C", "OH", "OW"],
26)
27@triton.heuristics(runtime.get_heuristic_config("upsample_nearest2d"))
28@triton.jit
29def upsample_nearest2d_kernel(
30 ptr_o,
31 ptr_i,
32 N,
33 C,
34 OH,
35 OW,
36 IH,
37 IW,
38 reciprocal_scale_h,
39 reciprocal_scale_w,
40 BLOCK_SIZE: tl.constexpr,
41 SAME_H: tl.constexpr,
42 SAME_W: tl.constexpr,
43):
44 pid = tle.program_id(axis=0)
45 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
46 ow = idx % OW
47 oh = idx // OW % OH
48 c = idx // OW // OH % C
49 n = idx // OW // OH // C % N
50 if SAME_H:
51 ih = oh
52 else:
53 # tl.floor() cannot be found in 2.3.1, using int trunc
54 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1)
55 if SAME_W:
56 iw = ow
57 else:
58 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1)
59 offset_o = ((n * C + c) * OH + oh) * OW + ow
60 offset_i = ((n * C + c) * IH + ih) * IW + iw
61 data = tl.load(ptr_i + offset_i)
62 tl.store(ptr_o + offset_o, data)
65def upsample_nearest2d(
66 input: torch.Tensor,
67 output_size: Tuple[int],
68 scales_h: Optional[float] = None,
69 scales_w: Optional[float] = None,
70) -> torch.Tensor:
71 logger.debug("GEMS_ASCEND UPSAMPLE NEAREST2D")
72 assert input.device.type == device
73 assert input.ndim == 4, "The ndim of input must be 4"
74 assert len(output_size) == 2, "The len of output_size must be 2"
75 OH, OW = output_size
76 N, C, IH, IW = input.shape
77 if scales_h is not None:
78 reciprocal_scale_h = 1 / scales_h
79 else:
80 reciprocal_scale_h = IH / OH
81 if scales_w is not None:
82 reciprocal_scale_w = 1 / scales_w
83 else:
84 reciprocal_scale_w = IW / OW
85 # allocate output
86 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype)
87 total_threads = N * C * OH * OW
88 grid = lambda META: (triton.cdiv(total_threads, META["BLOCK_SIZE"]),)
89 with torch_device_fn.device(input.device):
90 upsample_nearest2d_kernel[grid](
91 output, input, N, C, OH, OW, IH, IW, reciprocal_scale_h, reciprocal_scale_w
92 )
93 return output