Coverage for src/flag_gems/experimental_ops/lift_fresh_copy.py: 0%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _copy_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
12 x = tl.load(in_ptr + offsets, mask=mask)
13 tl.store(out_ptr + offsets, x, mask=mask)
16def lift_fresh_copy(*args, **kwargs):
17 # Attempt to find the input tensor from args/kwargs
18 x = None
19 if len(args) > 0 and isinstance(args[0], torch.Tensor):
20 x = args[0]
21 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor):
22 x = kwargs["self"]
23 else:
24 for v in list(args) + list(kwargs.values()):
25 if isinstance(v, torch.Tensor):
26 x = v
27 break
28 if x is None:
29 raise ValueError("lift_fresh_copy expects a Tensor argument")
31 if not x.is_cuda:
32 raise ValueError("lift_fresh_copy Triton kernel requires a CUDA tensor")
34 x_contig = x.contiguous()
35 out = torch.empty_like(x_contig, memory_format=torch.contiguous_format)
37 n_elements = x_contig.numel()
38 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
39 _copy_kernel[grid](x_contig, out, n_elements, BLOCK_SIZE=1024)
41 return out.view_as(x_contig)
44def lift_fresh_copy_out(x: torch.Tensor, out: torch.Tensor = None):
45 if x is None or not isinstance(x, torch.Tensor):
46 raise ValueError("lift_fresh_copy_out expects 'x' to be a Tensor")
47 if not x.is_cuda:
48 raise ValueError("lift_fresh_copy_out Triton kernel requires CUDA tensors")
50 x_contig = x.contiguous()
52 if out is None:
53 out = torch.empty_like(x_contig, memory_format=torch.contiguous_format)
54 else:
55 if not out.is_cuda:
56 raise ValueError("Output tensor 'out' must be on CUDA")
57 if out.dtype != x_contig.dtype:
58 raise ValueError("Output tensor 'out' must have the same dtype as input")
59 # Resize to match input shape and ensure contiguous layout
60 if out.numel() != x_contig.numel() or not out.is_contiguous():
61 out.resize_(x_contig.shape)
62 if not out.is_contiguous():
63 out = out.contiguous()
65 n_elements = x_contig.numel()
66 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
67 _copy_kernel[grid](x_contig, out, n_elements, BLOCK_SIZE=1024)
69 return out.view_as(x_contig)