Coverage for src/flag_gems/ops/resolve_conj.py: 13%
102 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.jit
11def resolve_conj_kernel_1d(
12 x_real_ptr, # Real part input pointer (float32, separate storage)
13 x_img_ptr, # Imaginary part input pointer (float32, separate storage)
14 output_ptr, # Output pointer (maintain original interleaved layout, float32 view)
15 n_elements_total, # Total number of elements (number of complex pairs)
16 is_conj: tl.constexpr, # Whether to set conjugate flag
17 BLOCK_SIZE: tl.constexpr, # Block size
18):
19 # Get PID of current program
20 pid = tl.program_id(axis=0)
22 # Create element index range for current block (complex element index, not float32 index)
23 block_start = pid * BLOCK_SIZE
24 offsets = block_start + tl.arange(0, BLOCK_SIZE)
26 # Create mask to prevent out-of-bounds access
27 mask = offsets < n_elements_total
29 # Input: Load real/imaginary parts directly via separate pointers (no need ×2, already separate storage)
30 real = tl.load(x_real_ptr + offsets, mask=mask)
31 imag = tl.load(x_img_ptr + offsets, mask=mask)
33 # Output: Maintain original interleaved layout (real part at even indices, imaginary part at odd indices)
34 output_real_offsets = 2 * offsets
35 output_img_offsets = 2 * offsets + 1
37 if is_conj:
38 # Conjugate: Real part unchanged, imaginary part negated, stored in original layout
39 tl.store(output_ptr + output_real_offsets, real, mask=mask)
40 tl.store(output_ptr + output_img_offsets, -imag, mask=mask)
41 else:
42 # Direct copy, maintain original layout
43 tl.store(output_ptr + output_real_offsets, real, mask=mask)
44 tl.store(output_ptr + output_img_offsets, imag, mask=mask)
47@triton.jit
48def resolve_conj_kernel_2d_strided(
49 x_real_ptr, # Real part input pointer (float32, separate storage)
50 x_img_ptr, # Imaginary part input pointer (float32, separate storage)
51 output_ptr, # Output pointer (maintain original interleaved layout, float32 view)
52 n_rows, # Number of rows
53 n_cols, # Number of columns
54 stride_row, # Row stride (in complex elements)
55 stride_col, # Column stride (in complex elements)
56 is_conj: tl.constexpr, # Whether to set conjugate flag
57 BLOCK_SIZE: tl.constexpr, # Block size
58):
59 # Get 2D PID of current program
60 pid_row = tl.program_id(axis=0)
61 pid_col_block = tl.program_id(axis=1)
63 # Calculate column index range (complex element index)
64 col_start = pid_col_block * BLOCK_SIZE
65 col_offsets = col_start + tl.arange(0, BLOCK_SIZE)
67 # Create column mask
68 col_mask = col_offsets < n_cols
70 # Input: Calculate base offset of complex elements (no need ×2, real/imaginary parts separated)
71 base_offset = pid_row * stride_row + col_offsets * stride_col
73 # Create full mask
74 mask = col_mask & (pid_row < n_rows)
76 # Load separated real and imaginary parts
77 real = tl.load(x_real_ptr + base_offset, mask=mask)
78 imag = tl.load(x_img_ptr + base_offset, mask=mask)
80 # Output: Convert to interleaved layout offset (×2, real part first, imaginary part second)
81 output_base_offset = base_offset * 2
83 if is_conj:
84 tl.store(output_ptr + output_base_offset, real, mask=mask)
85 tl.store(output_ptr + output_base_offset + 1, -imag, mask=mask)
86 else:
87 tl.store(output_ptr + output_base_offset, real, mask=mask)
88 tl.store(output_ptr + output_base_offset + 1, imag, mask=mask)
91@triton.jit
92def resolve_conj_kernel_large_2d(
93 x_real_ptr, # Real part input pointer (float32, separate storage)
94 x_img_ptr, # Imaginary part input pointer (float32, separate storage)
95 output_ptr, # Output pointer (maintain original interleaved layout, float32 view)
96 n_rows, # Number of rows
97 n_cols, # Number of columns
98 stride_row, # Row stride (in complex elements)
99 stride_col, # Column stride (in complex elements)
100 is_conj: tl.constexpr, # Whether to set conjugate flag
101 BLOCK_SIZE_ROWS: tl.constexpr, # Row block size
102 BLOCK_SIZE_COLS: tl.constexpr, # Column block size
103):
104 # Get 2D PID of current program
105 pid_row = tl.program_id(axis=0)
106 pid_col = tl.program_id(axis=1)
108 # Calculate row and column index ranges (complex element index)
109 row_offsets = pid_row * BLOCK_SIZE_ROWS + tl.arange(0, BLOCK_SIZE_ROWS)
110 col_offsets = pid_col * BLOCK_SIZE_COLS + tl.arange(0, BLOCK_SIZE_COLS)
112 # Create row and column masks
113 row_mask = row_offsets < n_rows
114 col_mask = col_offsets < n_cols
116 # Input: Calculate base offset of complex elements (no need ×2, real/imaginary parts separated)
117 base_offsets = row_offsets[:, None] * stride_row + col_offsets[None, :] * stride_col
119 # Create full mask
120 mask = row_mask[:, None] & col_mask[None, :]
122 # Load separated real and imaginary parts
123 real = tl.load(x_real_ptr + base_offsets, mask=mask)
124 imag = tl.load(x_img_ptr + base_offsets, mask=mask)
126 # Output: Convert to interleaved layout offset (×2)
127 output_base_offsets = base_offsets * 2
129 if is_conj:
130 tl.store(output_ptr + output_base_offsets, real, mask=mask)
131 tl.store(output_ptr + output_base_offsets + 1, -imag, mask=mask)
132 else:
133 tl.store(output_ptr + output_base_offsets, real, mask=mask)
134 tl.store(output_ptr + output_base_offsets + 1, imag, mask=mask)
137def resolve_conj_triton(x: torch.Tensor, is_conj: bool) -> torch.Tensor:
138 """
139 resolve_conj function implemented with Triton, supporting arbitrary shapes
140 Input: Separate real/imaginary parts (avoid x.view()), Output: Maintain original complex tensor structure
142 Args:
143 x: Input tensor
144 is_conj: Whether conjugate flag is set
146 Returns:
147 Resolved tensor (structure consistent with input)
148 """
149 # Ensure tensor is on GPU
150 if not x.is_cuda:
151 x = x.cuda()
153 # Check if it is complex type
154 is_complex = x.is_complex()
156 # If no conjugate needed and is real, return copy directly
157 if not is_conj and not is_complex:
158 return x.clone()
160 if not is_complex:
161 return x.clone()
163 # Output maintains original structure (unchanged), still complex tensor
164 output = torch.empty_like(x)
166 if x.dtype == torch.complex64:
167 # Input separate real/imaginary parts (avoid view(), get float32 tensor directly with .real/.imag)
168 x_real = x.real # shape same as x, dtype=float32 (real part separate storage)
169 x_img = (
170 x.imag
171 ) # shape same as x, dtype=float32 (imaginary part separate storage)
173 # Output still use view() to convert to float32 pointer (only for kernel storage, no change to output structure)
174 output_view = output.view(torch.float32)
176 # Get tensor shape and total number of elements
177 shape = x.shape
178 n_elements_total = x.numel()
180 # Select kernel based on dimensions
181 if len(shape) == 2:
182 rows, cols = shape
184 # Use optimized kernel for large 2D tensors
185 if rows * cols > 1000000:
186 stride_row = x.stride(0) # Row stride (complex element unit)
187 stride_col = x.stride(1) # Column stride (complex element unit)
189 BLOCK_SIZE_COLS = 128
190 grid_rows = rows
191 grid_cols = triton.cdiv(cols, BLOCK_SIZE_COLS)
192 grid = (grid_rows, grid_cols)
194 # Launch kernel (pass separate real/imaginary pointers, output maintains interleaved pointer)
195 resolve_conj_kernel_2d_strided[grid](
196 x_real,
197 x_img,
198 output_view,
199 rows,
200 cols,
201 stride_row,
202 stride_col,
203 is_conj,
204 BLOCK_SIZE_COLS,
205 )
206 else:
207 # Use 1D kernel for small 2D tensors
208 BLOCK_SIZE = 256
209 grid = (triton.cdiv(n_elements_total, BLOCK_SIZE),)
210 resolve_conj_kernel_1d[grid](
211 x_real, x_img, output_view, n_elements_total, is_conj, BLOCK_SIZE
212 )
213 elif len(shape) == 3:
214 # Use 1D kernel for 3D tensors (flatten processing)
215 n_elements_total = x.numel()
216 BLOCK_SIZE = min(1024, n_elements_total)
217 grid = (triton.cdiv(n_elements_total, BLOCK_SIZE),)
218 resolve_conj_kernel_1d[grid](
219 x_real, x_img, output_view, n_elements_total, is_conj, BLOCK_SIZE
220 )
221 else:
222 # Use general 1D kernel for 1D or other dimensions
223 BLOCK_SIZE = 1024 if n_elements_total > 1000000 else 256
224 grid = (triton.cdiv(n_elements_total, BLOCK_SIZE),)
225 resolve_conj_kernel_1d[grid](
226 x_real, x_img, output_view, n_elements_total, is_conj, BLOCK_SIZE
227 )
229 # Output is still complex tensor, structure unchanged
230 return output
231 else:
232 # Unsupported complex type, fallback to PyTorch implementation
233 if is_conj:
234 return torch.conj(x)
235 else:
236 return x.clone()
239def resolve_conj(A: torch.Tensor):
240 logger.debug("GEMS RESOLVE_CONJ")
241 if A.is_conj():
242 if len(A.shape) in (2, 3):
243 return resolve_conj_triton(A, is_conj=True)
244 else:
245 return torch.complex(A.real, A.imag.neg())
246 else:
247 return A