Coverage for src/flag_gems/ops/fmod_.py: 51%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import pointwise_dynamic 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def _div_rz(x, y): 

15 """div_rz - round toward zero""" 

16 result = x / y 

17 return tl.where(result >= 0, tl.floor(result), tl.ceil(result)) 

18 

19 

20@triton.jit 

21def _fmod(x, y): 

22 """fmod using div_rz""" 

23 quotient = _div_rz(x, y) 

24 return x - y * quotient 

25 

26 

27@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

28@triton.jit 

29def fmod_tt(x, y): 

30 """fmod for tensor-tensor""" 

31 return _fmod(x, y) 

32 

33 

34@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

35@triton.jit 

36def fmod_ts(x, y): 

37 """fmod for tensor-scalar""" 

38 return _fmod(x, y) 

39 

40 

41@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

42@triton.jit 

43def fmod_st(x, y): 

44 """fmod for scalar-tensor""" 

45 return _fmod(x, y) 

46 

47 

48def fmod(A, B): 

49 logger.debug("GEMS FMOD") 

50 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

51 return fmod_tt(A, B) 

52 elif isinstance(A, torch.Tensor): 

53 return fmod_ts(A, B) 

54 elif isinstance(B, torch.Tensor): 

55 return fmod_st(A, B) 

56 else: 

57 # Both scalar 

58 return torch.tensor(A % B) 

59 

60 

61def fmod_(A, B): 

62 logger.debug("GEMS FMOD_") 

63 assert A.dtype.is_floating_point, "fmod_ only supports floating point dtypes" 

64 if isinstance(B, torch.Tensor): 

65 return fmod_tt(A, B, out0=A) 

66 else: 

67 return fmod_ts(A, B, out0=A)