Coverage for src/flag_gems/runtime/backend/_cambricon/ops/resolve_conj.py: 0%
17 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
12@triton.jit
13def conj_func(x):
14 return x ^ -(1 << 63)
17def resolve_conj(A: torch.Tensor):
18 logger.debug("GEMS_CAMBRICON RESOLVE_CONJ")
19 if A.is_conj():
20 assert (
21 A.dtype == torch.cfloat
22 ), "The `resolve_conj` operation in FlagGems currently only supports the `torch.cfloat` type"
23 typed_view = torch.view_as_real(A.conj()).view(torch.int64)
24 out = conj_func(typed_view)
25 return torch.view_as_complex(out.view(torch.float))
26 else:
27 return A