Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/rwkv_mm_sparsity.py: 0%
33 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def rwkv_mm_sparsity_kernel(
8 k_ptr,
9 v_ptr,
10 output_ptr,
11 v_cols: tl.constexpr,
12 k_size: tl.constexpr,
13 BLOCK_SIZE_N: tl.constexpr,
14 BLOCK_SIZE_K: tl.constexpr,
15):
16 """
17 完全使用 2D 张量实现矩阵-向量乘法: output[1, N] = k[1, K] @ V[K, N]
18 所有中间变量保持 2D 形状,无需 tl.sum 挤压维度。
19 """
20 pid_n = tl.program_id(axis=0)
22 offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
23 mask_n = offs_n < v_cols
24 accumulator = tl.zeros((1, BLOCK_SIZE_N), dtype=tl.float32)
26 for k_block_idx in range(0, tl.cdiv(k_size, BLOCK_SIZE_K)):
27 offs_k = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
28 mask_k = offs_k < k_size
29 k_ptrs = k_ptr + offs_k
30 k_block_1d = tl.load(k_ptrs, mask=mask_k, other=0.0).to(tl.float32)
31 k_block = k_block_1d[None, :]
32 v_ptrs = v_ptr + (offs_k[:, None] * v_cols) + offs_n[None, :]
33 v_mask = mask_k[:, None] & mask_n[None, :]
34 v_block = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.float32)
36 # k_block: (1, BLOCK_SIZE_K), v_block: (BLOCK_SIZE_K, BLOCK_SIZE_N) -> accumulator: (1, BLOCK_SIZE_N)
37 accumulator += tl.dot(k_block, v_block, allow_tf32=False)
39 output_ptrs = output_ptr + offs_n
40 output_1d = tl.view(accumulator, (BLOCK_SIZE_N,))
41 tl.store(output_ptrs, output_1d, mask=mask_n)
44def rwkv_mm_sparsity(k: torch.Tensor, v: torch.Tensor):
45 assert k.dim() == 1 and v.dim() == 2
46 assert k.size(0) == v.size(0)
48 v_cols = v.size(1)
49 output = torch.empty(v_cols, device=k.device, dtype=k.dtype)
51 blk_size = triton.next_power_of_2(256)
52 k_size = triton.next_power_of_2(k.size(0))
53 block_size = triton.next_power_of_2(128) if 128 < k_size else k_size
54 grid = (triton.cdiv(v_cols, block_size),)
56 rwkv_mm_sparsity_kernel[grid](
57 k,
58 v,
59 output,
60 v_cols,
61 k_size,
62 blk_size,
63 block_size,
64 )
65 return output