Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/isclose.py: 0%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import tl_extra_shim 

9 

10from ..utils.pointwise_dynamic import pointwise_dynamic 

11from .all import all 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14_isfinited = tl_extra_shim.isfinited 

15_finitef = tl_extra_shim.finitef 

16 

17 

18@pointwise_dynamic( 

19 is_tensor=[True, True, False, False, False, False], 

20 promotion_methods=[(0, 1, "ALWAYS_BOOL")], 

21) 

22@triton.jit 

23def isclose_func( 

24 x, 

25 y, 

26 rtol, 

27 atol, 

28 equal_nan: tl.constexpr, 

29 zero_tol: tl.constexpr, 

30): 

31 cast_x = x if x.dtype.is_fp64() else x.to(tl.float32) 

32 cast_y = y if x.dtype.is_fp64() else y.to(tl.float32) 

33 if x.dtype.is_bf16(): 

34 close = cast_x == cast_y 

35 else: 

36 close = x == y 

37 if equal_nan: 

38 close |= (cast_x != cast_x) & (cast_y != cast_y) 

39 if not zero_tol: 

40 allowed = atol + tl.abs(rtol * cast_y) 

41 actual = tl.abs(cast_x - cast_y) 

42 actual_finite = _isfinited(actual) if x.dtype.is_fp64() else _finitef(actual) 

43 close |= actual_finite.to(tl.int1) & (actual <= allowed) 

44 return close 

45 

46 

47def isclose( 

48 A: torch.Tensor, 

49 B: torch.Tensor, 

50 rtol=1e-05, 

51 atol=1e-08, 

52 equal_nan: bool = False, 

53) -> torch.Tensor: 

54 logger.debug("GEMS ISCLOSE") 

55 if not equal_nan: 

56 os.environ["XPU_cmp_nan"] = "1" 

57 else: 

58 if "XPU_cmp_nan" in os.environ: 

59 del os.environ["XPU_cmp_nan"] 

60 # note: Int8 is not supported in isclose_func, because the result of int8 == int8 is wrong 

61 # in triton jit function, and needs to be fixed in triton. The same is true for bool. 

62 if A.dtype == torch.bool: 

63 return A == B 

64 if A.dtype != B.dtype: 

65 raise RuntimeError("{} did not match {}".format(A.dtype, B.dtype)) 

66 if A.is_quantized or B.is_quantized: 

67 raise RuntimeError("isclose is not supported for quantized inputs.") 

68 if rtol < 0: 

69 raise RuntimeError( 

70 "rtol must be greater than or equal to zero, but got {}".format(rtol) 

71 ) 

72 if atol < 0: 

73 raise RuntimeError( 

74 "atol must be greater than or equal to zero, but got {}".format(atol) 

75 ) 

76 zero_tol = (rtol == 0) and (atol == 0) 

77 return isclose_func(A, B, rtol, atol, equal_nan, zero_tol) 

78 

79 

80def allclose( 

81 A: torch.Tensor, 

82 B: torch.Tensor, 

83 rtol=1e-05, 

84 atol=1e-08, 

85 equal_nan: bool = False, 

86) -> bool: 

87 logger.debug("GEMS ALLCLOSE") 

88 return all(isclose(A, B, rtol, atol, equal_nan)).item()