Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/to.py: 0%
67 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
2import os
3from typing import Optional
5import torch
6import triton
7from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
9from ..utils.pointwise_dynamic import pointwise_dynamic
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13_FALLBACK_KEYSET = torch._C.DispatchKeySet(
14 torch._C.DispatchKey.CompositeExplicitAutograd
15)
18@pointwise_dynamic(
19 is_tensor=[
20 True,
21 ],
22 promotion_methods=[(0, "DEFAULT")],
23)
24@triton.jit
25def _to_copy_func(x):
26 return x
29close_interleave_config = CodeGenConfig(
30 512,
31 (65536, 65536, 65536),
32 32,
33 True,
34 prefer_1d_tile=True,
35 isCloseInterleave=True,
36)
39@pointwise_dynamic(
40 is_tensor=[
41 True,
42 ],
43 promotion_methods=[(0, "DEFAULT")],
44 config=close_interleave_config,
45)
46@triton.jit
47def _to_copy_func_close_interleave(x):
48 return x
51def _resolve_dtype(x: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.dtype:
52 if dtype is None:
53 return x.dtype
54 if isinstance(dtype, torch.dtype):
55 return dtype
56 raise TypeError(f"Unsupported dtype argument type: {type(dtype)!r}")
59def _resolve_device(x: torch.Tensor, device: Optional[torch.device]) -> torch.device:
60 if device is None:
61 return x.device
62 return torch.device(device)
65def _normalize_memory_format(
66 memory_format: Optional[torch.memory_format],
67) -> torch.memory_format:
68 if memory_format is None:
69 return torch.preserve_format
70 return memory_format
73def _allocate_preserve_format(x: torch.Tensor, empty_kwargs: dict) -> torch.Tensor:
74 """Recreate tensor storage while honoring preserve_format semantics."""
75 if torch.ops.aten.is_non_overlapping_and_dense(x):
76 return torch.empty_strided(x.size(), x.stride(), **empty_kwargs)
77 # Fall back to PyTorch's best-effort layout suggestion when stride replication is unsafe.
78 return torch.empty_like(x, memory_format=torch.preserve_format, **empty_kwargs)
81# func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None,
82# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
83def to_copy(
84 x,
85 *,
86 dtype=None,
87 layout=None,
88 device=None,
89 pin_memory=None,
90 non_blocking=False,
91 memory_format=None,
92):
93 if x.dtype == torch.bfloat16:
94 to_dtype_fn = _to_copy_func_close_interleave
95 else:
96 to_dtype_fn = _to_copy_func
98 # We only implement the dense strided kernel today; all other layouts fall back to PyTorch.
99 if (layout is not None and layout != torch.strided) or x.layout != torch.strided:
100 raise NotImplementedError(
101 "FlagGems to_copy currently supports strided tensors only."
102 )
103 if pin_memory is not None:
104 raise NotImplementedError(
105 "FlagGems to_copy does not yet support pin_memory=True."
106 )
107 if x.is_quantized:
108 raise NotImplementedError(
109 "Quantized tensors are not supported in FlagGems to_copy yet."
110 )
112 target_dtype = _resolve_dtype(x, dtype)
113 target_device = _resolve_device(x, device)
114 target_memory_format = _normalize_memory_format(memory_format)
116 if target_device != x.device or (
117 x.device.type == "cpu" and target_device.type == "cpu"
118 ):
119 # Device transfer (d2h/h2d etc.) relies on PyTorch's implementation.
120 return torch.ops.aten._to_copy.default.redispatch(
121 _FALLBACK_KEYSET,
122 x,
123 dtype=target_dtype,
124 layout=layout,
125 device=target_device,
126 pin_memory=pin_memory,
127 non_blocking=non_blocking,
128 memory_format=target_memory_format,
129 )
131 logger.debug("GEMS _TO_COPY")
132 empty_kwargs = {"dtype": target_dtype, "device": target_device}
134 if target_memory_format is torch.preserve_format:
135 out = _allocate_preserve_format(x, empty_kwargs)
136 else:
137 out = torch.empty_like(x, memory_format=target_memory_format, **empty_kwargs)
139 out = torch.empty_like(x, dtype=dtype, memory_format=memory_format)
140 if out.element_size() == 8:
141 os.environ["TRITONXPU_ELEMBYTES"] = "8"
142 os.environ["TRITONXPU_BF16_FAST"] = "1"
143 res = to_dtype_fn(x, out0=out)
144 del os.environ["TRITONXPU_ELEMBYTES"]
145 del os.environ["TRITONXPU_BF16_FAST"]
146 else:
147 os.environ["TRITONXPU_BF16_FAST"] = "1"
148 res = to_dtype_fn(x, out0=out)
149 del os.environ["TRITONXPU_BF16_FAST"]
150 return res