Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/rwkv_mm_sparsity.py: 0%

33 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def rwkv_mm_sparsity_kernel( 

8 k_ptr, 

9 v_ptr, 

10 output_ptr, 

11 v_cols: tl.constexpr, 

12 k_size: tl.constexpr, 

13 BLOCK_SIZE_N: tl.constexpr, 

14 BLOCK_SIZE_K: tl.constexpr, 

15): 

16 """ 

17 完全使用 2D 张量实现矩阵-向量乘法: output[1, N] = k[1, K] @ V[K, N] 

18 所有中间变量保持 2D 形状,无需 tl.sum 挤压维度。 

19 """ 

20 pid_n = tl.program_id(axis=0) 

21 

22 offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

23 mask_n = offs_n < v_cols 

24 accumulator = tl.zeros((1, BLOCK_SIZE_N), dtype=tl.float32) 

25 

26 for k_block_idx in range(0, tl.cdiv(k_size, BLOCK_SIZE_K)): 

27 offs_k = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 

28 mask_k = offs_k < k_size 

29 k_ptrs = k_ptr + offs_k 

30 k_block_1d = tl.load(k_ptrs, mask=mask_k, other=0.0).to(tl.float32) 

31 k_block = k_block_1d[None, :] 

32 v_ptrs = v_ptr + (offs_k[:, None] * v_cols) + offs_n[None, :] 

33 v_mask = mask_k[:, None] & mask_n[None, :] 

34 v_block = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.float32) 

35 

36 # k_block: (1, BLOCK_SIZE_K), v_block: (BLOCK_SIZE_K, BLOCK_SIZE_N) -> accumulator: (1, BLOCK_SIZE_N) 

37 accumulator += tl.dot(k_block, v_block, allow_tf32=False) 

38 

39 output_ptrs = output_ptr + offs_n 

40 output_1d = tl.view(accumulator, (BLOCK_SIZE_N,)) 

41 tl.store(output_ptrs, output_1d, mask=mask_n) 

42 

43 

44def rwkv_mm_sparsity(k: torch.Tensor, v: torch.Tensor): 

45 assert k.dim() == 1 and v.dim() == 2 

46 assert k.size(0) == v.size(0) 

47 

48 v_cols = v.size(1) 

49 output = torch.empty(v_cols, device=k.device, dtype=k.dtype) 

50 

51 blk_size = triton.next_power_of_2(256) 

52 k_size = triton.next_power_of_2(k.size(0)) 

53 block_size = triton.next_power_of_2(128) if 128 < k_size else k_size 

54 grid = (triton.cdiv(v_cols, block_size),) 

55 

56 rwkv_mm_sparsity_kernel[grid]( 

57 k, 

58 v, 

59 output, 

60 v_cols, 

61 k_size, 

62 blk_size, 

63 block_size, 

64 ) 

65 return output