Coverage for src/flag_gems/experimental_ops/permute_copy.py: 0%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def permute_copy_kernel( 

8 x_ptr, # *Pointer* to input tensor data 

9 y_ptr, # *Pointer* to output tensor data 

10 numel, # total number of elements 

11 out_shape_ptr, # int64[N] sizes of output dimensions 

12 in_strides_ptr, # int64[N] input strides (in elements) 

13 out_strides_ptr, # int64[N] output strides (in elements) 

14 perm_ptr, # int64[N] mapping from output dim -> input dim 

15 NDIMS: tl.constexpr, # number of dimensions 

16 BLOCK_SIZE: tl.constexpr, 

17): 

18 pid = tl.program_id(axis=0) 

19 block_start = pid * BLOCK_SIZE 

20 off = block_start + tl.arange(0, BLOCK_SIZE) 

21 mask = off < numel 

22 

23 # Prepare offsets 

24 tmp = off.to(tl.int64) 

25 in_off = tl.zeros([BLOCK_SIZE], dtype=tl.int64) 

26 out_off = tl.zeros([BLOCK_SIZE], dtype=tl.int64) 

27 

28 # Decompose linear index into multi-index over output shape 

29 # and accumulate input/output offsets using strides. 

30 # Iterate from last dim to first for divmod-based digit extraction. 

31 for rev_i in range(NDIMS): 

32 i = NDIMS - 1 - rev_i 

33 size_i = tl.load(out_shape_ptr + i) # scalar broadcasted to vector 

34 # Avoid div by zero if size_i could be 0 (numel==0 covered by mask; size 0 dims produce numel 0) 

35 size_i = tl.where(size_i == 0, 1, size_i) 

36 idx_i = tmp % size_i 

37 tmp = tmp // size_i 

38 

39 out_stride_i = tl.load(out_strides_ptr + i) 

40 perm_i = tl.load(perm_ptr + i) 

41 in_stride_axis = tl.load(in_strides_ptr + perm_i) 

42 

43 out_off += idx_i * out_stride_i 

44 in_off += idx_i * in_stride_axis 

45 

46 x = tl.load(x_ptr + in_off, mask=mask, other=0) 

47 tl.store(y_ptr + out_off, x, mask=mask) 

48 

49 

50def _normalize_dims(dims, ndim): 

51 if isinstance(dims, torch.Tensor): 

52 dims = dims.tolist() 

53 dims = list(dims) 

54 if len(dims) != ndim: 

55 raise ValueError(f"dims length {len(dims)} must equal tensor ndim {ndim}") 

56 norm = [] 

57 for d in dims: 

58 if d < 0: 

59 d += ndim 

60 if not (0 <= d < ndim): 

61 raise ValueError(f"dimension out of range: {d}") 

62 norm.append(d) 

63 if sorted(norm) != list(range(ndim)): 

64 raise ValueError(f"dims must be a permutation of [0..{ndim - 1}], got {norm}") 

65 return norm 

66 

67 

68def _launch_permute_copy(x: torch.Tensor, dims, out: torch.Tensor = None): 

69 assert x.is_cuda, "Input tensor must be on CUDA device for Triton kernels." 

70 dims = _normalize_dims(dims, x.dim()) 

71 out_shape = [x.size(d) for d in dims] 

72 n_elements = int( 

73 torch.tensor(out_shape, dtype=torch.int64).prod().item() 

74 if len(out_shape) > 0 

75 else 1 

76 ) 

77 

78 if out is None: 

79 out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

80 else: 

81 if not out.is_cuda: 

82 raise ValueError("Output tensor must be on CUDA device.") 

83 if tuple(out.shape) != tuple(out_shape): 

84 raise ValueError( 

85 f"Output shape {tuple(out.shape)} does not match expected {tuple(out_shape)}." 

86 ) 

87 if out.dtype != x.dtype: 

88 raise ValueError( 

89 f"Output dtype {out.dtype} must match input dtype {x.dtype}." 

90 ) 

91 if out.device != x.device: 

92 raise ValueError("Input and output must be on the same device.") 

93 

94 # Early exit for zero elements 

95 if n_elements == 0: 

96 return out 

97 

98 # Prepare metadata tensors on device (int64) 

99 NDIMS = x.dim() 

100 # Handle 0-dim tensors 

101 if NDIMS == 0: 

102 # trivial copy 

103 out.copy_(x) 

104 return out 

105 

106 out_shape_t = torch.tensor(out_shape, device=x.device, dtype=torch.int64) 

107 in_strides_t = torch.tensor(x.stride(), device=x.device, dtype=torch.int64) 

108 out_strides_t = torch.tensor(out.stride(), device=x.device, dtype=torch.int64) 

109 perm_t = torch.tensor(dims, device=x.device, dtype=torch.int64) 

110 

111 # Launch configuration 

112 BLOCK_SIZE = 1024 

113 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

114 

115 permute_copy_kernel[grid]( 

116 x, 

117 out, 

118 n_elements, 

119 out_shape_t, 

120 in_strides_t, 

121 out_strides_t, 

122 perm_t, 

123 NDIMS=NDIMS, 

124 BLOCK_SIZE=BLOCK_SIZE, 

125 ) 

126 return out 

127 

128 

129def permute_copy(self: torch.Tensor, dims): 

130 return _launch_permute_copy(self, dims, out=None) 

131 

132 

133def permute_copy_out(self: torch.Tensor, dims, out: torch.Tensor): 

134 return _launch_permute_copy(self, dims, out=out)