Coverage for src/flag_gems/runtime/backend/_cambricon/ops/isclose.py: 0%

40 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from ..utils.pointwise_dynamic import pointwise_dynamic 

8from .all import all 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12 

13@pointwise_dynamic( 

14 is_tensor=[True, True, False, False, False, False], 

15 promotion_methods=[(0, 1, "ALWAYS_BOOL")], 

16) 

17@triton.jit 

18def isclose_func( 

19 x, 

20 y, 

21 rtol, 

22 atol, 

23 equal_nan: tl.constexpr, 

24 zero_tol: tl.constexpr, 

25): 

26 cast_x = x if x.dtype.is_fp64() else x.to(tl.float32) 

27 cast_y = y if x.dtype.is_fp64() else y.to(tl.float32) 

28 if x.dtype.is_bf16(): 

29 close = cast_x == cast_y 

30 else: 

31 close = x == y 

32 if equal_nan: 

33 close |= (cast_x != cast_x) & (cast_y != cast_y) 

34 if not zero_tol: 

35 allowed = atol + tl.abs(rtol * cast_y) 

36 actual = tl.abs(cast_x - cast_y) 

37 # actual_finite = _isfinited(actual) if x.dtype.is_fp64() else _finitef(actual) 

38 actual_finite = ( 

39 ((actual.to(tl.int32, bitcast=True) & 0x7FFFFFFF) < 0x7F800000) 

40 if actual.dtype.is_fp32() 

41 else ( 

42 ((actual.to(tl.int16, bitcast=True) & 0x7FFF) < 0x7C00) 

43 if actual.dtype.is_fp16() 

44 else (actual.to(tl.int16, bitcast=True) & 0x7FFF < 0x7F80) 

45 ) 

46 ) 

47 

48 close |= actual_finite.to(tl.int1) & (actual <= allowed) 

49 return close 

50 

51 

52def isclose( 

53 A: torch.Tensor, 

54 B: torch.Tensor, 

55 rtol=1e-05, 

56 atol=1e-08, 

57 equal_nan: bool = False, 

58) -> torch.Tensor: 

59 logger.debug("GEMS_CAMBRICON ISCLOSE") 

60 # note: Int8 is not supported in isclose_func, because the result of int8 == int8 is wrong 

61 # in triton jit function, and needs to be fixed in triton. The same is true for bool. 

62 if A.dtype == torch.bool: 

63 return A == B 

64 if A.dtype != B.dtype: 

65 raise RuntimeError("{} did not match {}".format(A.dtype, B.dtype)) 

66 if A.is_quantized or B.is_quantized: 

67 raise RuntimeError("isclose is not supported for quantized inputs.") 

68 if rtol < 0: 

69 raise RuntimeError( 

70 "rtol must be greater than or equal to zero, but got {}".format(rtol) 

71 ) 

72 if atol < 0: 

73 raise RuntimeError( 

74 "atol must be greater than or equal to zero, but got {}".format(atol) 

75 ) 

76 zero_tol = (rtol == 0) and (atol == 0) 

77 return isclose_func(A, B, rtol, atol, equal_nan, zero_tol) 

78 

79 

80def allclose( 

81 A: torch.Tensor, 

82 B: torch.Tensor, 

83 rtol=1e-05, 

84 atol=1e-08, 

85 equal_nan: bool = False, 

86) -> bool: 

87 logger.debug("GEMS_CAMBRICON ALLCLOSE") 

88 return all(isclose(A, B, rtol, atol, equal_nan)).item()