Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/dot.py: 0%
63 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import triton_lang_extension as tle
10from flag_gems.utils.libentry import libentry
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.jit
17def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
18 pid = tle.program_id(0)
19 block_start = pid * BLOCK_SIZE
21 offsets = block_start + tl.arange(0, BLOCK_SIZE)
23 mask = offsets < N
24 x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
25 y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
27 sum = tl.sum(x * y)
28 tl.store(out_ptr, sum)
31@libentry()
32@triton.jit
33def dot_kernel_1(x_ptr, y_ptr, mid_ptr, N, BLOCK_SIZE: tl.constexpr):
34 pid = tle.program_id(0)
35 block_start = pid * BLOCK_SIZE
37 offsets = block_start + tl.arange(0, BLOCK_SIZE)
39 mask = offsets < N
40 x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
41 y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
43 partial_sum = tl.sum(x * y)
44 tl.store(mid_ptr + pid, partial_sum)
47@libentry()
48@triton.jit
49def dot_kernel_2(mid_ptr, out_ptr, M, BLOCK_MID: tl.constexpr):
50 offset = tl.arange(0, BLOCK_MID)
51 mid = mid_ptr + offset
52 mask = offset < M
53 mid_val = tl.load(mid, mask=mask, other=0.0)
54 out_val = tl.sum(mid_val)
55 tl.store(out_ptr, out_val)
58def dot(x, y):
59 logger.debug("Triton Dot Product")
61 assert x.shape == y.shape, "Input vectors must have the same shape"
62 assert x.dim() == 1, "Input must be 1D tensors"
64 N = x.shape[0]
66 # Only when N is less than TRITON_MAX_TENSOR_NUMEL can it be processed with a single kernel,
67 # and performance is better when N < 4096
68 if N >= 4096:
69 block_size = triton.next_power_of_2(math.ceil(math.sqrt(N)))
71 mid_size = triton.cdiv(N, block_size)
72 block_mid = triton.next_power_of_2(mid_size)
74 grid_1 = (mid_size, 1, 1)
75 grid_2 = (1, 1, 1)
77 mid = torch.empty((mid_size,), dtype=torch.float32, device=x.device)
78 out = torch.empty([], dtype=x.dtype, device=x.device)
80 with torch_device_fn.device(x.device):
81 dot_kernel_1[grid_1](x, y, mid, N, block_size)
82 dot_kernel_2[grid_2](mid, out, mid_size, block_mid)
84 else:
85 block_size = triton.next_power_of_2(N)
87 grid = (1, 1, 1)
89 out = torch.empty([], dtype=torch.float32, device=x.device)
91 with torch_device_fn.device(x.device):
92 dot_kernel[grid](x, y, out, N, block_size)
93 out = out.to(x.dtype)
95 return out