Coverage for src/flag_gems/runtime/backend/_hygon/ops/isclose.py: 0%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

8 

9from .all import all 

10 

11logger = logging.getLogger(__name__) 

12try: 

13 _isfinited = tl_extra_shim.isfinited 

14 _finitef = tl_extra_shim.finitef 

15except Exception: 

16 pass 

17 

18 

19@pointwise_dynamic( 

20 is_tensor=[True, True, False, False, False, False], 

21 promotion_methods=[(0, 1, "ALWAYS_BOOL")], 

22) 

23@triton.jit 

24def isclose_func( 

25 x, 

26 y, 

27 rtol, 

28 atol, 

29 equal_nan: tl.constexpr, 

30 zero_tol: tl.constexpr, 

31): 

32 cast_x = x if x.dtype.is_fp64() else x.to(tl.float32) 

33 cast_y = y if x.dtype.is_fp64() else y.to(tl.float32) 

34 if x.dtype.is_bf16(): 

35 close = cast_x == cast_y 

36 else: 

37 close = x == y 

38 if equal_nan: 

39 close |= (cast_x != cast_x) & (cast_y != cast_y) 

40 if zero_tol == 0: 

41 allowed = atol + tl.abs(rtol * cast_y) 

42 actual = tl.abs(cast_x - cast_y) 

43 actual_finite = _isfinited(actual) if x.dtype.is_fp64() else _finitef(actual) 

44 close |= actual_finite.to(tl.int1) & (actual <= allowed) 

45 return close 

46 

47 

48def isclose( 

49 A: torch.Tensor, 

50 B: torch.Tensor, 

51 rtol=1e-05, 

52 atol=1e-08, 

53 equal_nan: bool = False, 

54) -> torch.Tensor: 

55 logger.debug("GEMS ISCLOSE") 

56 # note: Int8 is not supported in isclose_func, because the result of int8 == int8 is wrong 

57 # in triton jit function, and needs to be fixed in triton. The same is true for bool. 

58 if A.dtype == torch.bool: 

59 return A == B 

60 if A.dtype != B.dtype: 

61 raise RuntimeError("{} did not match {}".format(A.dtype, B.dtype)) 

62 if A.is_quantized or B.is_quantized: 

63 raise RuntimeError("isclose is not supported for quantized inputs.") 

64 if rtol < 0: 

65 raise RuntimeError( 

66 "rtol must be greater than or equal to zero, but got {}".format(rtol) 

67 ) 

68 if atol < 0: 

69 raise RuntimeError( 

70 "atol must be greater than or equal to zero, but got {}".format(atol) 

71 ) 

72 zero_tol = (rtol == 0) and (atol == 0) 

73 return isclose_func(A, B, rtol, atol, equal_nan, zero_tol) 

74 

75 

76def allclose( 

77 A: torch.Tensor, 

78 B: torch.Tensor, 

79 rtol=1e-05, 

80 atol=1e-08, 

81 equal_nan: bool = False, 

82) -> bool: 

83 logger.debug("GEMS ALLCLOSE") 

84 return all(isclose(A, B, rtol, atol, equal_nan)).item()