Coverage for src/flag_gems/ops/isclose.py: 56%
45 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops.all import all
8from flag_gems.utils import pointwise_dynamic, tl_extra_shim
10try:
11 _isfinited = tl_extra_shim.isfinited
12 _finitef = tl_extra_shim.finitef
13except Exception:
14 pass
15logger = logging.getLogger(__name__)
18@pointwise_dynamic(
19 is_tensor=[True, True, False, False, False, False],
20 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
21)
22@triton.jit
23def isclose_func(
24 x,
25 y,
26 rtol,
27 atol,
28 equal_nan: tl.constexpr,
29 zero_tol: tl.constexpr,
30):
31 cast_x = x if x.dtype.is_fp64() else x.to(tl.float32)
32 cast_y = y if x.dtype.is_fp64() else y.to(tl.float32)
33 if x.dtype.is_bf16():
34 close = cast_x == cast_y
35 else:
36 close = x == y
37 if equal_nan:
38 close |= (cast_x != cast_x) & (cast_y != cast_y)
39 if not zero_tol:
40 allowed = atol + tl.abs(rtol * cast_y)
41 actual = tl.abs(cast_x - cast_y)
42 actual_finite = _isfinited(actual) if x.dtype.is_fp64() else _finitef(actual)
43 close |= actual_finite.to(tl.int1) & (actual <= allowed)
44 return close
47def isclose(
48 A: torch.Tensor,
49 B: torch.Tensor,
50 rtol=1e-05,
51 atol=1e-08,
52 equal_nan: bool = False,
53) -> torch.Tensor:
54 logger.debug("GEMS ISCLOSE")
55 # note: Int8 is not supported in isclose_func, because the result of int8 == int8 is wrong
56 # in triton jit function, and needs to be fixed in triton. The same is true for bool.
57 if A.dtype == torch.bool:
58 return A == B
59 if A.dtype != B.dtype:
60 raise RuntimeError("{} did not match {}".format(A.dtype, B.dtype))
61 if A.is_quantized or B.is_quantized:
62 raise RuntimeError("isclose is not supported for quantized inputs.")
63 if rtol < 0:
64 raise RuntimeError(
65 "rtol must be greater than or equal to zero, but got {}".format(rtol)
66 )
67 if atol < 0:
68 raise RuntimeError(
69 "atol must be greater than or equal to zero, but got {}".format(atol)
70 )
71 zero_tol = (rtol == 0) and (atol == 0)
72 return isclose_func(A, B, rtol, atol, equal_nan, zero_tol)
75def allclose(
76 A: torch.Tensor,
77 B: torch.Tensor,
78 rtol=1e-05,
79 atol=1e-08,
80 equal_nan: bool = False,
81) -> bool:
82 logger.debug("GEMS ALLCLOSE")
83 return all(isclose(A, B, rtol, atol, equal_nan)).item()