Coverage for src/flag_gems/experimental_ops/lift_fresh_copy.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _copy_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(in_ptr + offsets, mask=mask) 

13 tl.store(out_ptr + offsets, x, mask=mask) 

14 

15 

16def lift_fresh_copy(*args, **kwargs): 

17 # Attempt to find the input tensor from args/kwargs 

18 x = None 

19 if len(args) > 0 and isinstance(args[0], torch.Tensor): 

20 x = args[0] 

21 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor): 

22 x = kwargs["self"] 

23 else: 

24 for v in list(args) + list(kwargs.values()): 

25 if isinstance(v, torch.Tensor): 

26 x = v 

27 break 

28 if x is None: 

29 raise ValueError("lift_fresh_copy expects a Tensor argument") 

30 

31 if not x.is_cuda: 

32 raise ValueError("lift_fresh_copy Triton kernel requires a CUDA tensor") 

33 

34 x_contig = x.contiguous() 

35 out = torch.empty_like(x_contig, memory_format=torch.contiguous_format) 

36 

37 n_elements = x_contig.numel() 

38 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

39 _copy_kernel[grid](x_contig, out, n_elements, BLOCK_SIZE=1024) 

40 

41 return out.view_as(x_contig) 

42 

43 

44def lift_fresh_copy_out(x: torch.Tensor, out: torch.Tensor = None): 

45 if x is None or not isinstance(x, torch.Tensor): 

46 raise ValueError("lift_fresh_copy_out expects 'x' to be a Tensor") 

47 if not x.is_cuda: 

48 raise ValueError("lift_fresh_copy_out Triton kernel requires CUDA tensors") 

49 

50 x_contig = x.contiguous() 

51 

52 if out is None: 

53 out = torch.empty_like(x_contig, memory_format=torch.contiguous_format) 

54 else: 

55 if not out.is_cuda: 

56 raise ValueError("Output tensor 'out' must be on CUDA") 

57 if out.dtype != x_contig.dtype: 

58 raise ValueError("Output tensor 'out' must have the same dtype as input") 

59 # Resize to match input shape and ensure contiguous layout 

60 if out.numel() != x_contig.numel() or not out.is_contiguous(): 

61 out.resize_(x_contig.shape) 

62 if not out.is_contiguous(): 

63 out = out.contiguous() 

64 

65 n_elements = x_contig.numel() 

66 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

67 _copy_kernel[grid](x_contig, out, n_elements, BLOCK_SIZE=1024) 

68 

69 return out.view_as(x_contig)