Coverage for src/flag_gems/fused/rwkv_mm_sparsity.py: 53%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def rwkv_mm_sparsity_kernel( 

8 k_ptr, 

9 v_ptr, 

10 output_ptr, 

11 v_cols: tl.constexpr, 

12 blk_size: tl.constexpr, 

13 k_size: tl.constexpr, 

14 block_size: tl.constexpr, 

15): 

16 pid = tl.program_id(0) 

17 col_idx = pid * block_size + tl.arange(0, block_size) 

18 col_mask = col_idx < v_cols 

19 

20 acc = tl.zeros((block_size,), dtype=tl.float32) 

21 

22 for i in range(0, tl.cdiv(k_size, blk_size)): 

23 k_offset = i * blk_size + tl.arange(0, blk_size) 

24 k_mask = k_offset < k_size 

25 k = tl.load(k_ptr + k_offset, mask=k_mask, other=0.0) 

26 k_nonzero_mask = k != 0 

27 

28 v_ptr_block = v_ptr + k_offset[:, None] * v_cols + col_idx[None, :] 

29 v = tl.load( 

30 v_ptr_block, 

31 mask=k_mask[:, None] & col_mask[None, :] & k_nonzero_mask[:, None], 

32 other=0.0, 

33 ) 

34 acc += tl.sum(k[:, None].to(tl.float32) * v.to(tl.float32), axis=0) 

35 

36 out_ptr = output_ptr + col_idx 

37 tl.store(out_ptr, acc, mask=col_mask) 

38 

39 

40def rwkv_mm_sparsity(k: torch.Tensor, v: torch.Tensor): 

41 assert k.dim() == 1 and v.dim() == 2 

42 assert k.size(0) == v.size(0) 

43 

44 v_cols = v.size(1) 

45 output = torch.empty(v_cols, device=k.device, dtype=k.dtype) 

46 

47 blk_size = triton.next_power_of_2(512) 

48 block_size = triton.next_power_of_2(16) 

49 k_size = triton.next_power_of_2(k.size(0)) 

50 grid = (triton.cdiv(v_cols, block_size),) 

51 

52 rwkv_mm_sparsity_kernel[grid]( 

53 k, 

54 v, 

55 output, 

56 v_cols, 

57 blk_size, 

58 k_size, 

59 block_size, 

60 ) 

61 return output