Coverage for src/flag_gems/runtime/backend/_cambricon/ops/eq.py: 0%
32 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
6import flag_gems
7from flag_gems.runtime import device
9from ..utils.pointwise_dynamic import pointwise_dynamic
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12device = device.name
15@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")])
16@triton.jit
17def eq_func(x, y):
18 return x == y
21def eq(A, B):
22 if A.device != B.device:
23 if A.device.type == device:
24 B = B.to(A.device)
25 else:
26 A = A.to(B.device)
27 logger.debug("GEMS_CAMBRICON EQ")
28 return eq_func(A, B)
31@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")])
32@triton.jit
33def eq_func_scalar(x, y):
34 return x == y
37def eq_scalar(A, B):
38 logger.debug("GEMS_CAMBRICON EQ SCALAR")
39 return eq_func_scalar(A, B)
42def equal(x: torch.Tensor, y: torch.Tensor) -> bool:
43 logger.debug("GEMS_CAMBRICON EQUAL")
44 if x.shape != y.shape:
45 return False
46 eq_tensor = eq(x, y)
47 return bool(flag_gems.all(eq_tensor).item())