Coverage for src/flag_gems/fused/rwkv_mm_sparsity.py: 53%
30 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def rwkv_mm_sparsity_kernel(
8 k_ptr,
9 v_ptr,
10 output_ptr,
11 v_cols: tl.constexpr,
12 blk_size: tl.constexpr,
13 k_size: tl.constexpr,
14 block_size: tl.constexpr,
15):
16 pid = tl.program_id(0)
17 col_idx = pid * block_size + tl.arange(0, block_size)
18 col_mask = col_idx < v_cols
20 acc = tl.zeros((block_size,), dtype=tl.float32)
22 for i in range(0, tl.cdiv(k_size, blk_size)):
23 k_offset = i * blk_size + tl.arange(0, blk_size)
24 k_mask = k_offset < k_size
25 k = tl.load(k_ptr + k_offset, mask=k_mask, other=0.0)
26 k_nonzero_mask = k != 0
28 v_ptr_block = v_ptr + k_offset[:, None] * v_cols + col_idx[None, :]
29 v = tl.load(
30 v_ptr_block,
31 mask=k_mask[:, None] & col_mask[None, :] & k_nonzero_mask[:, None],
32 other=0.0,
33 )
34 acc += tl.sum(k[:, None].to(tl.float32) * v.to(tl.float32), axis=0)
36 out_ptr = output_ptr + col_idx
37 tl.store(out_ptr, acc, mask=col_mask)
40def rwkv_mm_sparsity(k: torch.Tensor, v: torch.Tensor):
41 assert k.dim() == 1 and v.dim() == 2
42 assert k.size(0) == v.size(0)
44 v_cols = v.size(1)
45 output = torch.empty(v_cols, device=k.device, dtype=k.dtype)
47 blk_size = triton.next_power_of_2(512)
48 block_size = triton.next_power_of_2(16)
49 k_size = triton.next_power_of_2(k.size(0))
50 grid = (triton.cdiv(v_cols, block_size),)
52 rwkv_mm_sparsity_kernel[grid](
53 k,
54 v,
55 output,
56 v_cols,
57 blk_size,
58 k_size,
59 block_size,
60 )
61 return output