Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/isclose.py: 0%
48 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import tl_extra_shim
10from ..utils.pointwise_dynamic import pointwise_dynamic
11from .all import all
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14_isfinited = tl_extra_shim.isfinited
15_finitef = tl_extra_shim.finitef
18@pointwise_dynamic(
19 is_tensor=[True, True, False, False, False, False],
20 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
21)
22@triton.jit
23def isclose_func(
24 x,
25 y,
26 rtol,
27 atol,
28 equal_nan: tl.constexpr,
29 zero_tol: tl.constexpr,
30):
31 cast_x = x if x.dtype.is_fp64() else x.to(tl.float32)
32 cast_y = y if x.dtype.is_fp64() else y.to(tl.float32)
33 if x.dtype.is_bf16():
34 close = cast_x == cast_y
35 else:
36 close = x == y
37 if equal_nan:
38 close |= (cast_x != cast_x) & (cast_y != cast_y)
39 if not zero_tol:
40 allowed = atol + tl.abs(rtol * cast_y)
41 actual = tl.abs(cast_x - cast_y)
42 actual_finite = _isfinited(actual) if x.dtype.is_fp64() else _finitef(actual)
43 close |= actual_finite.to(tl.int1) & (actual <= allowed)
44 return close
47def isclose(
48 A: torch.Tensor,
49 B: torch.Tensor,
50 rtol=1e-05,
51 atol=1e-08,
52 equal_nan: bool = False,
53) -> torch.Tensor:
54 logger.debug("GEMS ISCLOSE")
55 if not equal_nan:
56 os.environ["XPU_cmp_nan"] = "1"
57 else:
58 if "XPU_cmp_nan" in os.environ:
59 del os.environ["XPU_cmp_nan"]
60 # note: Int8 is not supported in isclose_func, because the result of int8 == int8 is wrong
61 # in triton jit function, and needs to be fixed in triton. The same is true for bool.
62 if A.dtype == torch.bool:
63 return A == B
64 if A.dtype != B.dtype:
65 raise RuntimeError("{} did not match {}".format(A.dtype, B.dtype))
66 if A.is_quantized or B.is_quantized:
67 raise RuntimeError("isclose is not supported for quantized inputs.")
68 if rtol < 0:
69 raise RuntimeError(
70 "rtol must be greater than or equal to zero, but got {}".format(rtol)
71 )
72 if atol < 0:
73 raise RuntimeError(
74 "atol must be greater than or equal to zero, but got {}".format(atol)
75 )
76 zero_tol = (rtol == 0) and (atol == 0)
77 return isclose_func(A, B, rtol, atol, equal_nan, zero_tol)
80def allclose(
81 A: torch.Tensor,
82 B: torch.Tensor,
83 rtol=1e-05,
84 atol=1e-08,
85 equal_nan: bool = False,
86) -> bool:
87 logger.debug("GEMS ALLCLOSE")
88 return all(isclose(A, B, rtol, atol, equal_nan)).item()