Coverage for src/flag_gems/fused/moe_sum.py: 44%
27 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.autotune(
7 configs=[
8 triton.Config({"BLOCK_SIZE": 128}, num_warps=2),
9 triton.Config({"BLOCK_SIZE": 256}, num_warps=4),
10 triton.Config({"BLOCK_SIZE": 512}, num_warps=8),
11 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8),
12 ],
13 key=["hidden_size", "topk"],
14)
15@triton.jit
16def moe_sum_kernel(
17 input_ptr,
18 output_ptr,
19 num_tokens,
20 topk,
21 hidden_size,
22 input_stride_token,
23 input_stride_topk,
24 input_stride_hidden,
25 output_stride_token,
26 output_stride_hidden,
27 BLOCK_SIZE: tl.constexpr,
28):
29 token_idx = tl.program_id(0)
30 block_idx = tl.program_id(1)
31 hidden_start = block_idx * BLOCK_SIZE
32 hidden_offsets = hidden_start + tl.arange(0, BLOCK_SIZE)
33 hidden_mask = hidden_offsets < hidden_size
34 if token_idx >= num_tokens:
35 return
36 acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
37 input_base = input_ptr + token_idx * input_stride_token
39 for expert_idx in range(topk):
40 expert_ptr = input_base + expert_idx * input_stride_topk
41 expert_data = tl.load(expert_ptr + hidden_offsets, mask=hidden_mask, other=0.0)
42 acc += expert_data
43 output_ptr_pos = output_ptr + token_idx * output_stride_token + hidden_offsets
45 tl.store(
46 output_ptr_pos,
47 acc.to(tl.float16) if input_ptr.dtype.element_ty == tl.float16 else acc,
48 mask=hidden_mask,
49 )
52def moe_sum(
53 input: torch.Tensor,
54 output: torch.Tensor,
55):
56 num_tokens, topk, hidden_size = input.shape
57 input_strides = input.stride()
58 output_strides = output.stride()
59 grid = lambda meta: (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"]))
60 moe_sum_kernel[grid](
61 input,
62 output,
63 num_tokens,
64 topk,
65 hidden_size,
66 input_strides[0],
67 input_strides[1],
68 input_strides[2],
69 output_strides[0],
70 output_strides[1],
71 )