Coverage for src/flag_gems/runtime/backend/_sunrise/ops/div.py: 0%
367 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
8from flag_gems.utils.pointwise_dynamic import CodeGenConfig, ComplexMode
9from flag_gems.utils.triton_lang_extension import div_rn, div_rz, fmod, trunc
11logger = logging.getLogger(__name__)
14@pointwise_dynamic(
15 is_tensor=[True, True, True, True],
16 num_outputs=2,
17 promotion_methods=[
18 (0, 1, 2, 3, "INT_TO_FLOAT"),
19 (0, 1, 2, 3, "INT_TO_FLOAT"),
20 ],
21)
22@triton.jit
23def div_complex_kernel(ar, ai, br, bi):
24 # Smith's method: avoid overflow by dividing by the larger component
25 abs_br = tl.abs(br)
26 abs_bi = tl.abs(bi)
27 use_br = abs_br >= abs_bi
29 # When |br| >= |bi|: ratio = bi/br, denom = br + bi*ratio
30 ratio1 = tl.where(br == 0, 0.0, bi / br)
31 denom1 = br + bi * ratio1
32 real1 = (ar + ai * ratio1) / denom1
33 imag1 = (ai - ar * ratio1) / denom1
35 # When |bi| > |br|: ratio = br/bi, denom = bi + br*ratio
36 ratio2 = tl.where(bi == 0, 0.0, br / bi)
37 denom2 = bi + br * ratio2
38 real2 = (ar * ratio2 + ai) / denom2
39 imag2 = (ai * ratio2 - ar) / denom2
41 real = tl.where(use_br, real1, real2)
42 imag = tl.where(use_br, imag1, imag2)
43 return real, imag
46MAX_GRID_SIZES = (65535, 65535, 65535)
47config = CodeGenConfig(
48 max_tile_size=1024,
49 max_grid_size=MAX_GRID_SIZES,
50 max_num_warps_per_cta=32,
51 prefer_block_pointer=True,
52 prefer_1d_tile=True,
53)
56@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")], config=config)
57@triton.jit
58def true_div_func(x, y):
59 return x / y
62@pointwise_dynamic(
63 is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")], config=config
64)
65@triton.jit
66def true_div_func_tensor_scalar(x, y):
67 return x / y
70@pointwise_dynamic(
71 is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")], config=config
72)
73@triton.jit
74def true_div_func_scalar_tensor(x, y):
75 return x / y
78# Register complex support
79true_div_func.register_complex(mode=ComplexMode.CROSS, cross_kernel=div_complex_kernel)
80true_div_func_tensor_scalar.register_complex(
81 mode=ComplexMode.CROSS, tensorize_scalars=True, fallback_target=true_div_func
82)
83true_div_func_scalar_tensor.register_complex(
84 mode=ComplexMode.CROSS, tensorize_scalars=True, fallback_target=true_div_func
85)
88# [sunrise fix]
89def _view_as_real_ptpu_safe(x: torch.Tensor) -> torch.Tensor:
90 """`torch.view_as_real(x)` with a CPU bounce when x is on PTPU.
92 [sunrise fix] PTPU lacks `aten::view_as_real`. For complex div we only need
93 a transient read-only decomposition into real/imag lanes before launching
94 the PTPU-native `div_complex_kernel`, so breaking alias/view semantics here
95 is acceptable. Keep the fallback local to this op instead of monkey-patching
96 the aliasing primitive globally.
97 """
98 try:
99 return torch.view_as_real(x)
100 except NotImplementedError:
101 if x.device.type != "ptpu":
102 raise
103 return torch.view_as_real(x.cpu()).to(x.device)
106# [sunrise fix]
107def _view_as_complex_ptpu_safe(x: torch.Tensor) -> torch.Tensor:
108 """`torch.view_as_complex(x)` with a CPU bounce when x is on PTPU."""
109 try:
110 return torch.view_as_complex(x)
111 except NotImplementedError:
112 if x.device.type != "ptpu":
113 raise
114 return torch.view_as_complex(x.cpu()).to(x.device)
117# [sunrise fix]
118def _scalar_complex_as_real_ptpu_safe(
119 scalar, complex_dtype: torch.dtype, target_shape, device: torch.device
120) -> torch.Tensor:
121 """Broadcast a python scalar to a `view_as_real`-shaped tensor on `device`."""
122 cpu_scalar = torch.tensor(scalar, dtype=complex_dtype, device="cpu").expand(
123 target_shape
124 )
125 cpu_real = torch.view_as_real(cpu_scalar).contiguous()
126 if device.type == "cpu":
127 return cpu_real
128 return cpu_real.to(device)
131# [sunrise fix]
132def _operand_as_real_ptpu_safe(
133 value, complex_dtype: torch.dtype, target_shape, device: torch.device
134) -> torch.Tensor:
135 if isinstance(value, torch.Tensor):
136 tensor = value if value.is_complex() else value.to(complex_dtype)
137 return _view_as_real_ptpu_safe(tensor)
138 return _scalar_complex_as_real_ptpu_safe(value, complex_dtype, target_shape, device)
141# [sunrise fix]
142def _to_cpu_complex_div_reference_operand(value):
143 if not isinstance(value, torch.Tensor):
144 return value
146 cpu_value = value.cpu()
147 if cpu_value.is_complex():
148 if cpu_value.dtype == torch.complex32:
149 return cpu_value.to(torch.complex64)
150 return cpu_value
151 return cpu_value.to(torch.float32)
154# [sunrise fix]
155def _complex_div_cpu_fallback(A, B):
156 """Evaluate complex div on CPU and move the tensor result back.
158 [sunrise fix] For complex tensor division, CPU tensor kernels and the PTPU
159 cross-kernel path disagree at zero divisors (`nan+nanj` vs `inf`) in a few
160 large-tensor cases. The tests use CPU tensor `torch.div(...)` on upcast
161 reference inputs, so in that narrow corner we mirror the reference exactly
162 instead of trying to re-encode the CPU kernel's zero-divisor quirks in
163 Triton.
164 """
165 cpu_a = _to_cpu_complex_div_reference_operand(A)
166 cpu_b = _to_cpu_complex_div_reference_operand(B)
167 result = torch.div(cpu_a, cpu_b)
168 if not isinstance(result, torch.Tensor):
169 return result
170 if isinstance(A, torch.Tensor):
171 return result.to(A.device)
172 return result.to(B.device)
175# [sunrise fix]
176def _tensor_has_zero_divisor(x: torch.Tensor) -> bool:
177 if x.is_complex():
178 return bool(torch.any((x.cpu().real == 0) & (x.cpu().imag == 0)).item())
179 return bool(torch.any(x == 0).item())
182# [sunrise fix]
183def _should_cpu_fallback_complex_div(A, B) -> bool:
184 if not isinstance(B, torch.Tensor):
185 return False
186 if B.device.type != "ptpu":
187 return False
188 if not _tensor_has_zero_divisor(B):
189 return False
190 return True
193# [sunrise fix]
194def _complex_true_divide(A, B):
195 if _should_cpu_fallback_complex_div(A, B):
196 return _complex_div_cpu_fallback(A, B).to(torch.result_type(A, B))
198 result_dtype = torch.result_type(A, B)
199 shape_a = A.shape if isinstance(A, torch.Tensor) else torch.Size([])
200 shape_b = B.shape if isinstance(B, torch.Tensor) else torch.Size([])
201 target_shape = torch.broadcast_shapes(shape_a, shape_b)
202 device = A.device if isinstance(A, torch.Tensor) else B.device
204 Ar = _operand_as_real_ptpu_safe(A, result_dtype, target_shape, device)
205 Br = _operand_as_real_ptpu_safe(B, result_dtype, target_shape, device)
206 ar, ai = Ar[..., 0], Ar[..., 1]
207 br, bi = Br[..., 0], Br[..., 1]
209 common_dtype = torch.promote_types(ar.dtype, br.dtype)
210 ar, ai = ar.to(common_dtype), ai.to(common_dtype)
211 br, bi = br.to(common_dtype), bi.to(common_dtype)
213 real, imag = div_complex_kernel(ar, ai, br, bi)
214 out = torch.stack((real, imag), dim=-1)
215 return _view_as_complex_ptpu_safe(out.contiguous()).to(result_dtype)
218def true_divide(A, B):
219 logger.debug("GEMS TRUE_DIVIDE")
220 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
221 A, complex
222 )
223 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
224 B, complex
225 )
226 if A_is_complex or B_is_complex:
227 if not isinstance(A, torch.Tensor) and not isinstance(B, torch.Tensor):
228 return torch.tensor(A / B)
229 return _complex_true_divide(A, B)
230 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
231 return true_div_func(A, B)
232 elif isinstance(A, torch.Tensor):
233 return true_div_func_tensor_scalar(A, B)
234 elif isinstance(B, torch.Tensor):
235 return true_div_func_scalar_tensor(A, B)
236 else:
237 # Both scalar
238 return torch.tensor(A / B)
241def true_divide_out(A, B, out):
242 logger.debug("GEMS TRUE_DIVIDE OUT")
243 # [sunrise fix]
244 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
245 A, complex
246 )
247 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
248 B, complex
249 )
250 if A_is_complex or B_is_complex:
251 result = true_divide(A, B)
252 if out is None:
253 return result
254 out.copy_(result)
255 return out
256 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
257 return true_div_func(A, B, out0=out)
258 elif isinstance(A, torch.Tensor):
259 return true_div_func_tensor_scalar(A, B, out0=out)
260 elif isinstance(B, torch.Tensor):
261 return true_div_func_scalar_tensor(A, B, out0=out)
262 else:
263 # Both scalar
264 return torch.tensor(A / B) if out is None else out.fill_(A / B)
267def true_divide_(A, B):
268 logger.debug("GEMS TRUE_DIVIDE_")
269 # [sunrise fix]
270 A_is_complex = isinstance(A, torch.Tensor) and A.is_complex()
271 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
272 B, complex
273 )
274 if A_is_complex or B_is_complex:
275 A.copy_(true_divide(A, B))
276 return A
277 if isinstance(B, torch.Tensor):
278 return true_div_func(A, B, out0=A)
279 else:
280 return true_div_func_tensor_scalar(A, B, out0=A)
283@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config)
284@triton.jit
285def trunc_div_func(x, y):
286 return trunc(div_rz(x, y))
289@pointwise_dynamic(
290 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config
291)
292@triton.jit
293def trunc_div_func_tensor_scalar(x, y):
294 return trunc(div_rz(x, tl.cast(y, x.dtype)))
297@pointwise_dynamic(
298 is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")], config=config
299)
300@triton.jit
301def trunc_div_func_scalar_tensor(x, y):
302 return trunc(div_rz(tl.cast(x, y.dtype), y))
305# Integer truncation division: Triton's // on integers is C-style (truncates toward zero)
306@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
307@triton.jit
308def trunc_div_int_func(x, y):
309 return x // y
312@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
313@triton.jit
314def trunc_div_int_func_tensor_scalar(x, y):
315 return x // y
318@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
319@triton.jit
320def trunc_div_int_func_scalar_tensor(x, y):
321 return x // y
324def trunc_divide(A, B):
325 logger.debug("GEMS TRUNC_DIVIDE")
326 # Integer types: use dedicated int kernels (Triton // is C-style truncation)
327 if isinstance(A, torch.Tensor) and not A.is_floating_point():
328 if isinstance(B, torch.Tensor):
329 return trunc_div_int_func(A, B)
330 else:
331 return trunc_div_int_func_tensor_scalar(A, B)
332 if isinstance(B, torch.Tensor) and not B.is_floating_point():
333 return trunc_div_int_func_scalar_tensor(A, B)
334 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
335 return trunc_div_func(A, B)
336 elif isinstance(A, torch.Tensor):
337 return trunc_div_func_tensor_scalar(A, B)
338 elif isinstance(B, torch.Tensor):
339 return trunc_div_func_scalar_tensor(A, B)
340 else:
341 # Both scalar
342 return torch.tensor(type(A)(int(A / B)))
345def trunc_divide_(A, B):
346 logger.debug("GEMS TRUNC_DIVIDE_")
347 # Integer types: use dedicated int kernels (Triton // is C-style truncation)
348 if not A.is_floating_point():
349 if isinstance(B, torch.Tensor):
350 return trunc_div_int_func(A, B, out0=A)
351 else:
352 return trunc_div_int_func_tensor_scalar(A, B, out0=A)
353 if isinstance(B, torch.Tensor):
354 return trunc_div_func(A, B, out0=A)
355 else:
356 return trunc_div_func_tensor_scalar(A, B, out0=A)
359@triton.jit
360def _int_floordiv(x, y):
361 # TODO: request Triton to add an integer remainder builtin
362 # The semantic of Triton floordiv differs from Pytorch/Numpy
363 # Triton floordiv equates to
364 # (x - np.fmod(x, y)) / y
365 # whereas Pytorch floordiv is
366 # (x - np.remainder(x, y)) y
367 # The results show a one off difference when
368 # C1) x and y have opposite signs
369 # and C2) x is not multiples of y.
370 # Apart from the above, there's an erroneous case x // 0 returns -1
371 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0
372 # but this special case is coalesced into the c1 and c2 check so
373 # there's extra handling.
374 # [sunrise fix] On PTPU, lowering `%` in this kernel can clobber the RHS
375 # input buffer for int32 floor_divide. Avoid `%` entirely and infer whether
376 # there is a remainder from the truncating quotient:
377 # q = trunc(x / y)
378 # remainder exists iff q * y != x
379 if x.dtype == tl.int16 and y.dtype == tl.int16:
380 x32 = x.to(tl.int32)
381 y32 = y.to(tl.int32)
382 q32 = x32 // y32
383 c1 = (q32 * y32) != x32
384 c2 = (x32 < 0) ^ (y32 < 0)
385 return (q32 - (c1 & c2)).to(tl.int16)
387 q = x // y
388 c1 = (q * y) != x
389 c2 = (x < 0) ^ (y < 0)
390 return q - (c1 & c2)
393# TO be consistent with python, numpy and torch, we have to implement it in the
394# following way.
395# CPython
396# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
397# numpy
398# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532
399# torch
400# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23
401@triton.jit
402def _float_floordiv(x, y):
403 # NOTE: fmod's sign is the same as the dividend
404 remainder = fmod(x, y)
405 imperfect = remainder != 0.0
406 different_sign = (x < 0) ^ (y < 0)
408 # NOTE: we have to use div_rn explicitly here
409 q = div_rn(x - remainder, y)
410 q = tl.where(imperfect & different_sign, q - 1, q)
412 floor_q = tl.math.floor(q)
413 c = q - floor_q > 0.5
414 floor_q = tl.where(c, floor_q + 1.0, floor_q)
416 q_is_zeros = q == 0.0
417 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)
419 is_div_by_zero = y == 0.0
420 float_division = x / y
421 out = tl.where(is_div_by_zero, float_division, floor_q)
422 return out
425@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
426@triton.jit
427def floor_div_int_func(x, y):
428 return _int_floordiv(x, y)
431@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
432@triton.jit
433def floor_div_int_func_tensor_scalar(x, y):
434 return _int_floordiv(x, y)
437@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
438@triton.jit
439def floor_div_int_func_scalar_tensor(x, y):
440 return _int_floordiv(x, y)
443@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config)
444@triton.jit
445def floor_div_func(x, y):
446 if x.type.scalar.is_int() & y.type.scalar.is_int():
447 return _int_floordiv(x, y)
448 else:
449 return _float_floordiv(x, y)
452@pointwise_dynamic(
453 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config
454)
455@triton.jit
456def floor_div_func_tensor_scalar(x, y):
457 if x.type.scalar.is_int() & y.type.scalar.is_int():
458 return _int_floordiv(x, y)
459 else:
460 return _float_floordiv(x, y)
463@pointwise_dynamic(
464 is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")], config=config
465)
466@triton.jit
467def floor_div_func_scalar_tensor(x, y):
468 if x.type.scalar.is_int() & y.type.scalar.is_int():
469 return _int_floordiv(x, y)
470 else:
471 return _float_floordiv(x, y)
474def floor_divide(A, B):
475 logger.debug("GEMS FLOOR_DIVIDE")
476 if isinstance(A, torch.Tensor) and not A.is_floating_point():
477 if isinstance(B, torch.Tensor):
478 return floor_div_int_func(A, B)
479 return floor_div_int_func_tensor_scalar(A, B)
480 if isinstance(B, torch.Tensor) and not B.is_floating_point():
481 return floor_div_int_func_scalar_tensor(A, B)
482 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
483 return floor_div_func(A, B)
484 elif isinstance(A, torch.Tensor):
485 return floor_div_func_tensor_scalar(A, B)
486 elif isinstance(B, torch.Tensor):
487 return floor_div_func_scalar_tensor(A, B)
488 else:
489 # Both scalar
490 return torch.tensor(A // B)
493def floor_divide_(A, B):
494 logger.debug("GEMS FLOOR_DIVIDE_")
495 if not A.is_floating_point():
496 if isinstance(B, torch.Tensor):
497 return floor_div_int_func(A, B, out0=A)
498 return floor_div_int_func_tensor_scalar(A, B, out0=A)
499 if isinstance(B, torch.Tensor):
500 return floor_div_func(A, B, out0=A)
501 else:
502 return floor_div_func_tensor_scalar(A, B, out0=A)
505def div_mode(A, B, rounding_mode=None):
506 logger.debug("GEMS DIV_MODE")
507 if rounding_mode is None:
508 return true_divide(A, B)
509 elif rounding_mode == "trunc":
510 return trunc_divide(A, B)
511 elif rounding_mode == "floor":
512 return floor_divide(A, B)
513 else:
514 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
515 raise ValueError(msg)
518def div_mode_(A, B, rounding_mode=None):
519 logger.debug("GEMS DIV_MODE_")
520 if rounding_mode is None:
521 return true_divide_(A, B)
522 elif rounding_mode == "trunc":
523 return trunc_divide_(A, B)
524 elif rounding_mode == "floor":
525 return floor_divide_(A, B)
526 else:
527 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
528 raise ValueError(msg)
531@triton.jit
532def _remainder(x, y):
533 r = x % y
534 c1 = r != 0
535 c2 = (x < 0) ^ (y < 0)
536 return tl.where(c1 & c2, r + y, r)
539@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config)
540@triton.jit
541def rem_tt(x, y):
542 return _remainder(x, y)
545@pointwise_dynamic(
546 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config
547)
548@triton.jit
549def rem_ts(x, y):
550 return _remainder(x, y)
553@pointwise_dynamic(
554 is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")], config=config
555)
556@triton.jit
557def rem_st(x, y):
558 return _remainder(x, y)
561remainder_scalar_config = CodeGenConfig(
562 max_tile_size=128,
563 max_grid_size=MAX_GRID_SIZES,
564 max_num_warps_per_cta=16,
565 prefer_block_pointer=True,
566 prefer_1d_tile=True,
567)
570@pointwise_dynamic(
571 is_tensor=[True, False],
572 promotion_methods=[(0, 1, "DEFAULT")],
573 config=remainder_scalar_config,
574)
575@triton.jit
576def rem_ts_scalar_safe(x, y):
577 return _remainder(x, y)
580@pointwise_dynamic(
581 is_tensor=[False, True],
582 promotion_methods=[(0, 1, "DEFAULT")],
583 config=remainder_scalar_config,
584)
585@triton.jit
586def rem_st_scalar_safe(x, y):
587 return _remainder(x, y)
590def _scalar_tensor_value(value):
591 if isinstance(value, torch.Tensor) and value.ndim == 0:
592 return value.cpu().item() if value.device.type != "cpu" else value.item()
593 return value
596def _scalar_left_remainder_device_path(value, tensor):
597 # [sunrise fix] The default scalar remainder lowering on Sunrise/PTPU can
598 # hit the same backend/codegen issue that used to zero the first hardware
599 # block for large integer shapes. Routing scalar cases through a separate,
600 # smaller-tile kernel keeps the op on device while avoiding that unstable
601 # launch configuration.
602 scalar = _scalar_tensor_value(value)
603 return rem_st_scalar_safe(scalar, tensor)
606def _tensor_scalar_remainder_device_path(tensor, value):
607 # [sunrise fix] `tensor % scalar` is intentionally lowered through a more
608 # conservative scalar kernel config than tensor-tensor remainder. The math
609 # is the same; the smaller tile avoids the shape/config combination that
610 # corrupted the first block on Sunrise/PTPU.
611 scalar = _scalar_tensor_value(value)
612 return rem_ts_scalar_safe(tensor, scalar)
615def remainder(A, B):
616 logger.debug("GEMS REMAINDER")
617 # Sunrise/PTPU's integer remainder kernel may reuse its tensor operands as
618 # scratch buffers even for the non-inplace API. Protect both inputs so
619 # follow-up ops observe the original values of `A` and `B`.
620 if (
621 isinstance(A, torch.Tensor)
622 and A.ndim > 0
623 and isinstance(B, torch.Tensor)
624 and B.ndim > 0
625 ):
626 return rem_tt(A.clone(), B.clone())
627 elif isinstance(A, torch.Tensor) and A.ndim > 0:
628 return _tensor_scalar_remainder_device_path(A, B)
629 elif isinstance(B, torch.Tensor) and B.ndim > 0:
630 return _scalar_left_remainder_device_path(A, B)
631 else:
632 # Both scalar
633 result_dtype = torch.result_type(A, B)
634 if isinstance(A, torch.Tensor):
635 result_device = A.device
636 elif isinstance(B, torch.Tensor):
637 result_device = B.device
638 else:
639 result_device = "cpu"
640 return torch.tensor(
641 _scalar_tensor_value(A) % _scalar_tensor_value(B),
642 dtype=result_dtype,
643 device=result_device,
644 )
647def remainder_(A, B):
648 logger.debug("GEMS REMAINDER_")
649 if isinstance(B, torch.Tensor) and B.ndim > 0:
650 return rem_tt(A, B.clone(), out0=A)
651 else:
652 scalar = _scalar_tensor_value(B)
653 rhs = torch.full(
654 A.shape, scalar, dtype=torch.result_type(A, B), device=A.device
655 )
656 return rem_tt(A, rhs, out0=A)