Coverage for src/flag_gems/runtime/backend/_cambricon/ops/isclose.py: 0%
40 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from ..utils.pointwise_dynamic import pointwise_dynamic
8from .all import all
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13@pointwise_dynamic(
14 is_tensor=[True, True, False, False, False, False],
15 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
16)
17@triton.jit
18def isclose_func(
19 x,
20 y,
21 rtol,
22 atol,
23 equal_nan: tl.constexpr,
24 zero_tol: tl.constexpr,
25):
26 cast_x = x if x.dtype.is_fp64() else x.to(tl.float32)
27 cast_y = y if x.dtype.is_fp64() else y.to(tl.float32)
28 if x.dtype.is_bf16():
29 close = cast_x == cast_y
30 else:
31 close = x == y
32 if equal_nan:
33 close |= (cast_x != cast_x) & (cast_y != cast_y)
34 if not zero_tol:
35 allowed = atol + tl.abs(rtol * cast_y)
36 actual = tl.abs(cast_x - cast_y)
37 # actual_finite = _isfinited(actual) if x.dtype.is_fp64() else _finitef(actual)
38 actual_finite = (
39 ((actual.to(tl.int32, bitcast=True) & 0x7FFFFFFF) < 0x7F800000)
40 if actual.dtype.is_fp32()
41 else (
42 ((actual.to(tl.int16, bitcast=True) & 0x7FFF) < 0x7C00)
43 if actual.dtype.is_fp16()
44 else (actual.to(tl.int16, bitcast=True) & 0x7FFF < 0x7F80)
45 )
46 )
48 close |= actual_finite.to(tl.int1) & (actual <= allowed)
49 return close
52def isclose(
53 A: torch.Tensor,
54 B: torch.Tensor,
55 rtol=1e-05,
56 atol=1e-08,
57 equal_nan: bool = False,
58) -> torch.Tensor:
59 logger.debug("GEMS_CAMBRICON ISCLOSE")
60 # note: Int8 is not supported in isclose_func, because the result of int8 == int8 is wrong
61 # in triton jit function, and needs to be fixed in triton. The same is true for bool.
62 if A.dtype == torch.bool:
63 return A == B
64 if A.dtype != B.dtype:
65 raise RuntimeError("{} did not match {}".format(A.dtype, B.dtype))
66 if A.is_quantized or B.is_quantized:
67 raise RuntimeError("isclose is not supported for quantized inputs.")
68 if rtol < 0:
69 raise RuntimeError(
70 "rtol must be greater than or equal to zero, but got {}".format(rtol)
71 )
72 if atol < 0:
73 raise RuntimeError(
74 "atol must be greater than or equal to zero, but got {}".format(atol)
75 )
76 zero_tol = (rtol == 0) and (atol == 0)
77 return isclose_func(A, B, rtol, atol, equal_nan, zero_tol)
80def allclose(
81 A: torch.Tensor,
82 B: torch.Tensor,
83 rtol=1e-05,
84 atol=1e-08,
85 equal_nan: bool = False,
86) -> bool:
87 logger.debug("GEMS_CAMBRICON ALLCLOSE")
88 return all(isclose(A, B, rtol, atol, equal_nan)).item()