Coverage for src/flag_gems/experimental_ops/permute.py: 0%
64 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def permute_kernel(
8 x_ptr, # *Pointer* to input tensor
9 y_ptr, # *Pointer* to output tensor
10 n_elements, # total number of elements
11 ndim, # number of dimensions
12 in_strides_perm_ptr, # int64[ndim]: input strides permuted by dims
13 out_shape_ptr, # int64[ndim]: output shape
14 out_postfix_ptr, # int64[ndim]: product of sizes after each axis in output
15 BLOCK_SIZE: tl.constexpr,
16 MAX_DIMS: tl.constexpr,
17):
18 pid = tl.program_id(axis=0)
19 block_start = pid * BLOCK_SIZE
20 offs = block_start + tl.arange(0, BLOCK_SIZE)
21 mask = offs < n_elements
23 offs64 = offs.to(tl.int64)
24 in_index = tl.zeros([BLOCK_SIZE], dtype=tl.int64)
26 for k in range(MAX_DIMS):
27 cond = k < ndim
28 step_k = tl.load(out_postfix_ptr + k, mask=cond, other=1).to(tl.int64)
29 size_k = tl.load(out_shape_ptr + k, mask=cond, other=1).to(tl.int64)
30 stride_k = tl.load(in_strides_perm_ptr + k, mask=cond, other=0).to(tl.int64)
31 coord_k = (offs64 // step_k) % size_k
32 in_index += coord_k * stride_k
34 vals = tl.load(x_ptr + in_index, mask=mask)
35 tl.store(y_ptr + offs64, vals, mask=mask)
38def permute(*args, **kwargs):
39 # Parse arguments to support common PyTorch calling patterns
40 if len(args) == 0:
41 raise TypeError("permute() missing required argument: 'input'")
43 x = args[0]
44 if not isinstance(x, torch.Tensor):
45 raise TypeError("First argument to permute must be a torch.Tensor")
47 # Determine dims from args/kwargs
48 dims = kwargs.get("dims", None)
49 if dims is None:
50 # If two args and second is sequence, treat as dims
51 if len(args) == 2 and isinstance(args[1], (list, tuple)):
52 dims = args[1]
53 else:
54 # Treat remaining positional args as dims varargs
55 dims = args[1:]
56 dims = tuple(int(d) for d in dims)
58 ndim = x.dim()
59 if len(dims) != ndim:
60 raise ValueError(
61 f"permute(): dims length {len(dims)} does not match tensor ndim {ndim}"
62 )
64 # Normalize negative dims and validate permutation
65 dims = tuple([d % ndim for d in dims])
66 if len(set(dims)) != ndim:
67 raise ValueError(
68 "permute(): dims must be a permutation of [0..ndim-1] with no repeats"
69 )
71 if not x.is_cuda:
72 raise AssertionError("Input tensor must be on CUDA device")
74 device = x.device
75 dtype = x.dtype
77 in_shape = tuple(x.shape)
78 out_shape = tuple(in_shape[d] for d in dims)
80 # Prepare strides in elements
81 in_strides = tuple(x.stride())
82 in_strides_perm = tuple(in_strides[d] for d in dims)
84 # Compute postfix products for output shape: prod(out_shape[k+1:])
85 out_postfix = []
86 p = 1
87 for size in reversed(out_shape):
88 out_postfix.append(p)
89 p *= int(size)
90 out_postfix = list(reversed(out_postfix))
92 # Create output tensor (contiguous layout)
93 out = torch.empty(out_shape, dtype=dtype, device=device)
95 n_elements = out.numel()
96 if n_elements == 0:
97 return out
99 # Move metadata to device (int64)
100 in_strides_perm_t = torch.tensor(in_strides_perm, dtype=torch.int64, device=device)
101 out_shape_t = torch.tensor(out_shape, dtype=torch.int64, device=device)
102 out_postfix_t = torch.tensor(out_postfix, dtype=torch.int64, device=device)
104 BLOCK_SIZE = 1024
105 MAX_DIMS = max(1, min(16, ndim)) # noqa: F841 cap unrolling to 16
107 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
108 permute_kernel[grid](
109 x,
110 out,
111 n_elements,
112 ndim,
113 in_strides_perm_t,
114 out_shape_t,
115 out_postfix_t,
116 BLOCK_SIZE=BLOCK_SIZE,
117 MAX_DIMS=16,
118 )
119 return out