Coverage for src/flag_gems/fused/rwkv_ka_fusion.py: 43%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def rwkv_ka_fusion_kernel( 

8 k_ptr, 

9 kk_ptr, 

10 a_ptr, 

11 ka_ptr, 

12 o_k_ptr, 

13 o_kk_ptr, 

14 o_kka_ptr, 

15 T, 

16 C, 

17 H, 

18 N, 

19 N_size: tl.constexpr, 

20 block_size: tl.constexpr, 

21): 

22 pid = tl.program_id(axis=0) 

23 k_start = pid * block_size 

24 

25 for i in range(0, H): 

26 offs = k_start + i * N + tl.arange(0, N_size) 

27 k = tl.load(k_ptr + offs, mask=offs < T * C, other=0.0) 

28 a = tl.load(a_ptr + offs, mask=offs < T * C, other=0.0) 

29 

30 c_offs = i * N + tl.arange(0, N_size) 

31 ka = tl.load(ka_ptr + c_offs, mask=c_offs < C, other=0.0) 

32 kk = tl.load(kk_ptr + c_offs, mask=c_offs < C, other=0.0) 

33 

34 kt = k * kk 

35 kt2 = kt * kt 

36 norm_kt2 = tl.sum(kt2.to(tl.float32)) 

37 norm_kt = tl.sqrt(norm_kt2 + 1e-12) 

38 okk = kt / norm_kt 

39 tl.store(o_kk_ptr + offs, okk, mask=offs < T * C) 

40 

41 ok = k * (1 + (a.to(tl.float32) - 1) * ka) 

42 okka = okk * a 

43 tl.store(o_k_ptr + offs, ok, mask=offs < T * C) 

44 tl.store(o_kka_ptr + offs, okka, mask=offs < T * C) 

45 

46 

47def rwkv_ka_fusion( 

48 k: torch.Tensor, kk: torch.Tensor, a: torch.Tensor, ka: torch.Tensor, H: int, N: int 

49): 

50 if k.dim() == 1: 

51 T = 1 

52 C = k.shape[0] 

53 else: 

54 T, C = k.shape 

55 

56 o_k = torch.empty_like(k) 

57 o_kk = torch.empty_like(k) 

58 o_kka = torch.empty_like(k) 

59 

60 BLOCK_SIZE = 1 * C 

61 grid = lambda meta: (triton.cdiv(T * C, BLOCK_SIZE),) 

62 N_size = triton.next_power_of_2(N) 

63 rwkv_ka_fusion_kernel[grid]( 

64 k, kk, a, ka, o_k, o_kk, o_kka, T, C, H, N, N_size, BLOCK_SIZE 

65 ) 

66 

67 return o_k, o_kk, o_kka