Coverage for src/flag_gems/fused/rwkv_ka_fusion.py: 48%
40 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.jit
11def rwkv_ka_fusion_kernel(
12 k_ptr,
13 kk_ptr,
14 a_ptr,
15 ka_ptr,
16 o_k_ptr,
17 o_kk_ptr,
18 o_kka_ptr,
19 T,
20 C,
21 H,
22 N,
23 N_size: tl.constexpr,
24 block_size: tl.constexpr,
25):
26 pid = tl.program_id(axis=0)
27 k_start = pid * block_size
29 for i in range(0, H):
30 offs = k_start + i * N + tl.arange(0, N_size)
31 k = tl.load(k_ptr + offs, mask=offs < T * C, other=0.0)
32 a = tl.load(a_ptr + offs, mask=offs < T * C, other=0.0)
34 c_offs = i * N + tl.arange(0, N_size)
35 ka = tl.load(ka_ptr + c_offs, mask=c_offs < C, other=0.0)
36 kk = tl.load(kk_ptr + c_offs, mask=c_offs < C, other=0.0)
38 kt = k * kk
39 kt2 = kt * kt
40 norm_kt2 = tl.sum(kt2.to(tl.float32))
41 norm_kt = tl.sqrt(norm_kt2 + 1e-12)
42 okk = kt / norm_kt
43 tl.store(o_kk_ptr + offs, okk, mask=offs < T * C)
45 ok = k * (1 + (a.to(tl.float32) - 1) * ka)
46 okka = okk * a
47 tl.store(o_k_ptr + offs, ok, mask=offs < T * C)
48 tl.store(o_kka_ptr + offs, okka, mask=offs < T * C)
51def rwkv_ka_fusion(
52 k: torch.Tensor, kk: torch.Tensor, a: torch.Tensor, ka: torch.Tensor, H: int, N: int
53):
54 logger.debug("GEMS RWKV KA FUSION")
55 if k.dim() == 1:
56 T = 1
57 C = k.shape[0]
58 else:
59 T, C = k.shape
61 o_k = torch.empty_like(k)
62 o_kk = torch.empty_like(k)
63 o_kka = torch.empty_like(k)
65 BLOCK_SIZE = 1 * C
66 grid = lambda meta: (triton.cdiv(T * C, BLOCK_SIZE),)
67 N_size = triton.next_power_of_2(N)
68 rwkv_ka_fusion_kernel[grid](
69 k, kk, a, ka, o_k, o_kk, o_kka, T, C, H, N, N_size, BLOCK_SIZE
70 )
72 return o_k, o_kk, o_kka