Coverage for src/flag_gems/runtime/backend/_hygon/ops/div.py: 0%
178 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic, tl_extra_shim
9logger = logging.getLogger(__name__)
10fmod = tl_extra_shim.fmod
11trunc = tl_extra_shim.trunc
14@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")])
15@triton.jit
16def true_div_func(x, y):
17 return x / y
20@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
21@triton.jit
22def true_div_func_tensor_scalar(x, y):
23 return x / y
26@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
27@triton.jit
28def true_div_func_scalar_tensor(x, y):
29 return x / y
32def true_divide(A, B):
33 logger.debug("GEMS TRUE_DIVIDE")
34 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
35 return true_div_func(A, B)
36 elif isinstance(A, torch.Tensor):
37 return true_div_func_tensor_scalar(A, B)
38 elif isinstance(B, torch.Tensor):
39 return true_div_func_scalar_tensor(A, B)
40 else:
41 # Both scalar
42 return torch.tensor(A / B)
45def true_divide_(A, B):
46 logger.debug("GEMS TRUE_DIVIDE_")
47 if isinstance(B, torch.Tensor):
48 return true_div_func(A, B, out0=A)
49 else:
50 return true_div_func_tensor_scalar(A, B, out0=A)
53@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
54@triton.jit
55def trunc_div_func(x, y):
56 x = x.to(tl.float64)
57 y = y.to(tl.float64)
58 return trunc((x / y))
61@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
62@triton.jit
63def trunc_div_func_tensor_scalar(x, y):
64 return trunc((x / y))
67@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
68@triton.jit
69def trunc_div_func_scalar_tensor(x, y):
70 return trunc((x / y))
73def trunc_divide(A, B):
74 logger.debug("GEMS TRUNC_DIVIDE")
75 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
76 return trunc_div_func(A, B)
77 elif isinstance(A, torch.Tensor):
78 return trunc_div_func_tensor_scalar(A, B)
79 elif isinstance(B, torch.Tensor):
80 return trunc_div_func_scalar_tensor(A, B)
81 else:
82 # Both scalar
83 return torch.tensor(A / B)
86def trunc_divide_(A, B):
87 logger.debug("GEMS TRUNC_DIVIDE_")
88 if isinstance(B, torch.Tensor):
89 return trunc_div_func(A, B, out0=A)
90 else:
91 return trunc_div_func_tensor_scalar(A, B, out0=A)
94@triton.jit
95def _int_floordiv(x, y):
96 # TODO: request Triton to add an integer remainder builtin
97 # The semantic of Triton floordiv differs from Pytorch/Numpy
98 # Triton floordiv equates to
99 # (x - np.fmod(x, y)) / y
100 # whereas Pytorch floordiv is
101 # (x - np.remainder(x, y)) y
102 # The results show a one off difference when
103 # C1) x and y have opposite signs
104 # and C2) x is not multiples of y.
105 # Apart from the above, there's an erroneous case x // 0 returns -1
106 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0
107 # but this special case is coalesced into the c1 and c2 check so
108 # there's extra handling.
109 r = x % y
110 c1 = r != 0
111 c2 = (x < 0) ^ (y < 0)
112 return tl.where(c1 & c2, x // y - 1, x // y)
115# TO be consistent with python, numpy and torch, we have to implement it in the
116# following way.
117# CPython
118# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
119# numpy
120# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532
121# torch
122# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23
123@triton.jit
124def _float_floordiv(x, y):
125 # NOTE: fmod's sign is the same as the dividend
126 remainder = fmod(x, y)
127 imperfect = remainder != 0.0
128 different_sign = (x < 0) ^ (y < 0)
130 # NOTE: we have to use div_rn explicitly here
131 # q = div_rn(x - remainder, y)
132 q = (x - remainder) / y
133 q = tl.where(imperfect & different_sign, q - 1, q)
135 floor_q = tl.math.floor(q)
136 c = q - floor_q > 0.5
137 floor_q = tl.where(c, floor_q + 1.0, floor_q)
139 q_is_zeros = q == 0.0
140 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)
142 is_div_by_zero = y == 0.0
143 float_division = x / y
144 out = tl.where(is_div_by_zero, float_division, floor_q)
145 return out
148@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
149@triton.jit
150def floor_div_func(x, y):
151 if x.type.scalar.is_int() & x.type.scalar.is_int():
152 if x.type.scalar.is_int16():
153 return _int_floordiv(x.to(tl.int32), y)
154 elif x.type.scalar.is_uint16():
155 return _int_floordiv(x.to(tl.uint32), y)
156 else:
157 return _int_floordiv(x, y)
158 else:
159 return _float_floordiv(x, y)
162@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
163@triton.jit
164def floor_div_func_tensor_scalar(x, y):
165 if x.type.scalar.is_int() & x.type.scalar.is_int():
166 if x.type.scalar.is_int16():
167 return _int_floordiv(x.to(tl.int32), y)
168 elif x.type.scalar.is_uint16():
169 return _int_floordiv(x.to(tl.uint32), y)
170 else:
171 return _int_floordiv(x, y)
172 else:
173 return _float_floordiv(x, y)
176@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
177@triton.jit
178def floor_div_func_scalar_tensor(x, y):
179 if x.type.scalar.is_int() & x.type.scalar.is_int():
180 if x.type.scalar.is_int16():
181 return _int_floordiv(x.to(tl.int32), y)
182 elif x.type.scalar.is_uint16():
183 return _int_floordiv(x.to(tl.uint32), y)
184 else:
185 return _int_floordiv(x, y)
186 else:
187 return _float_floordiv(x, y)
190def floor_divide(A, B):
191 logger.debug("GEMS FLOOR_DIVIDE")
192 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
193 return floor_div_func(A, B)
194 elif isinstance(A, torch.Tensor):
195 return floor_div_func_tensor_scalar(A, B)
196 elif isinstance(B, torch.Tensor):
197 return floor_div_func_scalar_tensor(A, B)
198 else:
199 # Both scalar
200 return torch.tensor(A // B)
203def floor_divide_(A, B):
204 logger.debug("GEMS FLOOR_DIVIDE_")
205 if isinstance(B, torch.Tensor):
206 return floor_div_func(A, B, out0=A)
207 else:
208 return floor_div_func_tensor_scalar(A, B, out0=A)
211def div_mode(A, B, rounding_mode=None):
212 if rounding_mode is None:
213 return true_divide(A, B)
214 elif rounding_mode == "trunc":
215 return trunc_divide(A, B)
216 elif rounding_mode == "floor":
217 return floor_divide(A, B)
218 else:
219 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
220 raise ValueError(msg)
223def div_mode_(A, B, rounding_mode=None):
224 if rounding_mode is None:
225 return true_divide_(A, B)
226 elif rounding_mode == "trunc":
227 return trunc_divide_(A, B)
228 elif rounding_mode == "floor":
229 return floor_divide_(A, B)
230 else:
231 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
232 raise ValueError(msg)
235@triton.jit
236def _remainder(x, y):
237 r = x % y
238 c1 = r != 0
239 c2 = (x < 0) ^ (y < 0)
240 return tl.where(c1 & c2, r + y, r)
243@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
244@triton.jit
245def rem_tt(x, y):
246 return _remainder(x, y)
249@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
250@triton.jit
251def rem_ts(x, y):
252 return _remainder(x, y)
255@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
256@triton.jit
257def rem_st(x, y):
258 return _remainder(x, y)
261def remainder(A, B):
262 logger.debug("GEMS FLOOR_DIVIDE")
263 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
264 return rem_tt(A, B)
265 elif isinstance(A, torch.Tensor):
266 return rem_ts(A, B)
267 elif isinstance(B, torch.Tensor):
268 return rem_st(A, B)
269 else:
270 # Both scalar
271 return torch.tensor(A % B)
274def remainder_(A, B):
275 logger.debug("GEMS REMAINDER_")
276 if isinstance(B, torch.Tensor):
277 return rem_tt(A, B, out0=A)
278 else:
279 return rem_ts(A, B, out0=A)