Coverage for src/flag_gems/ops/eq.py: 82%

33 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7import flag_gems 

8from flag_gems.runtime import device 

9from flag_gems.utils import pointwise_dynamic 

10 

11logger = logging.getLogger(__name__) 

12device = device.name 

13 

14 

15@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) 

16@triton.jit 

17def eq_func(x, y): 

18 return x.to(tl.float32) == y.to(tl.float32) 

19 

20 

21def eq(A, B): 

22 if A.device != B.device: 

23 if A.device.type == device: 

24 B = B.to(A.device) 

25 else: 

26 A = A.to(B.device) 

27 logger.debug("GEMS EQ") 

28 return eq_func(A, B) 

29 

30 

31@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) 

32@triton.jit 

33def eq_func_scalar(x, y): 

34 return x.to(tl.float32) == y.to(tl.float32) 

35 

36 

37def eq_scalar(A, B): 

38 logger.debug("GEMS EQ SCALAR") 

39 return eq_func_scalar(A, B) 

40 

41 

42def equal(x: torch.Tensor, y: torch.Tensor) -> bool: 

43 logger.debug("GEMS EQUAL") 

44 if x.shape != y.shape: 

45 return False 

46 eq_tensor = eq(x, y) 

47 return bool(flag_gems.all(eq_tensor).item())