Coverage for src/flag_gems/fused/rwkv_mm_sparsity.py: 58%
33 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.jit
11def rwkv_mm_sparsity_kernel(
12 k_ptr,
13 v_ptr,
14 output_ptr,
15 v_cols: tl.constexpr,
16 blk_size: tl.constexpr,
17 k_size: tl.constexpr,
18 block_size: tl.constexpr,
19):
20 pid = tl.program_id(0)
21 col_idx = pid * block_size + tl.arange(0, block_size)
22 col_mask = col_idx < v_cols
24 acc = tl.zeros((block_size,), dtype=tl.float32)
26 for i in range(0, tl.cdiv(k_size, blk_size)):
27 k_offset = i * blk_size + tl.arange(0, blk_size)
28 k_mask = k_offset < k_size
29 k = tl.load(k_ptr + k_offset, mask=k_mask, other=0.0)
30 k_nonzero_mask = k != 0
32 v_ptr_block = v_ptr + k_offset[:, None] * v_cols + col_idx[None, :]
33 v = tl.load(
34 v_ptr_block,
35 mask=k_mask[:, None] & col_mask[None, :] & k_nonzero_mask[:, None],
36 other=0.0,
37 )
38 acc += tl.sum(k[:, None].to(tl.float32) * v.to(tl.float32), axis=0)
40 out_ptr = output_ptr + col_idx
41 tl.store(out_ptr, acc, mask=col_mask)
44def rwkv_mm_sparsity(k: torch.Tensor, v: torch.Tensor):
45 logger.debug("GEMS RWKV MM SPARSITY")
46 assert k.dim() == 1 and v.dim() == 2
47 assert k.size(0) == v.size(0)
49 v_cols = v.size(1)
50 output = torch.empty(v_cols, device=k.device, dtype=k.dtype)
52 blk_size = triton.next_power_of_2(512)
53 block_size = triton.next_power_of_2(16)
54 k_size = triton.next_power_of_2(k.size(0))
55 grid = (triton.cdiv(v_cols, block_size),)
57 rwkv_mm_sparsity_kernel[grid](
58 k,
59 v,
60 output,
61 v_cols,
62 blk_size,
63 k_size,
64 block_size,
65 )
66 return output