Coverage for src/flag_gems/runtime/backend/_cambricon/ops/div.py: 0%
182 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import tl_extra_shim
9from ..utils.pointwise_dynamic import pointwise_dynamic
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12div_rn = tl_extra_shim.div_rn
13div_rz = tl_extra_shim.div_rz
14fmod = tl_extra_shim.fmod
15trunc = tl_extra_shim.trunc
18@pointwise_dynamic(
19 is_tensor=[True, True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]
20)
21@triton.jit
22def true_div_func(x, y, inplace):
23 return x / y
26@pointwise_dynamic(
27 is_tensor=[True, False, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]
28)
29@triton.jit
30def true_div_func_tensor_scalar(x, y, inplace):
31 y = y.to(x.dtype)
32 return x / y
35@pointwise_dynamic(
36 is_tensor=[False, True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]
37)
38@triton.jit
39def true_div_func_scalar_tensor(x, y, inplace):
40 x = x.to(y.dtype)
41 return x / y
44def true_divide(A, B):
45 logger.debug("GEMS_CAMBRICON TRUE_DIVIDE")
46 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
47 return true_div_func(A, B, False)
48 elif isinstance(A, torch.Tensor):
49 return true_div_func_tensor_scalar(A, B, False)
50 elif isinstance(B, torch.Tensor):
51 return true_div_func_scalar_tensor(A, B, False)
52 else:
53 # Both scalar
54 return torch.tensor(A / B)
57def true_divide_out(A, B, out):
58 logger.debug("GEMS_CAMBRICON TRUE_DIVIDE OUT")
59 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
60 return true_div_func(A, B, False, out0=out)
61 elif isinstance(A, torch.Tensor):
62 return true_div_func_tensor_scalar(A, B, False, out0=out)
63 elif isinstance(B, torch.Tensor):
64 return true_div_func_scalar_tensor(A, B, False, out0=out)
65 else:
66 # Both scalar
67 return torch.tensor(A / B) if out is None else out.fill_(A / B)
70def true_divide_(A, B):
71 logger.debug("GEMS_CAMBRICON TRUE_DIVIDE_")
72 if isinstance(B, torch.Tensor):
73 return true_div_func(A, B, True, out0=A)
74 else:
75 return true_div_func_tensor_scalar(A, B, True, out0=A)
78@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
79@triton.jit
80def trunc_div_func(x, y, inplace):
81 return trunc(div_rn(x, y))
84@pointwise_dynamic(
85 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
86)
87@triton.jit
88def trunc_div_func_tensor_scalar(x, y, inplace):
89 return trunc(div_rn(x, y))
92@pointwise_dynamic(
93 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
94)
95@triton.jit
96def trunc_div_func_scalar_tensor(x, y, inplace):
97 return trunc(div_rn(x, y))
100def trunc_divide(A, B):
101 logger.debug("GEMS_CAMBRICON TRUNC_DIVIDE")
102 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
103 return trunc_div_func(A, B, False)
104 elif isinstance(A, torch.Tensor):
105 return trunc_div_func_tensor_scalar(A, B, False)
106 elif isinstance(B, torch.Tensor):
107 return trunc_div_func_scalar_tensor(A, B, False)
108 else:
109 # Both scalar
110 return torch.tensor(A / B)
113def trunc_divide_(A, B):
114 logger.debug("GEMS_CAMBRICON TRUNC_DIVIDE_")
115 if isinstance(B, torch.Tensor):
116 return trunc_div_func(A, B, True, out0=A)
117 else:
118 return trunc_div_func_tensor_scalar(A, B, True, out0=A)
121@triton.jit
122def _int_floordiv(x, y):
123 # TODO: request Triton to add an integer remainder builtin
124 # The semantic of Triton floordiv differs from Pytorch/Numpy
125 # Triton floordiv equates to
126 # (x - np.fmod(x, y)) / y
127 # whereas Pytorch floordiv is
128 # (x - np.remainder(x, y)) y
129 # The results show a one off difference when
130 # C1) x and y have opposite signs
131 # and C2) x is not multiples of y.
132 # Apart from the above, there's an erroneous case x // 0 returns -1
133 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0
134 # but this special case is coalesced into the c1 and c2 check so
135 # there's extra handling.
136 r = x % y
137 c1 = r != 0
138 c2 = (x < 0) ^ (y < 0)
139 c3 = (x < 0) & (y == 0)
140 c = c1 & c2
141 if x.dtype == tl.int16 and y.dtype == tl.int16:
142 return (x.to(tl.int32) // y.to(tl.int32)).cast(tl.int16) - c - c3
143 return x // y - c - c3
146# TO be consistent with python, numpy and torch, we have to implement it in the
147# following way.
148# CPython
149# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
150# numpy
151# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532
152# torch
153# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23
154@triton.jit
155def _float_floordiv(x, y):
156 # NOTE: fmod's sign is the same as the dividend
157 remainder = fmod(x, y)
158 imperfect = remainder != 0.0
159 different_sign = (x < 0) ^ (y < 0)
161 # NOTE: we have to use div_rn explicitly here
162 q = div_rn(x - remainder, y)
163 q = tl.where(imperfect & different_sign, q - 1, q)
165 floor_q = tl.math.floor(q)
166 c = q - floor_q > 0.5
167 floor_q = tl.where(c, floor_q + 1.0, floor_q)
169 q_is_zeros = q == 0.0
170 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)
172 is_div_by_zero = y == 0.0
173 float_division = x / y
174 out = tl.where(is_div_by_zero, float_division, floor_q)
175 return out
178@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
179@triton.jit
180def floor_div_func(x, y, inplace):
181 if x.type.scalar.is_int() & y.type.scalar.is_int():
182 return _int_floordiv(x, y)
183 else:
184 return _float_floordiv(x, y)
187@pointwise_dynamic(
188 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
189)
190@triton.jit
191def floor_div_func_tensor_scalar(x, y, inplace):
192 if x.type.scalar.is_int() & y.type.scalar.is_int():
193 return _int_floordiv(x, y)
194 else:
195 return _float_floordiv(x, y)
198@pointwise_dynamic(
199 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
200)
201@triton.jit
202def floor_div_func_scalar_tensor(x, y, inplace):
203 if x.type.scalar.is_int() & y.type.scalar.is_int():
204 return _int_floordiv(x, y)
205 else:
206 return _float_floordiv(x, y)
209def floor_divide(A, B):
210 logger.debug("GEMS_CAMBRICON FLOOR_DIVIDE")
211 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
212 return floor_div_func(A, B, False)
213 elif isinstance(A, torch.Tensor):
214 return floor_div_func_tensor_scalar(A, B, False)
215 elif isinstance(B, torch.Tensor):
216 return floor_div_func_scalar_tensor(A, B, False)
217 else:
218 # Both scalar
219 return torch.tensor(A // B)
222def floor_divide_(A, B):
223 logger.debug("GEMS_CAMBRICON FLOOR_DIVIDE_")
224 if isinstance(B, torch.Tensor):
225 return floor_div_func(A, B, True, out0=A)
226 else:
227 return floor_div_func_tensor_scalar(A, B, True, out0=A)
230def div_mode(A, B, rounding_mode=None):
231 if rounding_mode is None:
232 return true_divide(A, B)
233 elif rounding_mode == "trunc":
234 return trunc_divide(A, B)
235 elif rounding_mode == "floor":
236 return floor_divide(A, B)
237 else:
238 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
239 raise ValueError(msg)
242def div_mode_(A, B, rounding_mode=None):
243 if rounding_mode is None:
244 return true_divide_(A, B)
245 elif rounding_mode == "trunc":
246 return trunc_divide_(A, B)
247 elif rounding_mode == "floor":
248 return floor_divide_(A, B)
249 else:
250 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
251 raise ValueError(msg)
254@triton.jit
255def _remainder(x, y):
256 r = x % y
257 c1 = r != 0
258 c2 = (x < 0) ^ (y < 0)
259 return tl.where(c1 & c2, r + y, r)
262@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
263@triton.jit
264def rem_tt(x, y, inplace):
265 return _remainder(x, y)
268@pointwise_dynamic(
269 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
270)
271@triton.jit
272def rem_ts(x, y, inplace):
273 return _remainder(x, y)
276@pointwise_dynamic(
277 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
278)
279@triton.jit
280def rem_st(x, y, inplace):
281 return _remainder(x, y)
284def remainder(A, B):
285 logger.debug("GEMS_CAMBRICON REMAINDER")
286 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
287 return rem_tt(A, B, False)
288 elif isinstance(A, torch.Tensor):
289 return rem_ts(A, B, False)
290 elif isinstance(B, torch.Tensor):
291 return rem_st(A, B, False)
292 else:
293 # Both scalar
294 return torch.tensor(A % B)
297def remainder_(A, B):
298 logger.debug("GEMS_CAMBRICON REMAINDER_")
299 if isinstance(B, torch.Tensor):
300 return rem_tt(A, B, True, out0=A)
301 else:
302 return rem_ts(A, B, True, out0=A)