Coverage for src/flag_gems/experimental_ops/lift.py: 0%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _copy_kernel(src_ptr, dst_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(src_ptr + offsets, mask=mask) 

13 tl.store(dst_ptr + offsets, x, mask=mask) 

14 

15 

16def lift(x: torch.Tensor): 

17 if not isinstance(x, torch.Tensor): 

18 raise TypeError("lift expects a single Tensor argument") 

19 if x.device.type != "cuda": 

20 raise RuntimeError("lift: input tensor must be on a CUDA device") 

21 out = torch.empty_like(x) 

22 n_elements = out.numel() 

23 if n_elements > 0: 

24 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

25 _copy_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024) 

26 return out 

27 

28 

29def lift_out(x: torch.Tensor, out: torch.Tensor): 

30 if not isinstance(x, torch.Tensor) or not isinstance(out, torch.Tensor): 

31 raise TypeError("lift_out expects (Tensor x, Tensor out)") 

32 if x.device.type != "cuda" or out.device.type != "cuda": 

33 raise RuntimeError( 

34 "lift_out: both input and out tensors must be on a CUDA device" 

35 ) 

36 if out.device != x.device: 

37 raise RuntimeError("lift_out: out tensor must be on the same device as input") 

38 if out.dtype != x.dtype: 

39 raise RuntimeError("lift_out: out tensor must have the same dtype as input") 

40 # Resize out to match shape; this ensures a contiguous layout and correct size. 

41 if tuple(out.shape) != tuple(x.shape): 

42 out.resize_(x.shape) 

43 n_elements = x.numel() 

44 if n_elements > 0: 

45 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

46 _copy_kernel[grid](x, out, n_elements, BLOCK_SIZE=1024) 

47 return out