Coverage for src/flag_gems/fused/rwkv_mm_sparsity.py: 58%

33 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.jit 

11def rwkv_mm_sparsity_kernel( 

12 k_ptr, 

13 v_ptr, 

14 output_ptr, 

15 v_cols: tl.constexpr, 

16 blk_size: tl.constexpr, 

17 k_size: tl.constexpr, 

18 block_size: tl.constexpr, 

19): 

20 pid = tl.program_id(0) 

21 col_idx = pid * block_size + tl.arange(0, block_size) 

22 col_mask = col_idx < v_cols 

23 

24 acc = tl.zeros((block_size,), dtype=tl.float32) 

25 

26 for i in range(0, tl.cdiv(k_size, blk_size)): 

27 k_offset = i * blk_size + tl.arange(0, blk_size) 

28 k_mask = k_offset < k_size 

29 k = tl.load(k_ptr + k_offset, mask=k_mask, other=0.0) 

30 k_nonzero_mask = k != 0 

31 

32 v_ptr_block = v_ptr + k_offset[:, None] * v_cols + col_idx[None, :] 

33 v = tl.load( 

34 v_ptr_block, 

35 mask=k_mask[:, None] & col_mask[None, :] & k_nonzero_mask[:, None], 

36 other=0.0, 

37 ) 

38 acc += tl.sum(k[:, None].to(tl.float32) * v.to(tl.float32), axis=0) 

39 

40 out_ptr = output_ptr + col_idx 

41 tl.store(out_ptr, acc, mask=col_mask) 

42 

43 

44def rwkv_mm_sparsity(k: torch.Tensor, v: torch.Tensor): 

45 logger.debug("GEMS RWKV MM SPARSITY") 

46 assert k.dim() == 1 and v.dim() == 2 

47 assert k.size(0) == v.size(0) 

48 

49 v_cols = v.size(1) 

50 output = torch.empty(v_cols, device=k.device, dtype=k.dtype) 

51 

52 blk_size = triton.next_power_of_2(512) 

53 block_size = triton.next_power_of_2(16) 

54 k_size = triton.next_power_of_2(k.size(0)) 

55 grid = (triton.cdiv(v_cols, block_size),) 

56 

57 rwkv_mm_sparsity_kernel[grid]( 

58 k, 

59 v, 

60 output, 

61 v_cols, 

62 blk_size, 

63 k_size, 

64 block_size, 

65 ) 

66 return output