Coverage for src/flag_gems/ops/resolve_conj.py: 17%
103 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7import flag_gems
9logger = logging.getLogger(__name__)
12@triton.jit
13def resolve_conj_kernel_1d(
14 x_real_ptr, # Real part input pointer (float32, separate storage)
15 x_img_ptr, # Imaginary part input pointer (float32, separate storage)
16 output_ptr, # Output pointer (maintain original interleaved layout, float32 view)
17 n_elements_total, # Total number of elements (number of complex pairs)
18 is_conj: tl.constexpr, # Whether to set conjugate flag
19 BLOCK_SIZE: tl.constexpr, # Block size
20):
21 # Get PID of current program
22 pid = tl.program_id(axis=0)
24 # Create element index range for current block (complex element index, not float32 index)
25 block_start = pid * BLOCK_SIZE
26 offsets = block_start + tl.arange(0, BLOCK_SIZE)
28 # Create mask to prevent out-of-bounds access
29 mask = offsets < n_elements_total
31 # Input: Load real/imaginary parts directly via separate pointers (no need ×2, already separate storage)
32 real = tl.load(x_real_ptr + offsets, mask=mask)
33 imag = tl.load(x_img_ptr + offsets, mask=mask)
35 # Output: Maintain original interleaved layout (real part at even indices, imaginary part at odd indices)
36 output_real_offsets = 2 * offsets
37 output_img_offsets = 2 * offsets + 1
39 if is_conj:
40 # Conjugate: Real part unchanged, imaginary part negated, stored in original layout
41 tl.store(output_ptr + output_real_offsets, real, mask=mask)
42 tl.store(output_ptr + output_img_offsets, -imag, mask=mask)
43 else:
44 # Direct copy, maintain original layout
45 tl.store(output_ptr + output_real_offsets, real, mask=mask)
46 tl.store(output_ptr + output_img_offsets, imag, mask=mask)
49@triton.jit
50def resolve_conj_kernel_2d_strided(
51 x_real_ptr, # Real part input pointer (float32, separate storage)
52 x_img_ptr, # Imaginary part input pointer (float32, separate storage)
53 output_ptr, # Output pointer (maintain original interleaved layout, float32 view)
54 n_rows, # Number of rows
55 n_cols, # Number of columns
56 stride_row, # Row stride (in complex elements)
57 stride_col, # Column stride (in complex elements)
58 is_conj: tl.constexpr, # Whether to set conjugate flag
59 BLOCK_SIZE: tl.constexpr, # Block size
60):
61 # Get 2D PID of current program
62 pid_row = tl.program_id(axis=0)
63 pid_col_block = tl.program_id(axis=1)
65 # Calculate column index range (complex element index)
66 col_start = pid_col_block * BLOCK_SIZE
67 col_offsets = col_start + tl.arange(0, BLOCK_SIZE)
69 # Create column mask
70 col_mask = col_offsets < n_cols
72 # Input: Calculate base offset of complex elements (no need ×2, real/imaginary parts separated)
73 base_offset = pid_row * stride_row + col_offsets * stride_col
75 # Create full mask
76 mask = col_mask & (pid_row < n_rows)
78 # Load separated real and imaginary parts
79 real = tl.load(x_real_ptr + base_offset, mask=mask)
80 imag = tl.load(x_img_ptr + base_offset, mask=mask)
82 # Output: Convert to interleaved layout offset (×2, real part first, imaginary part second)
83 output_base_offset = base_offset * 2
85 if is_conj:
86 tl.store(output_ptr + output_base_offset, real, mask=mask)
87 tl.store(output_ptr + output_base_offset + 1, -imag, mask=mask)
88 else:
89 tl.store(output_ptr + output_base_offset, real, mask=mask)
90 tl.store(output_ptr + output_base_offset + 1, imag, mask=mask)
93@triton.jit
94def resolve_conj_kernel_large_2d(
95 x_real_ptr, # Real part input pointer (float32, separate storage)
96 x_img_ptr, # Imaginary part input pointer (float32, separate storage)
97 output_ptr, # Output pointer (maintain original interleaved layout, float32 view)
98 n_rows, # Number of rows
99 n_cols, # Number of columns
100 stride_row, # Row stride (in complex elements)
101 stride_col, # Column stride (in complex elements)
102 is_conj: tl.constexpr, # Whether to set conjugate flag
103 BLOCK_SIZE_ROWS: tl.constexpr, # Row block size
104 BLOCK_SIZE_COLS: tl.constexpr, # Column block size
105):
106 # Get 2D PID of current program
107 pid_row = tl.program_id(axis=0)
108 pid_col = tl.program_id(axis=1)
110 # Calculate row and column index ranges (complex element index)
111 row_offsets = pid_row * BLOCK_SIZE_ROWS + tl.arange(0, BLOCK_SIZE_ROWS)
112 col_offsets = pid_col * BLOCK_SIZE_COLS + tl.arange(0, BLOCK_SIZE_COLS)
114 # Create row and column masks
115 row_mask = row_offsets < n_rows
116 col_mask = col_offsets < n_cols
118 # Input: Calculate base offset of complex elements (no need ×2, real/imaginary parts separated)
119 base_offsets = row_offsets[:, None] * stride_row + col_offsets[None, :] * stride_col
121 # Create full mask
122 mask = row_mask[:, None] & col_mask[None, :]
124 # Load separated real and imaginary parts
125 real = tl.load(x_real_ptr + base_offsets, mask=mask)
126 imag = tl.load(x_img_ptr + base_offsets, mask=mask)
128 # Output: Convert to interleaved layout offset (×2)
129 output_base_offsets = base_offsets * 2
131 if is_conj:
132 tl.store(output_ptr + output_base_offsets, real, mask=mask)
133 tl.store(output_ptr + output_base_offsets + 1, -imag, mask=mask)
134 else:
135 tl.store(output_ptr + output_base_offsets, real, mask=mask)
136 tl.store(output_ptr + output_base_offsets + 1, imag, mask=mask)
139def resolve_conj_triton(x: torch.Tensor, is_conj: bool) -> torch.Tensor:
140 """
141 resolve_conj function implemented with Triton, supporting arbitrary shapes
142 Input: Separate real/imaginary parts (avoid x.view()), Output: Maintain original complex tensor structure
144 Args:
145 x: Input tensor
146 is_conj: Whether conjugate flag is set
148 Returns:
149 Resolved tensor (structure consistent with input)
150 """
151 # Ensure tensor is on GPU
152 if x.device.type != flag_gems.device:
153 x = x.to(flag_gems.device)
155 # Check if it is complex type
156 is_complex = x.is_complex()
158 # If no conjugate needed and is real, return copy directly
159 if not is_conj and not is_complex:
160 return x.clone()
162 if not is_complex:
163 return x.clone()
165 # Output maintains original structure (unchanged), still complex tensor
166 output = torch.empty_like(x)
168 if x.dtype == torch.complex64:
169 # Input separate real/imaginary parts (avoid view(), get float32 tensor directly with .real/.imag)
170 x_real = x.real # shape same as x, dtype=float32 (real part separate storage)
171 x_img = (
172 x.imag
173 ) # shape same as x, dtype=float32 (imaginary part separate storage)
175 # Output still use view() to convert to float32 pointer (only for kernel storage, no change to output structure)
176 output_view = output.view(torch.float32)
178 # Get tensor shape and total number of elements
179 shape = x.shape
180 n_elements_total = x.numel()
182 # Select kernel based on dimensions
183 if len(shape) == 2:
184 rows, cols = shape
186 # Use optimized kernel for large 2D tensors
187 if rows * cols > 1000000:
188 stride_row = x.stride(0) # Row stride (complex element unit)
189 stride_col = x.stride(1) # Column stride (complex element unit)
191 BLOCK_SIZE_COLS = 128
192 grid_rows = rows
193 grid_cols = triton.cdiv(cols, BLOCK_SIZE_COLS)
194 grid = (grid_rows, grid_cols)
196 # Launch kernel (pass separate real/imaginary pointers, output maintains interleaved pointer)
197 resolve_conj_kernel_2d_strided[grid](
198 x_real,
199 x_img,
200 output_view,
201 rows,
202 cols,
203 stride_row,
204 stride_col,
205 is_conj,
206 BLOCK_SIZE_COLS,
207 )
208 else:
209 # Use 1D kernel for small 2D tensors
210 BLOCK_SIZE = 256
211 grid = (triton.cdiv(n_elements_total, BLOCK_SIZE),)
212 resolve_conj_kernel_1d[grid](
213 x_real, x_img, output_view, n_elements_total, is_conj, BLOCK_SIZE
214 )
215 elif len(shape) == 3:
216 # Use 1D kernel for 3D tensors (flatten processing)
217 n_elements_total = x.numel()
218 BLOCK_SIZE = min(1024, n_elements_total)
219 grid = (triton.cdiv(n_elements_total, BLOCK_SIZE),)
220 resolve_conj_kernel_1d[grid](
221 x_real, x_img, output_view, n_elements_total, is_conj, BLOCK_SIZE
222 )
223 else:
224 # Use general 1D kernel for 1D or other dimensions
225 BLOCK_SIZE = 1024 if n_elements_total > 1000000 else 256
226 grid = (triton.cdiv(n_elements_total, BLOCK_SIZE),)
227 resolve_conj_kernel_1d[grid](
228 x_real, x_img, output_view, n_elements_total, is_conj, BLOCK_SIZE
229 )
231 # Output is still complex tensor, structure unchanged
232 return output
233 else:
234 # Unsupported complex type, fallback to PyTorch implementation
235 if is_conj:
236 return torch.conj(x)
237 else:
238 return x.clone()
241def resolve_conj(A: torch.Tensor):
242 logger.debug("GEMS RESOLVE_CONJ")
243 if A.is_conj():
244 if len(A.shape) in (2, 3):
245 return resolve_conj_triton(A, is_conj=True)
246 else:
247 return torch.complex(A.real, A.imag.neg())
248 else:
249 return A