Coverage for src/flag_gems/ops/isclose.py: 56%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops.all import all 

8from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

9 

10try: 

11 _isfinited = tl_extra_shim.isfinited 

12 _finitef = tl_extra_shim.finitef 

13except Exception: 

14 pass 

15logger = logging.getLogger(__name__) 

16 

17 

18@pointwise_dynamic( 

19 is_tensor=[True, True, False, False, False, False], 

20 promotion_methods=[(0, 1, "ALWAYS_BOOL")], 

21) 

22@triton.jit 

23def isclose_func( 

24 x, 

25 y, 

26 rtol, 

27 atol, 

28 equal_nan: tl.constexpr, 

29 zero_tol: tl.constexpr, 

30): 

31 cast_x = x if x.dtype.is_fp64() else x.to(tl.float32) 

32 cast_y = y if x.dtype.is_fp64() else y.to(tl.float32) 

33 if x.dtype.is_bf16(): 

34 close = cast_x == cast_y 

35 else: 

36 close = x == y 

37 if equal_nan: 

38 close |= (cast_x != cast_x) & (cast_y != cast_y) 

39 if not zero_tol: 

40 allowed = atol + tl.abs(rtol * cast_y) 

41 actual = tl.abs(cast_x - cast_y) 

42 actual_finite = _isfinited(actual) if x.dtype.is_fp64() else _finitef(actual) 

43 close |= actual_finite.to(tl.int1) & (actual <= allowed) 

44 return close 

45 

46 

47def isclose( 

48 A: torch.Tensor, 

49 B: torch.Tensor, 

50 rtol=1e-05, 

51 atol=1e-08, 

52 equal_nan: bool = False, 

53) -> torch.Tensor: 

54 logger.debug("GEMS ISCLOSE") 

55 # note: Int8 is not supported in isclose_func, because the result of int8 == int8 is wrong 

56 # in triton jit function, and needs to be fixed in triton. The same is true for bool. 

57 if A.dtype == torch.bool: 

58 return A == B 

59 if A.dtype != B.dtype: 

60 raise RuntimeError("{} did not match {}".format(A.dtype, B.dtype)) 

61 if A.is_quantized or B.is_quantized: 

62 raise RuntimeError("isclose is not supported for quantized inputs.") 

63 if rtol < 0: 

64 raise RuntimeError( 

65 "rtol must be greater than or equal to zero, but got {}".format(rtol) 

66 ) 

67 if atol < 0: 

68 raise RuntimeError( 

69 "atol must be greater than or equal to zero, but got {}".format(atol) 

70 ) 

71 zero_tol = (rtol == 0) and (atol == 0) 

72 return isclose_func(A, B, rtol, atol, equal_nan, zero_tol) 

73 

74 

75def allclose( 

76 A: torch.Tensor, 

77 B: torch.Tensor, 

78 rtol=1e-05, 

79 atol=1e-08, 

80 equal_nan: bool = False, 

81) -> bool: 

82 logger.debug("GEMS ALLCLOSE") 

83 return all(isclose(A, B, rtol, atol, equal_nan)).item()