Coverage for src/flag_gems/experimental_ops/permute.py: 0%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def permute_kernel( 

8 x_ptr, # *Pointer* to input tensor 

9 y_ptr, # *Pointer* to output tensor 

10 n_elements, # total number of elements 

11 ndim, # number of dimensions 

12 in_strides_perm_ptr, # int64[ndim]: input strides permuted by dims 

13 out_shape_ptr, # int64[ndim]: output shape 

14 out_postfix_ptr, # int64[ndim]: product of sizes after each axis in output 

15 BLOCK_SIZE: tl.constexpr, 

16 MAX_DIMS: tl.constexpr, 

17): 

18 pid = tl.program_id(axis=0) 

19 block_start = pid * BLOCK_SIZE 

20 offs = block_start + tl.arange(0, BLOCK_SIZE) 

21 mask = offs < n_elements 

22 

23 offs64 = offs.to(tl.int64) 

24 in_index = tl.zeros([BLOCK_SIZE], dtype=tl.int64) 

25 

26 for k in range(MAX_DIMS): 

27 cond = k < ndim 

28 step_k = tl.load(out_postfix_ptr + k, mask=cond, other=1).to(tl.int64) 

29 size_k = tl.load(out_shape_ptr + k, mask=cond, other=1).to(tl.int64) 

30 stride_k = tl.load(in_strides_perm_ptr + k, mask=cond, other=0).to(tl.int64) 

31 coord_k = (offs64 // step_k) % size_k 

32 in_index += coord_k * stride_k 

33 

34 vals = tl.load(x_ptr + in_index, mask=mask) 

35 tl.store(y_ptr + offs64, vals, mask=mask) 

36 

37 

38def permute(*args, **kwargs): 

39 # Parse arguments to support common PyTorch calling patterns 

40 if len(args) == 0: 

41 raise TypeError("permute() missing required argument: 'input'") 

42 

43 x = args[0] 

44 if not isinstance(x, torch.Tensor): 

45 raise TypeError("First argument to permute must be a torch.Tensor") 

46 

47 # Determine dims from args/kwargs 

48 dims = kwargs.get("dims", None) 

49 if dims is None: 

50 # If two args and second is sequence, treat as dims 

51 if len(args) == 2 and isinstance(args[1], (list, tuple)): 

52 dims = args[1] 

53 else: 

54 # Treat remaining positional args as dims varargs 

55 dims = args[1:] 

56 dims = tuple(int(d) for d in dims) 

57 

58 ndim = x.dim() 

59 if len(dims) != ndim: 

60 raise ValueError( 

61 f"permute(): dims length {len(dims)} does not match tensor ndim {ndim}" 

62 ) 

63 

64 # Normalize negative dims and validate permutation 

65 dims = tuple([d % ndim for d in dims]) 

66 if len(set(dims)) != ndim: 

67 raise ValueError( 

68 "permute(): dims must be a permutation of [0..ndim-1] with no repeats" 

69 ) 

70 

71 if not x.is_cuda: 

72 raise AssertionError("Input tensor must be on CUDA device") 

73 

74 device = x.device 

75 dtype = x.dtype 

76 

77 in_shape = tuple(x.shape) 

78 out_shape = tuple(in_shape[d] for d in dims) 

79 

80 # Prepare strides in elements 

81 in_strides = tuple(x.stride()) 

82 in_strides_perm = tuple(in_strides[d] for d in dims) 

83 

84 # Compute postfix products for output shape: prod(out_shape[k+1:]) 

85 out_postfix = [] 

86 p = 1 

87 for size in reversed(out_shape): 

88 out_postfix.append(p) 

89 p *= int(size) 

90 out_postfix = list(reversed(out_postfix)) 

91 

92 # Create output tensor (contiguous layout) 

93 out = torch.empty(out_shape, dtype=dtype, device=device) 

94 

95 n_elements = out.numel() 

96 if n_elements == 0: 

97 return out 

98 

99 # Move metadata to device (int64) 

100 in_strides_perm_t = torch.tensor(in_strides_perm, dtype=torch.int64, device=device) 

101 out_shape_t = torch.tensor(out_shape, dtype=torch.int64, device=device) 

102 out_postfix_t = torch.tensor(out_postfix, dtype=torch.int64, device=device) 

103 

104 BLOCK_SIZE = 1024 

105 MAX_DIMS = max(1, min(16, ndim)) # noqa: F841 cap unrolling to 16 

106 

107 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

108 permute_kernel[grid]( 

109 x, 

110 out, 

111 n_elements, 

112 ndim, 

113 in_strides_perm_t, 

114 out_shape_t, 

115 out_postfix_t, 

116 BLOCK_SIZE=BLOCK_SIZE, 

117 MAX_DIMS=16, 

118 ) 

119 return out