Coverage for src/flag_gems/ops/split_with_sizes_copy.py: 84%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@libentry() 

14@triton.jit 

15def split_copy_kernel( 

16 out_ptr, 

17 inp_ptr, 

18 n_elements, 

19 BLOCK_SIZE: tl.constexpr, 

20): 

21 """ 

22 Copy elements from input to output. 

23 Both input and output are expected to be contiguous and have the same shape. 

24 """ 

25 pid = tl.program_id(0) 

26 block_start = pid * BLOCK_SIZE 

27 offsets = tl.arange(0, BLOCK_SIZE) 

28 mask = block_start + offsets < n_elements 

29 

30 data = tl.load(inp_ptr + block_start + offsets, mask=mask) 

31 tl.store(out_ptr + block_start + offsets, data, mask=mask) 

32 

33 

34def split_with_sizes_copy(inp, split_sizes, dim=0): 

35 logger.debug("GEMS SPLIT_WITH_SIZES_COPY") 

36 

37 if dim < 0: 

38 dim = dim + inp.ndim 

39 

40 # Get split sizes as a list 

41 if isinstance(split_sizes, torch.Tensor): 

42 split_sizes = split_sizes.tolist() 

43 

44 # Handle SymInt[] - convert to Python list 

45 if hasattr(split_sizes, "__iter__"): 

46 split_sizes = list(split_sizes) 

47 

48 result = [] 

49 offset = 0 

50 for size in split_sizes: 

51 if size == 0: 

52 # Handle zero-size split 

53 out_shape = list(inp.shape) 

54 out_shape[dim] = 0 

55 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

56 result.append(out) 

57 continue 

58 

59 # Extract the split region using narrow (creates a view) 

60 split_view = inp.narrow(dim, offset, size) 

61 

62 # Make the view contiguous before copying 

63 split_view = split_view.contiguous() 

64 

65 # Create output tensor with same shape as the view 

66 out = torch.empty_like(split_view) 

67 

68 # Copy data using Triton kernel 

69 n_elements = out.numel() 

70 if n_elements > 0: 

71 # Standard block size for element-wise copy kernel 

72 BLOCK_SIZE = 1024 

73 grid = (triton.cdiv(n_elements, BLOCK_SIZE),) 

74 split_copy_kernel[grid]( 

75 out, 

76 split_view, 

77 n_elements, 

78 BLOCK_SIZE=BLOCK_SIZE, 

79 ) 

80 

81 result.append(out) 

82 offset += size 

83 

84 return tuple(result)