Coverage for src/flag_gems/ops/remainder.py: 79%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@triton.jit 

13def _remainder(x, y): 

14 r = x % y 

15 c1 = r != 0 

16 c2 = (x < 0) ^ (y < 0) 

17 return tl.where(c1 & c2, r + y, r) 

18 

19 

20@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

21@triton.jit 

22def rem_tt(x, y): 

23 return _remainder(x, y) 

24 

25 

26@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

27@triton.jit 

28def rem_ts(x, y): 

29 return _remainder(x, y) 

30 

31 

32@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

33@triton.jit 

34def rem_st(x, y): 

35 return _remainder(x, y) 

36 

37 

38def remainder(A, B): 

39 logger.debug("GEMS REMAINDER") 

40 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

41 return rem_tt(A, B) 

42 elif isinstance(A, torch.Tensor): 

43 return rem_ts(A, B) 

44 elif isinstance(B, torch.Tensor): 

45 return rem_st(A, B) 

46 else: 

47 # Both scalar 

48 return torch.tensor(A % B) 

49 

50 

51def remainder_(A, B): 

52 logger.debug("GEMS REMAINDER_") 

53 if isinstance(B, torch.Tensor): 

54 return rem_tt(A, B, out0=A) 

55 else: 

56 return rem_ts(A, B, out0=A)