Coverage for src/flag_gems/experimental_ops/lift.py: 0%
38 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _copy_kernel(src_ptr, dst_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
12 x = tl.load(src_ptr + offsets, mask=mask)
13 tl.store(dst_ptr + offsets, x, mask=mask)
16def lift(x: torch.Tensor):
17 if not isinstance(x, torch.Tensor):
18 raise TypeError("lift expects a single Tensor argument")
19 if x.device.type != "cuda":
20 raise RuntimeError("lift: input tensor must be on a CUDA device")
21 out = torch.empty_like(x)
22 n_elements = out.numel()
23 if n_elements > 0:
24 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
25 _copy_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
26 return out
29def lift_out(x: torch.Tensor, out: torch.Tensor):
30 if not isinstance(x, torch.Tensor) or not isinstance(out, torch.Tensor):
31 raise TypeError("lift_out expects (Tensor x, Tensor out)")
32 if x.device.type != "cuda" or out.device.type != "cuda":
33 raise RuntimeError(
34 "lift_out: both input and out tensors must be on a CUDA device"
35 )
36 if out.device != x.device:
37 raise RuntimeError("lift_out: out tensor must be on the same device as input")
38 if out.dtype != x.dtype:
39 raise RuntimeError("lift_out: out tensor must have the same dtype as input")
40 # Resize out to match shape; this ensures a contiguous layout and correct size.
41 if tuple(out.shape) != tuple(x.shape):
42 out.resize_(x.shape)
43 n_elements = x.numel()
44 if n_elements > 0:
45 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
46 _copy_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024)
47 return out