Coverage for src/flag_gems/runtime/backend/_mthreads/ops/resolve_conj.py: 0%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import triton_lang_extension as tle 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@triton.jit 

13def resolve_conj_kernel_1d( 

14 x_real_ptr, # Real part input pointer (float32, separate storage) 

15 x_img_ptr, # Imaginary part input pointer (float32, separate storage) 

16 output_ptr, # Output pointer (maintain original interleaved layout, float32 view) 

17 n_elements_total, # Total number of elements (number of complex pairs) 

18 is_conj: tl.constexpr, # Whether to set conjugate flag 

19 BLOCK_SIZE: tl.constexpr, # Block size 

20): 

21 # Get PID of current program 

22 pid = tle.program_id(axis=0) 

23 

24 # Create element index range for current block (complex element index, not float32 index) 

25 block_start = pid * BLOCK_SIZE 

26 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

27 

28 # Create mask to prevent out-of-bounds access 

29 mask = offsets < n_elements_total 

30 

31 # Input: Load real/imaginary parts directly via separate pointers (no need ×2, already separate storage) 

32 real = tl.load(x_real_ptr + offsets, mask=mask) 

33 imag = tl.load(x_img_ptr + offsets, mask=mask) 

34 

35 # Output: Maintain original interleaved layout (real part at even indices, imaginary part at odd indices) 

36 output_real_offsets = 2 * offsets 

37 output_img_offsets = 2 * offsets + 1 

38 

39 if is_conj: 

40 # Conjugate: Real part unchanged, imaginary part negated, stored in original layout 

41 tl.store(output_ptr + output_real_offsets, real, mask=mask) 

42 tl.store(output_ptr + output_img_offsets, -imag, mask=mask) 

43 else: 

44 # Direct copy, maintain original layout 

45 tl.store(output_ptr + output_real_offsets, real, mask=mask) 

46 tl.store(output_ptr + output_img_offsets, imag, mask=mask) 

47 

48 

49@triton.jit 

50def resolve_conj_kernel_2d_strided( 

51 x_real_ptr, # Real part input pointer (float32, separate storage) 

52 x_img_ptr, # Imaginary part input pointer (float32, separate storage) 

53 output_ptr, # Output pointer (maintain original interleaved layout, float32 view) 

54 n_rows, # Number of rows 

55 n_cols, # Number of columns 

56 stride_row, # Row stride (in complex elements) 

57 stride_col, # Column stride (in complex elements) 

58 is_conj: tl.constexpr, # Whether to set conjugate flag 

59 BLOCK_SIZE: tl.constexpr, # Block size 

60): 

61 # Get 2D PID of current program 

62 pid_row = tle.program_id(axis=0) 

63 pid_col_block = tl.program_id(axis=1) 

64 

65 # Calculate column index range (complex element index) 

66 col_start = pid_col_block * BLOCK_SIZE 

67 col_offsets = col_start + tl.arange(0, BLOCK_SIZE) 

68 

69 # Create column mask 

70 col_mask = col_offsets < n_cols 

71 

72 # Input: Calculate base offset of complex elements (no need ×2, real/imaginary parts separated) 

73 base_offset = pid_row * stride_row + col_offsets * stride_col 

74 

75 # Create full mask 

76 mask = col_mask & (pid_row < n_rows) 

77 

78 # Load separated real and imaginary parts 

79 real = tl.load(x_real_ptr + base_offset, mask=mask) 

80 imag = tl.load(x_img_ptr + base_offset, mask=mask) 

81 

82 # Output: Convert to interleaved layout offset (×2, real part first, imaginary part second) 

83 output_base_offset = base_offset * 2 

84 

85 if is_conj: 

86 tl.store(output_ptr + output_base_offset, real, mask=mask) 

87 tl.store(output_ptr + output_base_offset + 1, -imag, mask=mask) 

88 else: 

89 tl.store(output_ptr + output_base_offset, real, mask=mask) 

90 tl.store(output_ptr + output_base_offset + 1, imag, mask=mask) 

91 

92 

93@triton.jit 

94def resolve_conj_kernel_large_2d( 

95 x_real_ptr, # Real part input pointer (float32, separate storage) 

96 x_img_ptr, # Imaginary part input pointer (float32, separate storage) 

97 output_ptr, # Output pointer (maintain original interleaved layout, float32 view) 

98 n_rows, # Number of rows 

99 n_cols, # Number of columns 

100 stride_row, # Row stride (in complex elements) 

101 stride_col, # Column stride (in complex elements) 

102 is_conj: tl.constexpr, # Whether to set conjugate flag 

103 BLOCK_SIZE_ROWS: tl.constexpr, # Row block size 

104 BLOCK_SIZE_COLS: tl.constexpr, # Column block size 

105): 

106 # Get 2D PID of current program 

107 pid_row = tle.program_id(axis=0) 

108 pid_col = tle.program_id(axis=1) 

109 

110 # Calculate row and column index ranges (complex element index) 

111 row_offsets = pid_row * BLOCK_SIZE_ROWS + tl.arange(0, BLOCK_SIZE_ROWS) 

112 col_offsets = pid_col * BLOCK_SIZE_COLS + tl.arange(0, BLOCK_SIZE_COLS) 

113 

114 # Create row and column masks 

115 row_mask = row_offsets < n_rows 

116 col_mask = col_offsets < n_cols 

117 

118 # Input: Calculate base offset of complex elements (no need ×2, real/imaginary parts separated) 

119 base_offsets = row_offsets[:, None] * stride_row + col_offsets[None, :] * stride_col 

120 

121 # Create full mask 

122 mask = row_mask[:, None] & col_mask[None, :] 

123 

124 # Load separated real and imaginary parts 

125 real = tl.load(x_real_ptr + base_offsets, mask=mask) 

126 imag = tl.load(x_img_ptr + base_offsets, mask=mask) 

127 

128 # Output: Convert to interleaved layout offset (×2) 

129 output_base_offsets = base_offsets * 2 

130 

131 if is_conj: 

132 tl.store(output_ptr + output_base_offsets, real, mask=mask) 

133 tl.store(output_ptr + output_base_offsets + 1, -imag, mask=mask) 

134 else: 

135 tl.store(output_ptr + output_base_offsets, real, mask=mask) 

136 tl.store(output_ptr + output_base_offsets + 1, imag, mask=mask) 

137 

138 

139def resolve_conj_triton(x: torch.Tensor, is_conj: bool) -> torch.Tensor: 

140 """ 

141 resolve_conj function implemented with Triton, supporting arbitrary shapes 

142 Input: Separate real/imaginary parts (avoid x.view()), Output: Maintain original complex tensor structure 

143 

144 Args: 

145 x: Input tensor 

146 is_conj: Whether conjugate flag is set 

147 

148 Returns: 

149 Resolved tensor (structure consistent with input) 

150 """ 

151 # Ensure tensor is on GPU 

152 if not x.is_musa: 

153 x = x.musa() 

154 

155 # Check if it is complex type 

156 is_complex = x.is_complex() 

157 

158 # If no conjugate needed and is real, return copy directly 

159 if not is_conj and not is_complex: 

160 return x.clone() 

161 

162 if not is_complex: 

163 return x.clone() 

164 

165 # Output maintains original structure (unchanged), still complex tensor 

166 output = torch.empty_like(x) 

167 

168 if x.dtype == torch.complex64: 

169 # Input separate real/imaginary parts (avoid view(), get float32 tensor directly with .real/.imag) 

170 x_real = x.real # shape same as x, dtype=float32 (real part separate storage) 

171 x_img = ( 

172 x.imag 

173 ) # shape same as x, dtype=float32 (imaginary part separate storage) 

174 

175 # Output still use view() to convert to float32 pointer (only for kernel storage, no change to output structure) 

176 output_view = output.view(torch.float32) 

177 

178 # Get tensor shape and total number of elements 

179 shape = x.shape 

180 n_elements_total = x.numel() 

181 

182 # Select kernel based on dimensions 

183 if len(shape) == 2: 

184 rows, cols = shape 

185 

186 # Use optimized kernel for large 2D tensors 

187 if rows * cols > 1000000: 

188 stride_row = x.stride(0) # Row stride (complex element unit) 

189 stride_col = x.stride(1) # Column stride (complex element unit) 

190 

191 BLOCK_SIZE_COLS = 128 

192 grid_rows = rows 

193 grid_cols = triton.cdiv(cols, BLOCK_SIZE_COLS) 

194 grid = (grid_rows, grid_cols) 

195 

196 # Launch kernel (pass separate real/imaginary pointers, output maintains interleaved pointer) 

197 resolve_conj_kernel_2d_strided[grid]( 

198 x_real, 

199 x_img, 

200 output_view, 

201 rows, 

202 cols, 

203 stride_row, 

204 stride_col, 

205 is_conj, 

206 BLOCK_SIZE_COLS, 

207 ) 

208 else: 

209 # Use 1D kernel for small 2D tensors 

210 BLOCK_SIZE = 256 

211 grid = (triton.cdiv(n_elements_total, BLOCK_SIZE),) 

212 resolve_conj_kernel_1d[grid]( 

213 x_real, x_img, output_view, n_elements_total, is_conj, BLOCK_SIZE 

214 ) 

215 elif len(shape) == 3: 

216 # Use 1D kernel for 3D tensors (flatten processing) 

217 n_elements_total = x.numel() 

218 BLOCK_SIZE = min(1024, n_elements_total) 

219 grid = (triton.cdiv(n_elements_total, BLOCK_SIZE),) 

220 resolve_conj_kernel_1d[grid]( 

221 x_real, x_img, output_view, n_elements_total, is_conj, BLOCK_SIZE 

222 ) 

223 else: 

224 # Use general 1D kernel for 1D or other dimensions 

225 BLOCK_SIZE = 1024 if n_elements_total > 1000000 else 256 

226 grid = (triton.cdiv(n_elements_total, BLOCK_SIZE),) 

227 resolve_conj_kernel_1d[grid]( 

228 x_real, x_img, output_view, n_elements_total, is_conj, BLOCK_SIZE 

229 ) 

230 

231 # Output is still complex tensor, structure unchanged 

232 return output 

233 else: 

234 # Unsupported complex type, fallback to PyTorch implementation 

235 if is_conj: 

236 return torch.conj(x) 

237 else: 

238 return x.clone() 

239 

240 

241def resolve_conj(A: torch.Tensor): 

242 logger.debug("GEMS_MTHREADS RESOLVE_CONJ") 

243 if A.is_conj(): 

244 if len(A.shape) in (2, 3): 

245 return resolve_conj_triton(A, is_conj=True) 

246 else: 

247 return torch.complex(A.real, A.imag.neg()) 

248 else: 

249 return A