Coverage for src/flag_gems/ops/copy.py: 68%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
2from typing import Optional
4import torch
5import triton
7from flag_gems.utils import pointwise_dynamic
9logger = logging.getLogger(__name__)
11_FALLBACK_KEYSET = torch._C.DispatchKeySet(
12 torch._C.DispatchKey.CompositeExplicitAutograd
13)
16@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")])
17@triton.jit
18def _copy_kernel(src):
19 return src
22def _can_use_triton(dst: torch.Tensor, src: torch.Tensor) -> bool:
23 if dst.layout != torch.strided or src.layout != torch.strided:
24 return False
25 if dst.device != src.device:
26 return False
27 if dst.is_quantized or src.is_quantized:
28 return False
29 if src.is_complex() or dst.is_complex():
30 # Preserve PyTorch's behaviour of warning when casting complex to real
31 # by forcing the redispatch path, which issues the warning internally.
32 return False
33 return True
36def _expand_like(src: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
37 if src.shape == target_shape:
38 return src
39 return src.expand(target_shape)
42def copy(
43 template: torch.Tensor, src: torch.Tensor, *, non_blocking: Optional[bool] = False
44):
45 logger.debug("GEMS COPY (functional)")
46 out = torch.empty_strided(
47 template.size(), template.stride(), dtype=template.dtype, device=template.device
48 )
49 copy_(out, src, non_blocking=bool(non_blocking))
50 return out
53def copy_(dst: torch.Tensor, src: torch.Tensor, non_blocking: bool = False):
54 if isinstance(src, (int, float, bool)):
55 src = torch.tensor(src, device=dst.device)
56 elif not isinstance(src, torch.Tensor):
57 raise TypeError("unsupport src type for copy_: ", type(src))
59 # this is the same as PyTorch's check
60 if dst._is_zerotensor():
61 raise RuntimeError("ZeroTensors are immutable. Call clone() before copy_.")
62 if src._is_zerotensor():
63 return dst.zero_()
65 if torch._C._is_alias_of(dst, src):
66 # Align with PyTorch: if metadata fully matches, this is a no-op.
67 if (
68 dst.storage_offset() == src.storage_offset()
69 and dst.stride() == src.stride()
70 and dst.size() == src.size()
71 and dst.dtype == src.dtype
72 and dst.device == src.device
73 and dst.is_conj() == src.is_conj()
74 and dst.is_neg() == src.is_neg()
75 ):
76 return dst
77 # Otherwise defer to PyTorch for well-defined semantics on overlapping writes.
78 return torch.ops.aten.copy_.default.redispatch(
79 _FALLBACK_KEYSET, dst, src, non_blocking
80 )
82 if src.numel() > 2**31 - 1 or dst.numel() > 2**31 - 1:
83 return torch.ops.aten.copy_.default.redispatch(
84 _FALLBACK_KEYSET, dst, src, non_blocking
85 )
87 if not _can_use_triton(dst, src):
88 return torch.ops.aten.copy_.default.redispatch(
89 _FALLBACK_KEYSET, dst, src, non_blocking
90 )
92 if dst.numel() == 0:
93 # Respect PyTorch behaviour: empty tensors should still validate broadcast.
94 return torch.ops.aten.copy_.default.redispatch(
95 _FALLBACK_KEYSET, dst, src, non_blocking
96 )
98 logger.debug("GEMS COPY_")
100 try:
101 broadcast_shape = torch.broadcast_shapes(dst.shape, src.shape)
102 except RuntimeError as exc:
103 raise RuntimeError(str(exc)) from exc
105 if torch.Size(broadcast_shape) != dst.shape:
106 raise RuntimeError(
107 f"The broadcast shape {broadcast_shape} does not match destination shape {tuple(dst.shape)}"
108 )
110 expanded_src = _expand_like(src, dst.shape)
112 overload = _copy_kernel.instantiate(expanded_src.ndim)
113 overload(expanded_src, out0=dst)
114 return dst