Coverage for src/flag_gems/experimental_ops/copy_.py: 0%
74 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import torch
2import triton
3import triton.language as tl
6def _tl_dtype_from_torch(dtype: torch.dtype):
7 # Map common torch dtypes to Triton dtypes
8 if dtype == torch.float16:
9 return tl.float16
10 if dtype == torch.bfloat16:
11 return tl.bfloat16
12 if dtype == torch.float32:
13 return tl.float32
14 if dtype == torch.float64:
15 return tl.float64
16 if dtype == torch.int8:
17 return tl.int8
18 if dtype == torch.int16:
19 return tl.int16
20 if dtype == torch.int32:
21 return tl.int32
22 if dtype == torch.int64:
23 return tl.int64
24 if dtype == torch.uint8:
25 return tl.uint8
26 raise NotImplementedError(f"Unsupported dtype for Triton copy_: {dtype}")
29@triton.jit
30def _copy_kernel(
31 dst_ptr, src_ptr, n_elements, BLOCK_SIZE: tl.constexpr, DST_DTYPE: tl.constexpr
32):
33 pid = tl.program_id(axis=0)
34 block_start = pid * BLOCK_SIZE
35 offsets = block_start + tl.arange(0, BLOCK_SIZE)
36 mask = offsets < n_elements
37 vals = tl.load(src_ptr + offsets, mask=mask)
38 vals = tl.cast(vals, DST_DTYPE)
39 tl.store(dst_ptr + offsets, vals, mask=mask)
42@triton.jit
43def _fill_kernel(
44 dst_ptr, scalar_value, n_elements, BLOCK_SIZE: tl.constexpr, DST_DTYPE: tl.constexpr
45):
46 pid = tl.program_id(axis=0)
47 block_start = pid * BLOCK_SIZE
48 offsets = block_start + tl.arange(0, BLOCK_SIZE)
49 mask = offsets < n_elements
50 vals = tl.full((BLOCK_SIZE,), tl.cast(scalar_value, DST_DTYPE), DST_DTYPE)
51 tl.store(dst_ptr + offsets, vals, mask=mask)
54def _launch_copy_tensor(dst: torch.Tensor, src: torch.Tensor):
55 assert dst.is_cuda and src.is_cuda, "Triton copy_ supports CUDA tensors only."
56 assert (
57 dst.is_contiguous() and src.is_contiguous()
58 ), "Only contiguous tensors are supported."
59 n_elements = dst.numel()
60 assert (
61 src.numel() == n_elements
62 ), "Source and destination must have the same number of elements."
63 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
64 DST_DTYPE = _tl_dtype_from_torch(dst.dtype)
65 _copy_kernel[grid](
66 dst,
67 src,
68 n_elements,
69 BLOCK_SIZE=1024,
70 DST_DTYPE=DST_DTYPE,
71 )
72 return dst
75def _launch_fill_scalar(dst: torch.Tensor, scalar):
76 assert dst.is_cuda, "Triton copy_ (scalar) supports CUDA tensors only."
77 assert dst.is_contiguous(), "Only contiguous tensors are supported."
78 n_elements = dst.numel()
79 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
80 DST_DTYPE = _tl_dtype_from_torch(dst.dtype)
81 # Convert scalar to a Python number for kernel argument
82 if dst.dtype.is_floating_point:
83 scalar_val = float(scalar)
84 else:
85 scalar_val = int(scalar)
86 _fill_kernel[grid](
87 dst,
88 scalar_val,
89 n_elements,
90 BLOCK_SIZE=1024,
91 DST_DTYPE=DST_DTYPE,
92 )
93 return dst
96def copy_(self: torch.Tensor, src, non_blocking: bool = False):
97 if isinstance(src, torch.Tensor):
98 return _launch_copy_tensor(self, src)
99 elif isinstance(src, (int, bool)):
100 return _launch_fill_scalar(self, int(src))
101 elif isinstance(src, float):
102 return _launch_fill_scalar(self, float(src))
103 else:
104 raise TypeError(f"Unsupported src type for copy_: {type(src)}")
107def copy__Tensor(self: torch.Tensor, src: torch.Tensor, non_blocking: bool = False):
108 return _launch_copy_tensor(self, src)
111def copy__int(self: torch.Tensor, src: int, non_blocking: bool = False):
112 return _launch_fill_scalar(self, int(src))
115def copy__float(self: torch.Tensor, src: float, non_blocking: bool = False):
116 return _launch_fill_scalar(self, float(src))