Coverage for src/flag_gems/fused/rwkv_ka_fusion.py: 48%

40 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.jit 

11def rwkv_ka_fusion_kernel( 

12 k_ptr, 

13 kk_ptr, 

14 a_ptr, 

15 ka_ptr, 

16 o_k_ptr, 

17 o_kk_ptr, 

18 o_kka_ptr, 

19 T, 

20 C, 

21 H, 

22 N, 

23 N_size: tl.constexpr, 

24 block_size: tl.constexpr, 

25): 

26 pid = tl.program_id(axis=0) 

27 k_start = pid * block_size 

28 

29 for i in range(0, H): 

30 offs = k_start + i * N + tl.arange(0, N_size) 

31 k = tl.load(k_ptr + offs, mask=offs < T * C, other=0.0) 

32 a = tl.load(a_ptr + offs, mask=offs < T * C, other=0.0) 

33 

34 c_offs = i * N + tl.arange(0, N_size) 

35 ka = tl.load(ka_ptr + c_offs, mask=c_offs < C, other=0.0) 

36 kk = tl.load(kk_ptr + c_offs, mask=c_offs < C, other=0.0) 

37 

38 kt = k * kk 

39 kt2 = kt * kt 

40 norm_kt2 = tl.sum(kt2.to(tl.float32)) 

41 norm_kt = tl.sqrt(norm_kt2 + 1e-12) 

42 okk = kt / norm_kt 

43 tl.store(o_kk_ptr + offs, okk, mask=offs < T * C) 

44 

45 ok = k * (1 + (a.to(tl.float32) - 1) * ka) 

46 okka = okk * a 

47 tl.store(o_k_ptr + offs, ok, mask=offs < T * C) 

48 tl.store(o_kka_ptr + offs, okka, mask=offs < T * C) 

49 

50 

51def rwkv_ka_fusion( 

52 k: torch.Tensor, kk: torch.Tensor, a: torch.Tensor, ka: torch.Tensor, H: int, N: int 

53): 

54 logger.debug("GEMS RWKV KA FUSION") 

55 if k.dim() == 1: 

56 T = 1 

57 C = k.shape[0] 

58 else: 

59 T, C = k.shape 

60 

61 o_k = torch.empty_like(k) 

62 o_kk = torch.empty_like(k) 

63 o_kka = torch.empty_like(k) 

64 

65 BLOCK_SIZE = 1 * C 

66 grid = lambda meta: (triton.cdiv(T * C, BLOCK_SIZE),) 

67 N_size = triton.next_power_of_2(N) 

68 rwkv_ka_fusion_kernel[grid]( 

69 k, kk, a, ka, o_k, o_kk, o_kka, T, C, H, N, N_size, BLOCK_SIZE 

70 ) 

71 

72 return o_k, o_kk, o_kka