Coverage for src/flag_gems/runtime/backend/_hygon/ops/isclose.py: 0%
45 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic, tl_extra_shim
9from .all import all
11logger = logging.getLogger(__name__)
12try:
13 _isfinited = tl_extra_shim.isfinited
14 _finitef = tl_extra_shim.finitef
15except Exception:
16 pass
19@pointwise_dynamic(
20 is_tensor=[True, True, False, False, False, False],
21 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
22)
23@triton.jit
24def isclose_func(
25 x,
26 y,
27 rtol,
28 atol,
29 equal_nan: tl.constexpr,
30 zero_tol: tl.constexpr,
31):
32 cast_x = x if x.dtype.is_fp64() else x.to(tl.float32)
33 cast_y = y if x.dtype.is_fp64() else y.to(tl.float32)
34 if x.dtype.is_bf16():
35 close = cast_x == cast_y
36 else:
37 close = x == y
38 if equal_nan:
39 close |= (cast_x != cast_x) & (cast_y != cast_y)
40 if zero_tol == 0:
41 allowed = atol + tl.abs(rtol * cast_y)
42 actual = tl.abs(cast_x - cast_y)
43 actual_finite = _isfinited(actual) if x.dtype.is_fp64() else _finitef(actual)
44 close |= actual_finite.to(tl.int1) & (actual <= allowed)
45 return close
48def isclose(
49 A: torch.Tensor,
50 B: torch.Tensor,
51 rtol=1e-05,
52 atol=1e-08,
53 equal_nan: bool = False,
54) -> torch.Tensor:
55 logger.debug("GEMS ISCLOSE")
56 # note: Int8 is not supported in isclose_func, because the result of int8 == int8 is wrong
57 # in triton jit function, and needs to be fixed in triton. The same is true for bool.
58 if A.dtype == torch.bool:
59 return A == B
60 if A.dtype != B.dtype:
61 raise RuntimeError("{} did not match {}".format(A.dtype, B.dtype))
62 if A.is_quantized or B.is_quantized:
63 raise RuntimeError("isclose is not supported for quantized inputs.")
64 if rtol < 0:
65 raise RuntimeError(
66 "rtol must be greater than or equal to zero, but got {}".format(rtol)
67 )
68 if atol < 0:
69 raise RuntimeError(
70 "atol must be greater than or equal to zero, but got {}".format(atol)
71 )
72 zero_tol = (rtol == 0) and (atol == 0)
73 return isclose_func(A, B, rtol, atol, equal_nan, zero_tol)
76def allclose(
77 A: torch.Tensor,
78 B: torch.Tensor,
79 rtol=1e-05,
80 atol=1e-08,
81 equal_nan: bool = False,
82) -> bool:
83 logger.debug("GEMS ALLCLOSE")
84 return all(isclose(A, B, rtol, atol, equal_nan)).item()