Coverage for src/flag_gems/ops/vdot.py: 53%
140 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch import Tensor
8from flag_gems import runtime
9from flag_gems.utils import libentry
11logger = logging.getLogger(__name__)
14@triton.jit
15def compute_vdot(
16 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj
17):
18 # # Given inp storage: [inp_real, inp_imag], other: [other_real, other_imag]
20 # # Case 1: inp_is_conj = False, other_is_conj = False
21 # out_real = inp_real * other_real + inp_imag * other_imag
22 # out_imag = inp_real * other_imag - inp_imag * other_real
24 # # Case 2: inp_is_conj = True, other_is_conj = False
25 # out_real = inp_real * other_real - inp_imag * other_imag
26 # out_imag = inp_real * other_imag + inp_imag * other_real
28 # # Case 3: inp_is_conj = False, other_is_conj = True
29 # out_real = inp_real * other_real - inp_imag * other_imag
30 # out_imag = -inp_real * other_imag - inp_imag * other_real
32 # # Case 4: inp_is_conj = True, other_is_conj = True
33 # out_real = inp_real * other_real + inp_imag * other_imag
34 # out_imag = inp_real * other_imag - inp_imag * other_real
35 if not inp_is_conj and not other_is_conj: # Case 1
36 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag)
37 out_imag = tl.sum(inp_real * other_imag - inp_imag * other_real)
38 elif inp_is_conj and not other_is_conj: # Case 2
39 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag)
40 out_imag = tl.sum(inp_real * other_imag + inp_imag * other_real)
41 elif not inp_is_conj and other_is_conj: # Case 3
42 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag)
43 out_imag = tl.sum(-inp_real * other_imag - inp_imag * other_real)
44 else: # Case 4
45 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag)
46 out_imag = tl.sum(-inp_real * other_imag + inp_imag * other_real)
48 return out_real, out_imag
51# support old version triton which do not support tl.split
52@libentry()
53@triton.jit()
54def vdot_kernel_complex(
55 inp_ptr,
56 other_ptr,
57 out_ptr,
58 n_elements,
59 inp_is_conj: tl.constexpr,
60 other_is_conj: tl.constexpr,
61 inp_stride: tl.constexpr,
62 other_stride: tl.constexpr,
63 BLOCK_SIZE: tl.constexpr,
64):
65 pid = tl.program_id(0)
66 num_progs = tl.num_programs(0)
68 grid_stride = num_progs * BLOCK_SIZE
70 acc_real = tl.zeros([], dtype=tl.float32)
71 acc_imag = tl.zeros([], dtype=tl.float32)
73 for current_start in range(0, n_elements // 2, grid_stride):
74 complex_idx = current_start + pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
75 mask = complex_idx < n_elements // 2
77 real_offset = complex_idx * 2
79 inp_real = tl.load(inp_ptr + real_offset * inp_stride, mask=mask, other=0.0)
80 inp_imag = tl.load(inp_ptr + real_offset * inp_stride + 1, mask=mask, other=0.0)
82 other_real = tl.load(
83 other_ptr + real_offset * other_stride, mask=mask, other=0.0
84 )
85 other_imag = tl.load(
86 other_ptr + real_offset * other_stride + 1, mask=mask, other=0.0
87 )
89 out_real, out_imag = compute_vdot(
90 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj
91 )
92 acc_real += out_real
93 acc_imag += out_imag
95 temp_offset = pid * 2
96 tl.store(out_ptr + temp_offset, acc_real)
97 tl.store(out_ptr + temp_offset + 1, acc_imag)
100@libentry()
101@triton.jit()
102def reduce_kernel_complex(input_ptr, out_ptr, n_blocks, BLOCK_SIZE: tl.constexpr):
103 pid = tl.program_id(0)
104 base_offset = tl.arange(0, BLOCK_SIZE)
105 mask = base_offset < n_blocks
107 inp_real = tl.load(input_ptr + base_offset * 2, mask=mask, other=0.0)
108 inp_imag = tl.load(input_ptr + base_offset * 2 + 1, mask=mask, other=0.0)
109 final_out_real = tl.sum(inp_real)
110 final_out_imag = tl.sum(inp_imag)
111 if pid == 0:
112 tl.store(out_ptr, final_out_real)
113 tl.store(out_ptr + 1, final_out_imag)
116# only support real number
117@libentry()
118@triton.heuristics(runtime.get_heuristic_config("vdot"))
119@triton.jit()
120def dot_kernel(
121 inp_ptr,
122 other_ptr,
123 out_ptr,
124 n_elements,
125 inp_stride: tl.constexpr,
126 other_stride: tl.constexpr,
127 BLOCK_SIZE: tl.constexpr,
128):
129 pid = tl.program_id(0)
130 num_progs = tl.num_programs(0)
131 grid_stride = num_progs * BLOCK_SIZE
133 acc = tl.zeros([], dtype=tl.float32)
135 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
137 for current_start in range(0, n_elements, grid_stride):
138 cur_offsets = current_start + offsets
139 mask = cur_offsets < n_elements
141 inp = tl.load(inp_ptr + inp_stride * cur_offsets, mask=mask, other=0.0).to(
142 tl.float32
143 )
144 other = tl.load(
145 other_ptr + other_stride * cur_offsets, mask=mask, other=0.0
146 ).to(tl.float32)
148 acc += tl.sum(inp * other)
150 tl.store(out_ptr + pid, acc)
153@libentry()
154@triton.jit()
155def reduce_kernel(
156 partial_sums_ptr,
157 output_ptr,
158 n_blocks,
159 BLOCK_SIZE: tl.constexpr,
160):
161 offset = tl.arange(0, BLOCK_SIZE)
162 mask = offset < n_blocks
164 partial_sums = tl.load(partial_sums_ptr + offset, mask=mask, other=0.0)
165 final_sum = tl.sum(partial_sums)
167 if tl.program_id(0) == 0:
168 tl.store(output_ptr, final_sum)
171@libentry()
172@triton.heuristics(runtime.get_heuristic_config("vdot"))
173@triton.jit()
174def dot_kernel_fp32(
175 inp_ptr,
176 other_ptr,
177 out_ptr,
178 n_elements,
179 inp_stride: tl.constexpr,
180 other_stride: tl.constexpr,
181 BLOCK_SIZE: tl.constexpr,
182):
183 pid = tl.program_id(0)
184 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
185 mask = offset < n_elements
187 inp = tl.load(inp_ptr + inp_stride * offset, mask=mask)
188 other = tl.load(other_ptr + other_stride * offset, mask=mask)
190 out = tl.sum(inp * other)
191 tl.atomic_add(out_ptr, out)
194def vdot(input: Tensor, other: Tensor):
195 logger.debug("GEMS VDOT")
197 assert (
198 input.dtype == other.dtype
199 ), f"Input tensors must have the same dtype. Got {input.dtype} and {other.dtype}."
200 assert (
201 input.ndim == 1 and other.ndim == 1
202 ), f"Input tensors must be 1D. Got {input.ndim}D and {other.ndim}D."
203 assert (
204 input.size() == other.size()
205 ), f"Input tensors must have the same size. Got {input.size()} and {other.size()}."
207 inp = input
208 inp_stride = inp.stride()[0]
209 other_stride = other.stride()[0]
211 if inp.is_complex():
212 inp_is_conj = False
213 other_is_conj = False
215 if inp.is_conj():
216 inp_is_conj = True
217 inp = inp.conj()
219 if other.is_conj():
220 other_is_conj = True
221 other = other.conj()
223 inp_real = torch.view_as_real(inp)
224 other_real = torch.view_as_real(other)
226 n_elements = inp_real.numel()
227 n_complex = inp.numel()
229 block_size = runtime.get_heuristic_config("vdot")["BLOCK_SIZE"](
230 {"n_elements": n_elements}
231 )
232 num_blocks = triton.cdiv(n_complex, block_size)
234 grid_size = min(num_blocks, 1024)
236 partial_real_sums = torch.empty(
237 grid_size, dtype=inp_real.dtype, device=inp.device
238 )
239 grid = (grid_size,)
240 vdot_kernel_complex[grid](
241 inp_real,
242 other_real,
243 partial_real_sums,
244 n_elements=n_elements,
245 inp_is_conj=inp_is_conj,
246 other_is_conj=other_is_conj,
247 inp_stride=inp_stride,
248 other_stride=other_stride,
249 BLOCK_SIZE=block_size,
250 )
251 output_real = torch.empty(2, dtype=inp_real.dtype, device=inp.device)
252 reduce_kernel_complex[(1,)](
253 partial_real_sums,
254 output_real,
255 grid_size,
256 BLOCK_SIZE=triton.next_power_of_2(grid_size),
257 )
258 return torch.view_as_complex(output_real)
259 elif inp.dtype == torch.float32:
260 output = torch.zeros([], dtype=torch.float32, device=inp.device)
261 n_elements = inp.numel()
262 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
263 dot_kernel_fp32[grid](
264 inp,
265 other,
266 output,
267 n_elements=n_elements,
268 inp_stride=inp_stride,
269 other_stride=other_stride,
270 )
271 return output
272 else:
273 n_elements = inp.numel()
274 block_size = runtime.get_heuristic_config("vdot")["BLOCK_SIZE"](
275 {"n_elements": n_elements}
276 )
278 num_blocks = triton.cdiv(n_elements, block_size)
279 grid_size = min(num_blocks, 1024)
281 grid = (num_blocks,)
282 partial_sums = torch.empty(grid_size, dtype=torch.float32, device=inp.device)
283 dot_kernel[grid](
284 inp,
285 other,
286 partial_sums,
287 n_elements=n_elements,
288 inp_stride=inp_stride,
289 other_stride=other_stride,
290 BLOCK_SIZE=block_size,
291 )
292 output = torch.empty([], dtype=input.dtype, device=inp.device)
293 reduce_bs = min(triton.next_power_of_2(grid_size), 1024)
294 reduce_kernel[(1,)](
295 partial_sums,
296 output,
297 num_blocks,
298 BLOCK_SIZE=reduce_bs,
299 )
300 return output