Coverage for src/flag_gems/ops/div.py: 60%
174 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
8from flag_gems.utils.triton_lang_extension import div_rn, div_rz, fmod, trunc
10logger = logging.getLogger(__name__)
13@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")])
14@triton.jit
15def true_div_func(x, y):
16 return x / y
19@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
20@triton.jit
21def true_div_func_tensor_scalar(x, y):
22 return x / y
25@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
26@triton.jit
27def true_div_func_scalar_tensor(x, y):
28 return x / y
31def true_divide(A, B):
32 logger.debug("GEMS TRUE_DIVIDE")
33 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
34 return true_div_func(A, B)
35 elif isinstance(A, torch.Tensor):
36 return true_div_func_tensor_scalar(A, B)
37 elif isinstance(B, torch.Tensor):
38 return true_div_func_scalar_tensor(A, B)
39 else:
40 # Both scalar
41 return torch.tensor(A / B)
44def true_divide_out(A, B, out):
45 logger.debug("GEMS TRUE_DIVIDE OUT")
46 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
47 return true_div_func(A, B, out0=out)
48 elif isinstance(A, torch.Tensor):
49 return true_div_func_tensor_scalar(A, B, out0=out)
50 elif isinstance(B, torch.Tensor):
51 return true_div_func_scalar_tensor(A, B, out0=out)
52 else:
53 # Both scalar
54 return torch.tensor(A / B) if out is None else out.fill_(A / B)
57def true_divide_(A, B):
58 logger.debug("GEMS TRUE_DIVIDE_")
59 if isinstance(B, torch.Tensor):
60 return true_div_func(A, B, out0=A)
61 else:
62 return true_div_func_tensor_scalar(A, B, out0=A)
65@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
66@triton.jit
67def trunc_div_func(x, y):
68 return trunc(div_rz(x, y))
71@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
72@triton.jit
73def trunc_div_func_tensor_scalar(x, y):
74 return trunc(div_rz(x, y))
77@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
78@triton.jit
79def trunc_div_func_scalar_tensor(x, y):
80 return trunc(div_rz(x, y))
83def trunc_divide(A, B):
84 logger.debug("GEMS TRUNC_DIVIDE")
85 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
86 return trunc_div_func(A, B)
87 elif isinstance(A, torch.Tensor):
88 return trunc_div_func_tensor_scalar(A, B)
89 elif isinstance(B, torch.Tensor):
90 return trunc_div_func_scalar_tensor(A, B)
91 else:
92 # Both scalar
93 return torch.tensor(A / B)
96def trunc_divide_(A, B):
97 logger.debug("GEMS TRUNC_DIVIDE_")
98 if isinstance(B, torch.Tensor):
99 return trunc_div_func(A, B, out0=A)
100 else:
101 return trunc_div_func_tensor_scalar(A, B, out0=A)
104@triton.jit
105def _int_floordiv(x, y):
106 # TODO: request Triton to add an integer remainder builtin
107 # The semantic of Triton floordiv differs from Pytorch/Numpy
108 # Triton floordiv equates to
109 # (x - np.fmod(x, y)) / y
110 # whereas Pytorch floordiv is
111 # (x - np.remainder(x, y)) y
112 # The results show a one off difference when
113 # C1) x and y have opposite signs
114 # and C2) x is not multiples of y.
115 # Apart from the above, there's an erroneous case x // 0 returns -1
116 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0
117 # but this special case is coalesced into the c1 and c2 check so
118 # there's extra handling.
119 r = x % y
120 c1 = r != 0
121 c2 = (x < 0) ^ (y < 0)
122 return tl.where(c1 & c2, x // y - 1, x // y)
125# TO be consistent with python, numpy and torch, we have to implement it in the
126# following way.
127# CPython
128# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
129# numpy
130# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532
131# torch
132# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23
133@triton.jit
134def _float_floordiv(x, y):
135 # NOTE: fmod's sign is the same as the dividend
136 remainder = fmod(x, y)
137 imperfect = remainder != 0.0
138 different_sign = (x < 0) ^ (y < 0)
140 # NOTE: we have to use div_rn explicitly here
141 q = div_rn(x - remainder, y)
142 q = tl.where(imperfect & different_sign, q - 1, q)
144 floor_q = tl.math.floor(q)
145 c = q - floor_q > 0.5
146 floor_q = tl.where(c, floor_q + 1.0, floor_q)
148 q_is_zeros = q == 0.0
149 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)
151 is_div_by_zero = y == 0.0
152 float_division = x / y
153 out = tl.where(is_div_by_zero, float_division, floor_q)
154 return out
157@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
158@triton.jit
159def floor_div_func(x, y):
160 if x.type.scalar.is_int() & x.type.scalar.is_int():
161 return _int_floordiv(x, y)
162 else:
163 return _float_floordiv(x, y)
166@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
167@triton.jit
168def floor_div_func_tensor_scalar(x, y):
169 if x.type.scalar.is_int() & x.type.scalar.is_int():
170 return _int_floordiv(x, y)
171 else:
172 return _float_floordiv(x, y)
175@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
176@triton.jit
177def floor_div_func_scalar_tensor(x, y):
178 if x.type.scalar.is_int() & x.type.scalar.is_int():
179 return _int_floordiv(x, y)
180 else:
181 return _float_floordiv(x, y)
184def floor_divide(A, B):
185 logger.debug("GEMS FLOOR_DIVIDE")
186 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
187 return floor_div_func(A, B)
188 elif isinstance(A, torch.Tensor):
189 return floor_div_func_tensor_scalar(A, B)
190 elif isinstance(B, torch.Tensor):
191 return floor_div_func_scalar_tensor(A, B)
192 else:
193 # Both scalar
194 return torch.tensor(A // B)
197def floor_divide_(A, B):
198 logger.debug("GEMS FLOOR_DIVIDE_")
199 if isinstance(B, torch.Tensor):
200 return floor_div_func(A, B, out0=A)
201 else:
202 return floor_div_func_tensor_scalar(A, B, out0=A)
205def div_mode(A, B, rounding_mode=None):
206 logger.debug("GEMS DIV_MODE")
207 if rounding_mode is None:
208 return true_divide(A, B)
209 elif rounding_mode == "trunc":
210 return trunc_divide(A, B)
211 elif rounding_mode == "floor":
212 return floor_divide(A, B)
213 else:
214 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
215 raise ValueError(msg)
218def div_mode_(A, B, rounding_mode=None):
219 logger.debug("GEMS DIV_MODE_")
220 if rounding_mode is None:
221 return true_divide_(A, B)
222 elif rounding_mode == "trunc":
223 return trunc_divide_(A, B)
224 elif rounding_mode == "floor":
225 return floor_divide_(A, B)
226 else:
227 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
228 raise ValueError(msg)
231@triton.jit
232def _remainder(x, y):
233 r = x % y
234 c1 = r != 0
235 c2 = (x < 0) ^ (y < 0)
236 return tl.where(c1 & c2, r + y, r)
239@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
240@triton.jit
241def rem_tt(x, y):
242 return _remainder(x, y)
245@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
246@triton.jit
247def rem_ts(x, y):
248 return _remainder(x, y)
251@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
252@triton.jit
253def rem_st(x, y):
254 return _remainder(x, y)
257def remainder(A, B):
258 logger.debug("GEMS REMAINDER")
259 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
260 return rem_tt(A, B)
261 elif isinstance(A, torch.Tensor):
262 return rem_ts(A, B)
263 elif isinstance(B, torch.Tensor):
264 return rem_st(A, B)
265 else:
266 # Both scalar
267 return torch.tensor(A % B)
270def remainder_(A, B):
271 logger.debug("GEMS REMAINDER_")
272 if isinstance(B, torch.Tensor):
273 return rem_tt(A, B, out0=A)
274 else:
275 return rem_ts(A, B, out0=A)