Coverage for src/flag_gems/runtime/backend/_ascend/ops/dot.py: 0%
72 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
15@libentry()
16@triton.jit
17def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
18 pid = tle.program_id(0)
19 block_start = pid * BLOCK_SIZE
21 offsets = block_start + tl.arange(0, BLOCK_SIZE)
23 mask = offsets < N
24 x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
25 y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
27 sum = tl.sum(x * y)
28 tl.store(out_ptr, sum)
31@libentry()
32@triton.jit
33def dot_kernel_1(x_ptr, y_ptr, mid_ptr, N, BLOCK_SIZE: tl.constexpr):
34 n_workers = tle.num_programs(0)
35 pid = tle.program_id(0)
37 n_tasks = tl.cdiv(N, BLOCK_SIZE)
38 tasks_per_worker = tl.cdiv(n_tasks, n_workers)
39 for task_index in range(tasks_per_worker):
40 task_id = pid + task_index * n_workers
42 block_start = task_id * BLOCK_SIZE
44 offsets = block_start + tl.arange(0, BLOCK_SIZE)
46 mask = offsets < N
47 x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
48 y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
50 partial_sum = tl.sum(x * y)
51 tl.store(mid_ptr + task_id, partial_sum)
54@libentry()
55@triton.jit
56def dot_kernel_2(mid_ptr, out_ptr, M, BLOCK_MID: tl.constexpr):
57 n_tasks = tl.cdiv(M, BLOCK_MID)
58 for task_index in range(n_tasks):
59 offset = task_index * BLOCK_MID + tl.arange(0, BLOCK_MID)
60 mid = mid_ptr + offset
61 mask = offset < M
62 mid_val = tl.load(mid, mask=mask, other=0.0)
63 out_val = tl.sum(mid_val)
64 tl.store(out_ptr, out_val)
67def dot(x, y):
68 logger.debug("GEMS_ASCEND DOT")
70 assert x.shape == y.shape, "Input vectors must have the same shape"
71 assert x.dim() == 1, "Input must be 1D tensors"
73 N = x.shape[0]
75 # Only when N is less than TRITON_MAX_TENSOR_NUMEL can it be processed with a single kernel,
76 # and performance is better when N < 4096
77 if N >= 4096:
78 block_size = triton.next_power_of_2(math.ceil(math.sqrt(N)))
79 block_size = 12032 if block_size > 12032 else block_size
80 mid_size = triton.cdiv(N, block_size)
81 block_mid = triton.next_power_of_2(mid_size)
82 block_mid = 16384 if block_mid > 16384 else block_mid
84 grid_1 = (mid_size if mid_size < 240 else 240, 1, 1)
85 grid_2 = (1, 1, 1)
87 mid = torch.empty((mid_size,), dtype=torch.float32, device=x.device)
88 out = torch.empty([], dtype=x.dtype, device=x.device)
90 with torch_device_fn.device(x.device):
91 dot_kernel_1[grid_1](x, y, mid, N, block_size)
92 dot_kernel_2[grid_2](mid, out, mid_size, block_mid)
94 else:
95 block_size = triton.next_power_of_2(N)
97 grid = (1, 1, 1)
99 out = torch.empty([], dtype=torch.float32, device=x.device)
101 with torch_device_fn.device(x.device):
102 dot_kernel[grid](x, y, out, N, block_size)
103 out = out.to(x.dtype)
105 return out