Coverage for src/flag_gems/ops/resolve_conj.py: 16%

102 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.jit 

11def resolve_conj_kernel_1d( 

12 x_real_ptr, # Real part input pointer (float32, separate storage) 

13 x_img_ptr, # Imaginary part input pointer (float32, separate storage) 

14 output_ptr, # Output pointer (maintain original interleaved layout, float32 view) 

15 n_elements_total, # Total number of elements (number of complex pairs) 

16 is_conj: tl.constexpr, # Whether to set conjugate flag 

17 BLOCK_SIZE: tl.constexpr, # Block size 

18): 

19 # Get PID of current program 

20 pid = tl.program_id(axis=0) 

21 

22 # Create element index range for current block (complex element index, not float32 index) 

23 block_start = pid * BLOCK_SIZE 

24 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

25 

26 # Create mask to prevent out-of-bounds access 

27 mask = offsets < n_elements_total 

28 

29 # Input: Load real/imaginary parts directly via separate pointers (no need ×2, already separate storage) 

30 real = tl.load(x_real_ptr + offsets, mask=mask) 

31 imag = tl.load(x_img_ptr + offsets, mask=mask) 

32 

33 # Output: Maintain original interleaved layout (real part at even indices, imaginary part at odd indices) 

34 output_real_offsets = 2 * offsets 

35 output_img_offsets = 2 * offsets + 1 

36 

37 if is_conj: 

38 # Conjugate: Real part unchanged, imaginary part negated, stored in original layout 

39 tl.store(output_ptr + output_real_offsets, real, mask=mask) 

40 tl.store(output_ptr + output_img_offsets, -imag, mask=mask) 

41 else: 

42 # Direct copy, maintain original layout 

43 tl.store(output_ptr + output_real_offsets, real, mask=mask) 

44 tl.store(output_ptr + output_img_offsets, imag, mask=mask) 

45 

46 

47@triton.jit 

48def resolve_conj_kernel_2d_strided( 

49 x_real_ptr, # Real part input pointer (float32, separate storage) 

50 x_img_ptr, # Imaginary part input pointer (float32, separate storage) 

51 output_ptr, # Output pointer (maintain original interleaved layout, float32 view) 

52 n_rows, # Number of rows 

53 n_cols, # Number of columns 

54 stride_row, # Row stride (in complex elements) 

55 stride_col, # Column stride (in complex elements) 

56 is_conj: tl.constexpr, # Whether to set conjugate flag 

57 BLOCK_SIZE: tl.constexpr, # Block size 

58): 

59 # Get 2D PID of current program 

60 pid_row = tl.program_id(axis=0) 

61 pid_col_block = tl.program_id(axis=1) 

62 

63 # Calculate column index range (complex element index) 

64 col_start = pid_col_block * BLOCK_SIZE 

65 col_offsets = col_start + tl.arange(0, BLOCK_SIZE) 

66 

67 # Create column mask 

68 col_mask = col_offsets < n_cols 

69 

70 # Input: Calculate base offset of complex elements (no need ×2, real/imaginary parts separated) 

71 base_offset = pid_row * stride_row + col_offsets * stride_col 

72 

73 # Create full mask 

74 mask = col_mask & (pid_row < n_rows) 

75 

76 # Load separated real and imaginary parts 

77 real = tl.load(x_real_ptr + base_offset, mask=mask) 

78 imag = tl.load(x_img_ptr + base_offset, mask=mask) 

79 

80 # Output: Convert to interleaved layout offset (×2, real part first, imaginary part second) 

81 output_base_offset = base_offset * 2 

82 

83 if is_conj: 

84 tl.store(output_ptr + output_base_offset, real, mask=mask) 

85 tl.store(output_ptr + output_base_offset + 1, -imag, mask=mask) 

86 else: 

87 tl.store(output_ptr + output_base_offset, real, mask=mask) 

88 tl.store(output_ptr + output_base_offset + 1, imag, mask=mask) 

89 

90 

91@triton.jit 

92def resolve_conj_kernel_large_2d( 

93 x_real_ptr, # Real part input pointer (float32, separate storage) 

94 x_img_ptr, # Imaginary part input pointer (float32, separate storage) 

95 output_ptr, # Output pointer (maintain original interleaved layout, float32 view) 

96 n_rows, # Number of rows 

97 n_cols, # Number of columns 

98 stride_row, # Row stride (in complex elements) 

99 stride_col, # Column stride (in complex elements) 

100 is_conj: tl.constexpr, # Whether to set conjugate flag 

101 BLOCK_SIZE_ROWS: tl.constexpr, # Row block size 

102 BLOCK_SIZE_COLS: tl.constexpr, # Column block size 

103): 

104 # Get 2D PID of current program 

105 pid_row = tl.program_id(axis=0) 

106 pid_col = tl.program_id(axis=1) 

107 

108 # Calculate row and column index ranges (complex element index) 

109 row_offsets = pid_row * BLOCK_SIZE_ROWS + tl.arange(0, BLOCK_SIZE_ROWS) 

110 col_offsets = pid_col * BLOCK_SIZE_COLS + tl.arange(0, BLOCK_SIZE_COLS) 

111 

112 # Create row and column masks 

113 row_mask = row_offsets < n_rows 

114 col_mask = col_offsets < n_cols 

115 

116 # Input: Calculate base offset of complex elements (no need ×2, real/imaginary parts separated) 

117 base_offsets = row_offsets[:, None] * stride_row + col_offsets[None, :] * stride_col 

118 

119 # Create full mask 

120 mask = row_mask[:, None] & col_mask[None, :] 

121 

122 # Load separated real and imaginary parts 

123 real = tl.load(x_real_ptr + base_offsets, mask=mask) 

124 imag = tl.load(x_img_ptr + base_offsets, mask=mask) 

125 

126 # Output: Convert to interleaved layout offset (×2) 

127 output_base_offsets = base_offsets * 2 

128 

129 if is_conj: 

130 tl.store(output_ptr + output_base_offsets, real, mask=mask) 

131 tl.store(output_ptr + output_base_offsets + 1, -imag, mask=mask) 

132 else: 

133 tl.store(output_ptr + output_base_offsets, real, mask=mask) 

134 tl.store(output_ptr + output_base_offsets + 1, imag, mask=mask) 

135 

136 

137def resolve_conj_triton(x: torch.Tensor, is_conj: bool) -> torch.Tensor: 

138 """ 

139 resolve_conj function implemented with Triton, supporting arbitrary shapes 

140 Input: Separate real/imaginary parts (avoid x.view()), Output: Maintain original complex tensor structure 

141 

142 Args: 

143 x: Input tensor 

144 is_conj: Whether conjugate flag is set 

145 

146 Returns: 

147 Resolved tensor (structure consistent with input) 

148 """ 

149 # Ensure tensor is on GPU 

150 if not x.is_cuda: 

151 x = x.cuda() 

152 

153 # Check if it is complex type 

154 is_complex = x.is_complex() 

155 

156 # If no conjugate needed and is real, return copy directly 

157 if not is_conj and not is_complex: 

158 return x.clone() 

159 

160 if not is_complex: 

161 return x.clone() 

162 

163 # Output maintains original structure (unchanged), still complex tensor 

164 output = torch.empty_like(x) 

165 

166 if x.dtype == torch.complex64: 

167 # Input separate real/imaginary parts (avoid view(), get float32 tensor directly with .real/.imag) 

168 x_real = x.real # shape same as x, dtype=float32 (real part separate storage) 

169 x_img = ( 

170 x.imag 

171 ) # shape same as x, dtype=float32 (imaginary part separate storage) 

172 

173 # Output still use view() to convert to float32 pointer (only for kernel storage, no change to output structure) 

174 output_view = output.view(torch.float32) 

175 

176 # Get tensor shape and total number of elements 

177 shape = x.shape 

178 n_elements_total = x.numel() 

179 

180 # Select kernel based on dimensions 

181 if len(shape) == 2: 

182 rows, cols = shape 

183 

184 # Use optimized kernel for large 2D tensors 

185 if rows * cols > 1000000: 

186 stride_row = x.stride(0) # Row stride (complex element unit) 

187 stride_col = x.stride(1) # Column stride (complex element unit) 

188 

189 BLOCK_SIZE_COLS = 128 

190 grid_rows = rows 

191 grid_cols = triton.cdiv(cols, BLOCK_SIZE_COLS) 

192 grid = (grid_rows, grid_cols) 

193 

194 # Launch kernel (pass separate real/imaginary pointers, output maintains interleaved pointer) 

195 resolve_conj_kernel_2d_strided[grid]( 

196 x_real, 

197 x_img, 

198 output_view, 

199 rows, 

200 cols, 

201 stride_row, 

202 stride_col, 

203 is_conj, 

204 BLOCK_SIZE_COLS, 

205 ) 

206 else: 

207 # Use 1D kernel for small 2D tensors 

208 BLOCK_SIZE = 256 

209 grid = (triton.cdiv(n_elements_total, BLOCK_SIZE),) 

210 resolve_conj_kernel_1d[grid]( 

211 x_real, x_img, output_view, n_elements_total, is_conj, BLOCK_SIZE 

212 ) 

213 elif len(shape) == 3: 

214 # Use 1D kernel for 3D tensors (flatten processing) 

215 n_elements_total = x.numel() 

216 BLOCK_SIZE = min(1024, n_elements_total) 

217 grid = (triton.cdiv(n_elements_total, BLOCK_SIZE),) 

218 resolve_conj_kernel_1d[grid]( 

219 x_real, x_img, output_view, n_elements_total, is_conj, BLOCK_SIZE 

220 ) 

221 else: 

222 # Use general 1D kernel for 1D or other dimensions 

223 BLOCK_SIZE = 1024 if n_elements_total > 1000000 else 256 

224 grid = (triton.cdiv(n_elements_total, BLOCK_SIZE),) 

225 resolve_conj_kernel_1d[grid]( 

226 x_real, x_img, output_view, n_elements_total, is_conj, BLOCK_SIZE 

227 ) 

228 

229 # Output is still complex tensor, structure unchanged 

230 return output 

231 else: 

232 # Unsupported complex type, fallback to PyTorch implementation 

233 if is_conj: 

234 return torch.conj(x) 

235 else: 

236 return x.clone() 

237 

238 

239def resolve_conj(A: torch.Tensor): 

240 logger.debug("GEMS RESOLVE_CONJ") 

241 if A.is_conj(): 

242 if len(A.shape) in (2, 3): 

243 return resolve_conj_triton(A, is_conj=True) 

244 else: 

245 return torch.complex(A.real, A.imag.neg()) 

246 else: 

247 return A