Coverage for src/flag_gems/ops/split_with_sizes_copy.py: 84%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import libentry
10logger = logging.getLogger(__name__)
13@libentry()
14@triton.jit
15def split_copy_kernel(
16 out_ptr,
17 inp_ptr,
18 n_elements,
19 BLOCK_SIZE: tl.constexpr,
20):
21 """
22 Copy elements from input to output.
23 Both input and output are expected to be contiguous and have the same shape.
24 """
25 pid = tl.program_id(0)
26 block_start = pid * BLOCK_SIZE
27 offsets = tl.arange(0, BLOCK_SIZE)
28 mask = block_start + offsets < n_elements
30 data = tl.load(inp_ptr + block_start + offsets, mask=mask)
31 tl.store(out_ptr + block_start + offsets, data, mask=mask)
34def split_with_sizes_copy(inp, split_sizes, dim=0):
35 logger.debug("GEMS SPLIT_WITH_SIZES_COPY")
37 if dim < 0:
38 dim = dim + inp.ndim
40 # Get split sizes as a list
41 if isinstance(split_sizes, torch.Tensor):
42 split_sizes = split_sizes.tolist()
44 # Handle SymInt[] - convert to Python list
45 if hasattr(split_sizes, "__iter__"):
46 split_sizes = list(split_sizes)
48 result = []
49 offset = 0
50 for size in split_sizes:
51 if size == 0:
52 # Handle zero-size split
53 out_shape = list(inp.shape)
54 out_shape[dim] = 0
55 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
56 result.append(out)
57 continue
59 # Extract the split region using narrow (creates a view)
60 split_view = inp.narrow(dim, offset, size)
62 # Make the view contiguous before copying
63 split_view = split_view.contiguous()
65 # Create output tensor with same shape as the view
66 out = torch.empty_like(split_view)
68 # Copy data using Triton kernel
69 n_elements = out.numel()
70 if n_elements > 0:
71 # Standard block size for element-wise copy kernel
72 BLOCK_SIZE = 1024
73 grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
74 split_copy_kernel[grid](
75 out,
76 split_view,
77 n_elements,
78 BLOCK_SIZE=BLOCK_SIZE,
79 )
81 result.append(out)
82 offset += size
84 return tuple(result)