Coverage for src/flag_gems/ops/fmod_.py: 51%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import pointwise_dynamic
10logger = logging.getLogger(__name__)
13@triton.jit
14def _div_rz(x, y):
15 """div_rz - round toward zero"""
16 result = x / y
17 return tl.where(result >= 0, tl.floor(result), tl.ceil(result))
20@triton.jit
21def _fmod(x, y):
22 """fmod using div_rz"""
23 quotient = _div_rz(x, y)
24 return x - y * quotient
27@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
28@triton.jit
29def fmod_tt(x, y):
30 """fmod for tensor-tensor"""
31 return _fmod(x, y)
34@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
35@triton.jit
36def fmod_ts(x, y):
37 """fmod for tensor-scalar"""
38 return _fmod(x, y)
41@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
42@triton.jit
43def fmod_st(x, y):
44 """fmod for scalar-tensor"""
45 return _fmod(x, y)
48def fmod(A, B):
49 logger.debug("GEMS FMOD")
50 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
51 return fmod_tt(A, B)
52 elif isinstance(A, torch.Tensor):
53 return fmod_ts(A, B)
54 elif isinstance(B, torch.Tensor):
55 return fmod_st(A, B)
56 else:
57 # Both scalar
58 return torch.tensor(A % B)
61def fmod_(A, B):
62 logger.debug("GEMS FMOD_")
63 assert A.dtype.is_floating_point, "fmod_ only supports floating point dtypes"
64 if isinstance(B, torch.Tensor):
65 return fmod_tt(A, B, out0=A)
66 else:
67 return fmod_ts(A, B, out0=A)