Coverage for src/flag_gems/experimental_ops/alias_copy.py: 0%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _alias_copy_kernel(src_ptr, dst_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 vals = tl.load(src_ptr + offsets, mask=mask) 

13 tl.store(dst_ptr + offsets, vals, mask=mask) 

14 

15 

16def alias_copy(x: torch.Tensor): 

17 """ 

18 Wrapper for aten::alias_copy 

19 Creates and returns a copy of `x` with identical content. 

20 """ 

21 if not x.is_cuda: 

22 raise RuntimeError("alias_copy: Triton kernel requires CUDA tensors.") 

23 out = torch.empty_like(x) 

24 n_elements = out.numel() 

25 if n_elements == 0: 

26 return out 

27 # Ensure contiguous memory for efficient linear copy 

28 src = x.contiguous() if not x.is_contiguous() else x 

29 if not out.is_contiguous(): 

30 out = out.contiguous() 

31 if src.dtype != out.dtype: 

32 raise RuntimeError("alias_copy: dtype mismatch between input and output.") 

33 if src.device != out.device: 

34 raise RuntimeError("alias_copy: input and output must be on the same device.") 

35 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

36 _alias_copy_kernel[grid](src, out, n_elements, BLOCK_SIZE=1024) 

37 return out 

38 

39 

40def alias_copy_out(x: torch.Tensor, out: torch.Tensor): 

41 """ 

42 Wrapper for aten::alias_copy.out 

43 Copies `x` into `out` and returns `out`. 

44 """ 

45 if not x.is_cuda or not out.is_cuda: 

46 raise RuntimeError("alias_copy_out: Triton kernel requires CUDA tensors.") 

47 if x.dtype != out.dtype: 

48 raise RuntimeError("alias_copy_out: dtype of input and output must match.") 

49 if x.numel() != out.numel(): 

50 raise RuntimeError( 

51 "alias_copy_out: input and output must have the same number of elements." 

52 ) 

53 if x.device != out.device: 

54 raise RuntimeError( 

55 "alias_copy_out: input and output must be on the same device." 

56 ) 

57 if not out.is_contiguous(): 

58 raise RuntimeError("alias_copy_out: output tensor must be contiguous.") 

59 src = x.contiguous() if not x.is_contiguous() else x 

60 n_elements = out.numel() 

61 if n_elements == 0: 

62 return out 

63 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

64 _alias_copy_kernel[grid](src, out, n_elements, BLOCK_SIZE=1024) 

65 return out