Coverage for src/flag_gems/runtime/backend/_iluvatar/ops/div.py: 0%
166 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic, tl_extra_shim
9logger = logging.getLogger(__name__)
10div_rn = tl_extra_shim.div_rn
11div_rz = tl_extra_shim.div_rz
12fmod = tl_extra_shim.fmod
13trunc = tl_extra_shim.trunc
16@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")])
17@triton.jit
18def true_div_func(x, y):
19 return x / y
22@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
23@triton.jit
24def true_div_func_tensor_scalar(x, y):
25 return x / y
28@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
29@triton.jit
30def true_div_func_scalar_tensor(x, y):
31 return x / y
34def true_divide(A, B):
35 logger.debug("GEMS TRUE_DIVIDE")
36 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
37 return true_div_func(A, B)
38 elif isinstance(A, torch.Tensor):
39 return true_div_func_tensor_scalar(A, B)
40 elif isinstance(B, torch.Tensor):
41 return true_div_func_scalar_tensor(A, B)
42 else:
43 # Both scalar
44 return torch.tensor(A / B)
47def true_divide_(A, B):
48 logger.debug("GEMS TRUE_DIVIDE_")
49 if isinstance(B, torch.Tensor):
50 return true_div_func(A, B, out0=A)
51 else:
52 return true_div_func_tensor_scalar(A, B, out0=A)
55@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
56@triton.jit
57def trunc_div_func(x, y):
58 return trunc(x / y)
61@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
62@triton.jit
63def trunc_div_func_tensor_scalar(x, y):
64 return trunc(div_rz(x, y))
67@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
68@triton.jit
69def trunc_div_func_scalar_tensor(x, y):
70 return trunc(div_rz(x, y))
73def trunc_divide(A, B):
74 logger.debug("GEMS TRUNC_DIVIDE")
75 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
76 return trunc_div_func(A, B)
77 elif isinstance(A, torch.Tensor):
78 return trunc_div_func_tensor_scalar(A, B)
79 elif isinstance(B, torch.Tensor):
80 return trunc_div_func_scalar_tensor(A, B)
81 else:
82 # Both scalar
83 return torch.tensor(A / B)
86def trunc_divide_(A, B):
87 logger.debug("GEMS TRUNC_DIVIDE_")
88 if isinstance(B, torch.Tensor):
89 return trunc_div_func(A, B, out0=A)
90 else:
91 return trunc_div_func_tensor_scalar(A, B, out0=A)
94@triton.jit
95def _int_floordiv(x, y):
96 # TODO: request Triton to add an integer remainder builtin
97 # The semantic of Triton floordiv differs from Pytorch/Numpy
98 # Triton floordiv equates to
99 # (x - np.fmod(x, y)) / y
100 # whereas Pytorch floordiv is
101 # (x - np.remainder(x, y)) y
102 # The results show a one off difference when
103 # C1) x and y have opposite signs
104 # and C2) x is not multiples of y.
105 # Apart from the above, there's an erroneous case x // 0 returns -1
106 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0
107 # but this special case is coalesced into the c1 and c2 check so
108 # there's extra handling.
109 r = x % y
110 c1 = r != 0
111 c2 = (x < 0) ^ (y < 0)
112 return tl.where(c1 & c2, x // y - 1, x // y)
115# TO be consistent with python, numpy and torch, we have to implement it in the
116# following way.
117# CPython
118# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
119# numpy
120# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532
121# torch
122# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23
123@triton.jit
124def _float_floordiv(x, y):
125 # NOTE: fmod's sign is the same as the dividend
126 remainder = fmod(x, y)
127 imperfect = remainder != 0.0
128 different_sign = (x < 0) ^ (y < 0)
130 # NOTE: we have to use div_rn explicitly here
131 q = div_rn(x - remainder, y)
132 q = tl.where(imperfect & different_sign, q - 1, q)
134 floor_q = tl.math.floor(q)
135 c = q - floor_q > 0.5
136 floor_q = tl.where(c, floor_q + 1.0, floor_q)
138 q_is_zeros = q == 0.0
139 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)
141 is_div_by_zero = y == 0.0
142 float_division = x / y
143 out = tl.where(is_div_by_zero, float_division, floor_q)
144 return out
147@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
148@triton.jit
149def floor_div_func(x, y):
150 if x.type.scalar.is_int() & x.type.scalar.is_int():
151 return _int_floordiv(x, y)
152 else:
153 return _float_floordiv(x, y)
156@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
157@triton.jit
158def floor_div_func_tensor_scalar(x, y):
159 if x.type.scalar.is_int() & x.type.scalar.is_int():
160 return _int_floordiv(x, y)
161 else:
162 return _float_floordiv(x, y)
165@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
166@triton.jit
167def floor_div_func_scalar_tensor(x, y):
168 if x.type.scalar.is_int() & x.type.scalar.is_int():
169 return _int_floordiv(x, y)
170 else:
171 return _float_floordiv(x, y)
174def floor_divide(A, B):
175 logger.debug("GEMS FLOOR_DIVIDE")
176 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
177 return floor_div_func(A, B)
178 elif isinstance(A, torch.Tensor):
179 return floor_div_func_tensor_scalar(A, B)
180 elif isinstance(B, torch.Tensor):
181 return floor_div_func_scalar_tensor(A, B)
182 else:
183 # Both scalar
184 return torch.tensor(A // B)
187def floor_divide_(A, B):
188 logger.debug("GEMS FLOOR_DIVIDE_")
189 if isinstance(B, torch.Tensor):
190 return floor_div_func(A, B, out0=A)
191 else:
192 return floor_div_func_tensor_scalar(A, B, out0=A)
195def div_mode(A, B, rounding_mode=None):
196 if rounding_mode is None:
197 return true_divide(A, B)
198 elif rounding_mode == "trunc":
199 return trunc_divide(A, B)
200 elif rounding_mode == "floor":
201 return floor_divide(A, B)
202 else:
203 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
204 raise ValueError(msg)
207def div_mode_(A, B, rounding_mode=None):
208 if rounding_mode is None:
209 return true_divide_(A, B)
210 elif rounding_mode == "trunc":
211 return trunc_divide_(A, B)
212 elif rounding_mode == "floor":
213 return floor_divide_(A, B)
214 else:
215 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
216 raise ValueError(msg)
219@triton.jit
220def _remainder(x, y):
221 r = x % y
222 c1 = r != 0
223 c2 = (x < 0) ^ (y < 0)
224 return tl.where(c1 & c2, r + y, r)
227@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
228@triton.jit
229def rem_tt(x, y):
230 return _remainder(x, y)
233@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
234@triton.jit
235def rem_ts(x, y):
236 return _remainder(x, y)
239@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
240@triton.jit
241def rem_st(x, y):
242 return _remainder(x, y)
245def remainder(A, B):
246 logger.debug("GEMS FLOOR_DIVIDE")
247 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
248 return rem_tt(A, B)
249 elif isinstance(A, torch.Tensor):
250 return rem_ts(A, B)
251 elif isinstance(B, torch.Tensor):
252 return rem_st(A, B)
253 else:
254 # Both scalar
255 return torch.tensor(A % B)
258def remainder_(A, B):
259 logger.debug("GEMS REMAINDER_")
260 if isinstance(B, torch.Tensor):
261 return rem_tt(A, B, out0=A)
262 else:
263 return rem_ts(A, B, out0=A)