Coverage for src/flag_gems/runtime/backend/_cambricon/ops/isfinite.py: 0%
24 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from ..utils.pointwise_dynamic import pointwise_dynamic
9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "ALWAYS_BOOL")])
13@triton.jit
14def isfinite_func(x):
15 if x.dtype.is_fp64():
16 return (x.to(tl.int64, bitcast=True) & 0x7FFFFFFFFFFFFFFF) < 0x7FF0000000000000
17 elif x.dtype.is_fp32():
18 return (x.to(tl.int32, bitcast=True) & 0x7FFFFFFF) < 0x7F800000
19 elif x.dtype.is_fp16():
20 return (x.to(tl.int16, bitcast=True) & 0x7FFF) < 0x7C00
21 elif x.dtype.is_bf16():
22 return (x.to(tl.int16, bitcast=True) & 0x7FFF) < 0x7F80
25def isfinite(
26 A: torch.Tensor,
27) -> torch.Tensor:
28 logger.debug("GEMS_CAMBRICON ISFINITE")
29 if A.is_floating_point():
30 legal_dtype = [torch.float32, torch.float16, torch.bfloat16]
31 assert (
32 A.dtype in legal_dtype
33 ), f"isfinite input float dtype should in {str(legal_dtype)}, get {str(A.dtype)}"
34 return isfinite_func(A)
35 else:
36 return torch.full(A.shape, True, dtype=torch.bool, device=A.device)