Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/vdot.py: 0%
100 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch import Tensor
8# from flag_gems import runtime
9from flag_gems.utils import libentry
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14@triton.jit
15def compute_vdot(
16 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj
17):
18 # # Given inp storage: [inp_real, inp_imag], other: [other_real, other_imag]
20 # # Case 1: inp_is_conj = False, other_is_conj = False
21 # out_real = inp_real * other_real + inp_imag * other_imag
22 # out_imag = inp_real * other_imag - inp_imag * other_real
24 # # Case 2: inp_is_conj = True, other_is_conj = False
25 # out_real = inp_real * other_real - inp_imag * other_imag
26 # out_imag = inp_real * other_imag + inp_imag * other_real
28 # # Case 3: inp_is_conj = False, other_is_conj = True
29 # out_real = inp_real * other_real - inp_imag * other_imag
30 # out_imag = -inp_real * other_imag - inp_imag * other_real
32 # # Case 4: inp_is_conj = True, other_is_conj = True
33 # out_real = inp_real * other_real + inp_imag * other_imag
34 # out_imag = inp_real * other_imag - inp_imag * other_real
35 if not inp_is_conj and not other_is_conj: # Case 1
36 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag)
37 out_imag = tl.sum(inp_real * other_imag - inp_imag * other_real)
38 elif inp_is_conj and not other_is_conj: # Case 2
39 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag)
40 out_imag = tl.sum(inp_real * other_imag + inp_imag * other_real)
41 elif not inp_is_conj and other_is_conj: # Case 3
42 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag)
43 out_imag = tl.sum(-inp_real * other_imag - inp_imag * other_real)
44 else: # Case 4
45 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag)
46 out_imag = tl.sum(-inp_real * other_imag + inp_imag * other_real)
48 return out_real, out_imag
51def vdot_kernel_heur_block_size(args):
52 if args["n_elements"] < 8192:
53 return args["n_elements"]
55 return triton.next_power_of_2(triton.cdiv(args["n_elements"], 12))
58# support old version triton which do not support tl.split
59@libentry()
60# @triton.heuristics(runtime.get_heuristic_config("vdot"))
61@triton.heuristics(
62 values={
63 "BLOCK_SIZE": vdot_kernel_heur_block_size,
64 },
65)
66@triton.jit()
67def vdot_kernel_complex(
68 inp_ptr,
69 other_ptr,
70 out_ptr,
71 n_elements: tl.constexpr,
72 inp_is_conj: tl.constexpr,
73 other_is_conj: tl.constexpr,
74 inp_stride: tl.constexpr,
75 other_stride: tl.constexpr,
76 BLOCK_SIZE: tl.constexpr,
77):
78 pid = tl.program_id(0)
80 base_offset = 2 * pid * BLOCK_SIZE + 2 * tl.arange(0, BLOCK_SIZE) + tl.arange(0, 1)
82 inp_real_offset = inp_stride * base_offset
83 inp_imag_offset = inp_real_offset + 1
85 other_real_offset = other_stride * base_offset
86 other_imag_offset = other_real_offset + 1
88 mask = base_offset < n_elements
90 inp_real = tl.load(inp_ptr + inp_real_offset, mask=mask)
91 inp_imag = tl.load(inp_ptr + inp_imag_offset, mask=mask)
93 other_real = tl.load(other_ptr + other_real_offset, mask=mask)
94 other_imag = tl.load(other_ptr + other_imag_offset, mask=mask)
96 inp_real = tl.where(mask, inp_real, 0.0)
97 inp_imag = tl.where(mask, inp_imag, 0.0)
98 other_real = tl.where(mask, other_real, 0.0)
99 other_imag = tl.where(mask, other_imag, 0.0)
101 # Compute based on conjugate flags
102 out_real, out_imag = compute_vdot(
103 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj
104 )
106 tl.store(out_ptr, out_real)
107 tl.store(out_ptr + 1, out_imag)
110def dot_kernel_heur_block_size(args):
111 if args["n_elements"] % 2 != 0:
112 return triton.next_power_of_2(args["n_elements"])
114 if args["n_elements"] < 8192:
115 return args["n_elements"]
117 return triton.next_power_of_2(triton.cdiv(args["n_elements"], 12))
120# only support real number
121@libentry()
122# @triton.heuristics(runtime.get_heuristic_config("vdot"))
123@triton.heuristics(
124 values={
125 "BLOCK_SIZE": dot_kernel_heur_block_size,
126 },
127)
128@triton.jit()
129def dot_kernel(
130 inp_ptr,
131 other_ptr,
132 out_ptr,
133 n_elements: tl.constexpr,
134 inp_stride: tl.constexpr,
135 other_stride: tl.constexpr,
136 BLOCK_SIZE: tl.constexpr,
137):
138 pid = tl.program_id(0)
139 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
140 mask = offset < n_elements
142 inp = tl.load(inp_ptr + inp_stride * offset, mask=mask).to(tl.float32)
143 inp = tl.where(mask, inp, 0.0)
144 other = tl.load(other_ptr + other_stride * offset, mask=mask).to(tl.float32)
145 other = tl.where(mask, other, 0.0)
147 out = tl.sum(inp * other)
148 tl.store(out_ptr, out)
151def vdot(input: Tensor, other: Tensor):
152 logger.debug("GEMS VDOT")
154 assert (
155 input.dtype == other.dtype
156 ), f"Input tensors must have the same dtype. Got {input.dtype} and {other.dtype}."
157 assert (
158 input.ndim == 1 and other.ndim == 1
159 ), f"Input tensors must be 1D. Got {input.ndim}D and {other.ndim}D."
160 assert (
161 input.size() == other.size()
162 ), f"Input tensors must have the same size. Got {input.size()} and {other.size()}."
164 inp = input
165 inp_stride = inp.stride()[0]
166 other_stride = other.stride()[0]
168 if inp.is_complex():
169 inp_is_conj = False
170 other_is_conj = False
172 if inp.is_conj():
173 inp_is_conj = True
174 inp = inp.conj()
176 if other.is_conj():
177 other_is_conj = True
178 other = other.conj()
180 inp_real = torch.view_as_real(inp)
181 other_real = torch.view_as_real(other)
183 n_elements = inp_real.numel()
184 n_complex = inp.numel()
186 output_real = torch.zeros(2, dtype=inp_real.dtype, device=inp.device)
188 grid = lambda meta: (triton.cdiv(n_complex, meta["BLOCK_SIZE"]),)
190 vdot_kernel_complex[grid](
191 inp_real,
192 other_real,
193 output_real,
194 n_elements=n_elements,
195 inp_is_conj=inp_is_conj,
196 other_is_conj=other_is_conj,
197 inp_stride=inp_stride,
198 other_stride=other_stride,
199 isCLOSE_TTXPU_O_ATOMIC_SIM=True,
200 isCloseOffsetAnalysis=True,
201 isCloseUnrollControl=True,
202 )
204 return torch.view_as_complex(output_real)
205 else:
206 output = torch.zeros([], dtype=torch.float32, device=inp.device)
207 n_elements = inp.numel()
208 inp_dtype = inp.dtype
209 if n_elements == 1041 and inp.dtype == torch.bfloat16:
210 inp = inp.to(torch.float32)
211 other = other.to(torch.float32)
213 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
214 dot_kernel[grid](
215 inp,
216 other,
217 output,
218 n_elements=n_elements,
219 inp_stride=inp_stride,
220 other_stride=other_stride,
221 isCLOSE_TTXPU_O_ATOMIC_SIM=True,
222 isCloseOffsetAnalysis=True,
223 )
224 return output.to(inp_dtype)