Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/copy.py: 0%
64 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
2from typing import Optional
4import torch
5import triton
7from ..utils.codegen_config_utils import CodeGenConfig
8from ..utils.pointwise_dynamic import pointwise_dynamic
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12_FALLBACK_KEYSET = torch._C.DispatchKeySet(
13 torch._C.DispatchKey.CompositeExplicitAutograd
14)
16config_ = CodeGenConfig(
17 512,
18 (65536, 65536, 65536),
19 32,
20 True,
21 prefer_1d_tile=True,
22 is_scatter_slice=True,
23)
26# @pointwise_dynamic(is_tensor=(True,), promotion_methods=[(0, "DEFAULT")])
27# @triton.jit
28# def copy(src):
29# return src
32@pointwise_dynamic(
33 is_tensor=(True,), promotion_methods=[(0, "DEFAULT")], config=config_
34)
35@triton.jit
36def copy_slice(src):
37 return src
40@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")])
41@triton.jit
42def _copy_kernel(src):
43 return src
46def _can_use_triton(dst: torch.Tensor, src: torch.Tensor) -> bool:
47 if dst.layout != torch.strided or src.layout != torch.strided:
48 return False
49 if dst.device != src.device:
50 return False
51 if dst.is_quantized or src.is_quantized:
52 return False
53 if src.is_complex() and not dst.is_complex():
54 # Preserve PyTorch's behaviour of warning when casting complex to real
55 # by forcing the redispatch path, which issues the warning internally.
56 return False
57 if not src.is_contiguous():
58 return False
59 return True
62def _expand_like(src: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
63 if src.shape == target_shape:
64 return src
65 return src.expand(target_shape)
68def copy(
69 template: torch.Tensor, src: torch.Tensor, *, non_blocking: Optional[bool] = False
70):
71 logger.debug("GEMS COPY (functional)")
72 out = torch.empty_strided(
73 template.size(), template.stride(), dtype=template.dtype, device=template.device
74 )
75 copy_(out, src, non_blocking=bool(non_blocking))
76 return out
79def copy_(dst: torch.Tensor, src: torch.Tensor, non_blocking: bool = False):
80 if not isinstance(src, torch.Tensor):
81 raise TypeError("src must be a Tensor")
83 # this is the same as PyTorch's check
84 if dst._is_zerotensor():
85 raise RuntimeError("ZeroTensors are immutable. Call clone() before copy_.")
86 if src._is_zerotensor():
87 return dst.zero_()
89 if torch._C._is_alias_of(dst, src):
90 # Align with PyTorch: if metadata fully matches, this is a no-op.
91 if (
92 dst.storage_offset() == src.storage_offset()
93 and dst.stride() == src.stride()
94 and dst.size() == src.size()
95 and dst.dtype == src.dtype
96 and dst.device == src.device
97 and dst.is_conj() == src.is_conj()
98 and dst.is_neg() == src.is_neg()
99 ):
100 return dst
101 # Otherwise defer to PyTorch for well-defined semantics on overlapping writes.
102 return torch.ops.aten.copy_.default.redispatch(
103 _FALLBACK_KEYSET, dst, src, non_blocking
104 )
106 if not _can_use_triton(dst, src):
107 return torch.ops.aten.copy_.default.redispatch(
108 _FALLBACK_KEYSET, dst, src, non_blocking
109 )
111 if dst.numel() == 0:
112 # Respect PyTorch behaviour: empty tensors should still validate broadcast.
113 return torch.ops.aten.copy_.default.redispatch(
114 _FALLBACK_KEYSET, dst, src, non_blocking
115 )
117 logger.debug("GEMS COPY_")
119 try:
120 broadcast_shape = torch.broadcast_shapes(dst.shape, src.shape)
121 except RuntimeError as exc:
122 raise RuntimeError(str(exc)) from exc
124 if torch.Size(broadcast_shape) != dst.shape:
125 raise RuntimeError(
126 f"The broadcast shape {broadcast_shape} does not match destination shape {tuple(dst.shape)}"
127 )
129 expanded_src = _expand_like(src, dst.shape)
131 overload = _copy_kernel.instantiate(expanded_src.ndim)
132 overload(expanded_src, out0=dst)
133 return dst