Coverage for src/flag_gems/experimental_ops/permute_copy.py: 0%
74 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def permute_copy_kernel(
8 x_ptr, # *Pointer* to input tensor data
9 y_ptr, # *Pointer* to output tensor data
10 numel, # total number of elements
11 out_shape_ptr, # int64[N] sizes of output dimensions
12 in_strides_ptr, # int64[N] input strides (in elements)
13 out_strides_ptr, # int64[N] output strides (in elements)
14 perm_ptr, # int64[N] mapping from output dim -> input dim
15 NDIMS: tl.constexpr, # number of dimensions
16 BLOCK_SIZE: tl.constexpr,
17):
18 pid = tl.program_id(axis=0)
19 block_start = pid * BLOCK_SIZE
20 off = block_start + tl.arange(0, BLOCK_SIZE)
21 mask = off < numel
23 # Prepare offsets
24 tmp = off.to(tl.int64)
25 in_off = tl.zeros([BLOCK_SIZE], dtype=tl.int64)
26 out_off = tl.zeros([BLOCK_SIZE], dtype=tl.int64)
28 # Decompose linear index into multi-index over output shape
29 # and accumulate input/output offsets using strides.
30 # Iterate from last dim to first for divmod-based digit extraction.
31 for rev_i in range(NDIMS):
32 i = NDIMS - 1 - rev_i
33 size_i = tl.load(out_shape_ptr + i) # scalar broadcasted to vector
34 # Avoid div by zero if size_i could be 0 (numel==0 covered by mask; size 0 dims produce numel 0)
35 size_i = tl.where(size_i == 0, 1, size_i)
36 idx_i = tmp % size_i
37 tmp = tmp // size_i
39 out_stride_i = tl.load(out_strides_ptr + i)
40 perm_i = tl.load(perm_ptr + i)
41 in_stride_axis = tl.load(in_strides_ptr + perm_i)
43 out_off += idx_i * out_stride_i
44 in_off += idx_i * in_stride_axis
46 x = tl.load(x_ptr + in_off, mask=mask, other=0)
47 tl.store(y_ptr + out_off, x, mask=mask)
50def _normalize_dims(dims, ndim):
51 if isinstance(dims, torch.Tensor):
52 dims = dims.tolist()
53 dims = list(dims)
54 if len(dims) != ndim:
55 raise ValueError(f"dims length {len(dims)} must equal tensor ndim {ndim}")
56 norm = []
57 for d in dims:
58 if d < 0:
59 d += ndim
60 if not (0 <= d < ndim):
61 raise ValueError(f"dimension out of range: {d}")
62 norm.append(d)
63 if sorted(norm) != list(range(ndim)):
64 raise ValueError(f"dims must be a permutation of [0..{ndim - 1}], got {norm}")
65 return norm
68def _launch_permute_copy(x: torch.Tensor, dims, out: torch.Tensor = None):
69 assert x.is_cuda, "Input tensor must be on CUDA device for Triton kernels."
70 dims = _normalize_dims(dims, x.dim())
71 out_shape = [x.size(d) for d in dims]
72 n_elements = int(
73 torch.tensor(out_shape, dtype=torch.int64).prod().item()
74 if len(out_shape) > 0
75 else 1
76 )
78 if out is None:
79 out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
80 else:
81 if not out.is_cuda:
82 raise ValueError("Output tensor must be on CUDA device.")
83 if tuple(out.shape) != tuple(out_shape):
84 raise ValueError(
85 f"Output shape {tuple(out.shape)} does not match expected {tuple(out_shape)}."
86 )
87 if out.dtype != x.dtype:
88 raise ValueError(
89 f"Output dtype {out.dtype} must match input dtype {x.dtype}."
90 )
91 if out.device != x.device:
92 raise ValueError("Input and output must be on the same device.")
94 # Early exit for zero elements
95 if n_elements == 0:
96 return out
98 # Prepare metadata tensors on device (int64)
99 NDIMS = x.dim()
100 # Handle 0-dim tensors
101 if NDIMS == 0:
102 # trivial copy
103 out.copy_(x)
104 return out
106 out_shape_t = torch.tensor(out_shape, device=x.device, dtype=torch.int64)
107 in_strides_t = torch.tensor(x.stride(), device=x.device, dtype=torch.int64)
108 out_strides_t = torch.tensor(out.stride(), device=x.device, dtype=torch.int64)
109 perm_t = torch.tensor(dims, device=x.device, dtype=torch.int64)
111 # Launch configuration
112 BLOCK_SIZE = 1024
113 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
115 permute_copy_kernel[grid](
116 x,
117 out,
118 n_elements,
119 out_shape_t,
120 in_strides_t,
121 out_strides_t,
122 perm_t,
123 NDIMS=NDIMS,
124 BLOCK_SIZE=BLOCK_SIZE,
125 )
126 return out
129def permute_copy(self: torch.Tensor, dims):
130 return _launch_permute_copy(self, dims, out=None)
133def permute_copy_out(self: torch.Tensor, dims, out: torch.Tensor):
134 return _launch_permute_copy(self, dims, out=out)