Coverage for src/flag_gems/experimental_ops/fft_ifftshift.py: 0%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def fft_ifftshift(
8 in_ptr_u8, # pointer to input tensor as bytes
9 out_ptr_u8, # pointer to output tensor as bytes
10 sizes_ptr, # int64[NDIMS]
11 in_strides_ptr, # int64[NDIMS], in elements
12 out_strides_ptr, # int64[NDIMS], in elements
13 adds_ptr, # int64[NDIMS], per-dim add = floor(size/2) if shifted else 0
14 n_elements, # total number of elements
15 ELEMENT_SIZE: tl.constexpr, # number of bytes per element
16 NDIMS: tl.constexpr, # number of dimensions
17 BLOCK_SIZE: tl.constexpr, # tile size
18):
19 pid = tl.program_id(axis=0)
20 block_start = pid * BLOCK_SIZE
21 offs = block_start + tl.arange(0, BLOCK_SIZE)
22 mask = offs < n_elements
23 offs64 = offs.to(tl.int64)
25 # Compute multi-dimensional indices from linear index (row-major)
26 tmp = offs64
27 in_off_elems = tl.zeros([BLOCK_SIZE], dtype=tl.int64)
28 out_off_elems = tl.zeros([BLOCK_SIZE], dtype=tl.int64)
30 # Iterate from last dim to first to compute indices
31 for d in range(NDIMS - 1, -1, -1):
32 size_d = tl.load(sizes_ptr + d) # scalar int64
33 idx_d = tmp % size_d
34 tmp = tmp // size_d
36 add_d = tl.load(adds_ptr + d) # scalar int64
37 idx_in_d = idx_d + add_d
38 # modulo to wrap within size_d
39 idx_in_d = idx_in_d - (idx_in_d // size_d) * size_d
41 in_stride_d = tl.load(in_strides_ptr + d)
42 out_stride_d = tl.load(out_strides_ptr + d)
44 in_off_elems += idx_in_d * in_stride_d
45 out_off_elems += idx_d * out_stride_d
47 in_byte_base = in_off_elems * ELEMENT_SIZE
48 out_byte_base = out_off_elems * ELEMENT_SIZE
50 # Copy ELEMENT_SIZE bytes per element
51 for b in range(ELEMENT_SIZE):
52 src_addr = in_ptr_u8 + in_byte_base + b
53 dst_addr = out_ptr_u8 + out_byte_base + b
54 val = tl.load(src_addr, mask=mask, other=0)
55 tl.store(dst_addr, val, mask=mask)
58# Keep a handle to the kernel before defining the wrapper with the same name
59fft_ifftshift_kernel = fft_ifftshift
62def fft_ifftshift(*args, **kwargs):
63 x = None
64 dims = None
66 # Parse input tensor
67 if len(args) >= 1:
68 x = args[0]
69 else:
70 # try kwargs
71 x = (
72 kwargs.get("input", None)
73 or kwargs.get("self", None)
74 or kwargs.get("tensor", None)
75 )
76 if x is None:
77 raise ValueError("fft_ifftshift expects at least one tensor argument as input.")
79 # Parse dims (can be in args[1], or kwargs 'dim'/'dims')
80 if len(args) >= 2:
81 dims = args[1]
82 else:
83 dims = kwargs.get("dim", kwargs.get("dims", None))
85 # Normalize dims
86 if dims is None:
87 dims_list = list(range(x.ndim))
88 else:
89 if isinstance(dims, int):
90 dims_list = [dims]
91 else:
92 dims_list = list(dims)
93 # normalize negative dims
94 dims_list = [(d + x.ndim) % x.ndim for d in dims_list]
95 # remove duplicates while preserving order
96 seen = set()
97 tmp = []
98 for d in dims_list:
99 if d not in seen:
100 tmp.append(d)
101 seen.add(d)
102 dims_list = tmp
104 # Handle scalars or empty tensors quickly
105 if x.ndim == 0 or x.numel() == 0:
106 return x.clone()
108 device = x.device
109 dtype = x.dtype # noqa: F841
110 out = torch.empty_like(x)
112 # Prepare metadata
113 sizes = torch.tensor(list(x.shape), device=device, dtype=torch.int64)
114 in_strides = torch.tensor(list(x.stride()), device=device, dtype=torch.int64)
115 out_strides = torch.tensor(list(out.stride()), device=device, dtype=torch.int64)
117 # Per-dimension add amount = floor(size/2) if dimension is included, else 0
118 add_list = [
119 (sizes[d].item() // 2) if d in set(dims_list) else 0 for d in range(x.ndim)
120 ]
121 adds = torch.tensor(add_list, device=device, dtype=torch.int64)
123 n_elements = x.numel()
124 NDIMS = x.ndim
125 ELEMENT_SIZE = x.element_size()
127 # Use byte pointers by viewing as uint8 without changing storage
128 x_u8 = x.view(torch.uint8)
129 out_u8 = out.view(torch.uint8)
131 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
132 BLOCK_SIZE = 1024
134 fft_ifftshift_kernel[grid](
135 x_u8,
136 out_u8,
137 sizes,
138 in_strides,
139 out_strides,
140 adds,
141 n_elements,
142 ELEMENT_SIZE=ELEMENT_SIZE,
143 NDIMS=NDIMS,
144 BLOCK_SIZE=BLOCK_SIZE,
145 )
146 return out