Coverage for src/flag_gems/experimental_ops/fft_ifftshift.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def fft_ifftshift( 

8 in_ptr_u8, # pointer to input tensor as bytes 

9 out_ptr_u8, # pointer to output tensor as bytes 

10 sizes_ptr, # int64[NDIMS] 

11 in_strides_ptr, # int64[NDIMS], in elements 

12 out_strides_ptr, # int64[NDIMS], in elements 

13 adds_ptr, # int64[NDIMS], per-dim add = floor(size/2) if shifted else 0 

14 n_elements, # total number of elements 

15 ELEMENT_SIZE: tl.constexpr, # number of bytes per element 

16 NDIMS: tl.constexpr, # number of dimensions 

17 BLOCK_SIZE: tl.constexpr, # tile size 

18): 

19 pid = tl.program_id(axis=0) 

20 block_start = pid * BLOCK_SIZE 

21 offs = block_start + tl.arange(0, BLOCK_SIZE) 

22 mask = offs < n_elements 

23 offs64 = offs.to(tl.int64) 

24 

25 # Compute multi-dimensional indices from linear index (row-major) 

26 tmp = offs64 

27 in_off_elems = tl.zeros([BLOCK_SIZE], dtype=tl.int64) 

28 out_off_elems = tl.zeros([BLOCK_SIZE], dtype=tl.int64) 

29 

30 # Iterate from last dim to first to compute indices 

31 for d in range(NDIMS - 1, -1, -1): 

32 size_d = tl.load(sizes_ptr + d) # scalar int64 

33 idx_d = tmp % size_d 

34 tmp = tmp // size_d 

35 

36 add_d = tl.load(adds_ptr + d) # scalar int64 

37 idx_in_d = idx_d + add_d 

38 # modulo to wrap within size_d 

39 idx_in_d = idx_in_d - (idx_in_d // size_d) * size_d 

40 

41 in_stride_d = tl.load(in_strides_ptr + d) 

42 out_stride_d = tl.load(out_strides_ptr + d) 

43 

44 in_off_elems += idx_in_d * in_stride_d 

45 out_off_elems += idx_d * out_stride_d 

46 

47 in_byte_base = in_off_elems * ELEMENT_SIZE 

48 out_byte_base = out_off_elems * ELEMENT_SIZE 

49 

50 # Copy ELEMENT_SIZE bytes per element 

51 for b in range(ELEMENT_SIZE): 

52 src_addr = in_ptr_u8 + in_byte_base + b 

53 dst_addr = out_ptr_u8 + out_byte_base + b 

54 val = tl.load(src_addr, mask=mask, other=0) 

55 tl.store(dst_addr, val, mask=mask) 

56 

57 

58# Keep a handle to the kernel before defining the wrapper with the same name 

59fft_ifftshift_kernel = fft_ifftshift 

60 

61 

62def fft_ifftshift(*args, **kwargs): 

63 x = None 

64 dims = None 

65 

66 # Parse input tensor 

67 if len(args) >= 1: 

68 x = args[0] 

69 else: 

70 # try kwargs 

71 x = ( 

72 kwargs.get("input", None) 

73 or kwargs.get("self", None) 

74 or kwargs.get("tensor", None) 

75 ) 

76 if x is None: 

77 raise ValueError("fft_ifftshift expects at least one tensor argument as input.") 

78 

79 # Parse dims (can be in args[1], or kwargs 'dim'/'dims') 

80 if len(args) >= 2: 

81 dims = args[1] 

82 else: 

83 dims = kwargs.get("dim", kwargs.get("dims", None)) 

84 

85 # Normalize dims 

86 if dims is None: 

87 dims_list = list(range(x.ndim)) 

88 else: 

89 if isinstance(dims, int): 

90 dims_list = [dims] 

91 else: 

92 dims_list = list(dims) 

93 # normalize negative dims 

94 dims_list = [(d + x.ndim) % x.ndim for d in dims_list] 

95 # remove duplicates while preserving order 

96 seen = set() 

97 tmp = [] 

98 for d in dims_list: 

99 if d not in seen: 

100 tmp.append(d) 

101 seen.add(d) 

102 dims_list = tmp 

103 

104 # Handle scalars or empty tensors quickly 

105 if x.ndim == 0 or x.numel() == 0: 

106 return x.clone() 

107 

108 device = x.device 

109 dtype = x.dtype # noqa: F841 

110 out = torch.empty_like(x) 

111 

112 # Prepare metadata 

113 sizes = torch.tensor(list(x.shape), device=device, dtype=torch.int64) 

114 in_strides = torch.tensor(list(x.stride()), device=device, dtype=torch.int64) 

115 out_strides = torch.tensor(list(out.stride()), device=device, dtype=torch.int64) 

116 

117 # Per-dimension add amount = floor(size/2) if dimension is included, else 0 

118 add_list = [ 

119 (sizes[d].item() // 2) if d in set(dims_list) else 0 for d in range(x.ndim) 

120 ] 

121 adds = torch.tensor(add_list, device=device, dtype=torch.int64) 

122 

123 n_elements = x.numel() 

124 NDIMS = x.ndim 

125 ELEMENT_SIZE = x.element_size() 

126 

127 # Use byte pointers by viewing as uint8 without changing storage 

128 x_u8 = x.view(torch.uint8) 

129 out_u8 = out.view(torch.uint8) 

130 

131 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

132 BLOCK_SIZE = 1024 

133 

134 fft_ifftshift_kernel[grid]( 

135 x_u8, 

136 out_u8, 

137 sizes, 

138 in_strides, 

139 out_strides, 

140 adds, 

141 n_elements, 

142 ELEMENT_SIZE=ELEMENT_SIZE, 

143 NDIMS=NDIMS, 

144 BLOCK_SIZE=BLOCK_SIZE, 

145 ) 

146 return out