Coverage for src/flag_gems/runtime/backend/_ascend/ops/resolve_neg.py: 0%
33 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
7import triton.runtime.driver as driver
9logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
12def get_npu_properties():
13 device = torch.npu.current_device()
14 return driver.active.utils.get_device_properties(device)
17@triton.jit
18def resolve_neg_kernel(
19 inp,
20 out,
21 data_len: tl.constexpr,
22 BLOCK_SIZE: tl.constexpr,
23 MAX_DATA_SIZE: tl.constexpr,
24):
25 pid = tl.program_id(0)
26 iter_num = tl.cdiv(BLOCK_SIZE, MAX_DATA_SIZE)
28 for idx in tl.range(0, iter_num):
29 offsets = pid * BLOCK_SIZE + idx * MAX_DATA_SIZE + tl.arange(0, MAX_DATA_SIZE)
30 mask = offsets < data_len
31 inp_val = tl.load(inp + offsets, mask=mask)
32 out_val = -inp_val
33 tl.store(out + offsets, out_val, mask=mask)
36def resolve_neg(A: torch.Tensor):
37 logger.debug("GEMS_ASCEND RESOLVE_NEG")
39 if A.is_neg():
40 data_len = A.numel()
41 out = torch.empty(A.numel(), dtype=A.dtype, device=A.device)
43 CORE_NUM = get_npu_properties()["num_vectorcore"]
44 BLOCK_SIZE = math.ceil(data_len / CORE_NUM)
45 MAX_DATA_SIZE = 20 * 1024
47 grid = lambda meta: (triton.cdiv(data_len, meta["BLOCK_SIZE"]),)
48 resolve_neg_kernel[grid](
49 A,
50 out,
51 data_len,
52 BLOCK_SIZE,
53 MAX_DATA_SIZE,
54 )
55 out = out.view(A.shape)
56 return out
57 else:
58 return A