Coverage for src/flag_gems/runtime/backend/_cambricon/ops/isfinite.py: 0%

24 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from ..utils.pointwise_dynamic import pointwise_dynamic 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11 

12@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "ALWAYS_BOOL")]) 

13@triton.jit 

14def isfinite_func(x): 

15 if x.dtype.is_fp64(): 

16 return (x.to(tl.int64, bitcast=True) & 0x7FFFFFFFFFFFFFFF) < 0x7FF0000000000000 

17 elif x.dtype.is_fp32(): 

18 return (x.to(tl.int32, bitcast=True) & 0x7FFFFFFF) < 0x7F800000 

19 elif x.dtype.is_fp16(): 

20 return (x.to(tl.int16, bitcast=True) & 0x7FFF) < 0x7C00 

21 elif x.dtype.is_bf16(): 

22 return (x.to(tl.int16, bitcast=True) & 0x7FFF) < 0x7F80 

23 

24 

25def isfinite( 

26 A: torch.Tensor, 

27) -> torch.Tensor: 

28 logger.debug("GEMS_CAMBRICON ISFINITE") 

29 if A.is_floating_point(): 

30 legal_dtype = [torch.float32, torch.float16, torch.bfloat16] 

31 assert ( 

32 A.dtype in legal_dtype 

33 ), f"isfinite input float dtype should in {str(legal_dtype)}, get {str(A.dtype)}" 

34 return isfinite_func(A) 

35 else: 

36 return torch.full(A.shape, True, dtype=torch.bool, device=A.device)