Coverage for src/flag_gems/fused/moe_sum.py: 50%
30 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.autotune(
11 configs=[
12 triton.Config({"BLOCK_SIZE": 128}, num_warps=2),
13 triton.Config({"BLOCK_SIZE": 256}, num_warps=4),
14 triton.Config({"BLOCK_SIZE": 512}, num_warps=8),
15 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8),
16 ],
17 key=["hidden_size", "topk"],
18)
19@triton.jit
20def moe_sum_kernel(
21 input_ptr,
22 output_ptr,
23 num_tokens,
24 topk,
25 hidden_size,
26 input_stride_token,
27 input_stride_topk,
28 input_stride_hidden,
29 output_stride_token,
30 output_stride_hidden,
31 BLOCK_SIZE: tl.constexpr,
32):
33 token_idx = tl.program_id(0)
34 block_idx = tl.program_id(1)
35 hidden_start = block_idx * BLOCK_SIZE
36 hidden_offsets = hidden_start + tl.arange(0, BLOCK_SIZE)
37 hidden_mask = hidden_offsets < hidden_size
38 if token_idx >= num_tokens:
39 return
40 acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
41 input_base = input_ptr + token_idx * input_stride_token
43 for expert_idx in range(topk):
44 expert_ptr = input_base + expert_idx * input_stride_topk
45 expert_data = tl.load(expert_ptr + hidden_offsets, mask=hidden_mask, other=0.0)
46 acc += expert_data
47 output_ptr_pos = output_ptr + token_idx * output_stride_token + hidden_offsets
49 tl.store(
50 output_ptr_pos,
51 acc.to(tl.float16) if input_ptr.dtype.element_ty == tl.float16 else acc,
52 mask=hidden_mask,
53 )
56def moe_sum(
57 input: torch.Tensor,
58 output: torch.Tensor,
59):
60 logger.debug("GEMS MOE SUM")
61 num_tokens, topk, hidden_size = input.shape
62 input_strides = input.stride()
63 output_strides = output.stride()
64 grid = lambda meta: (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"]))
65 moe_sum_kernel[grid](
66 input,
67 output,
68 num_tokens,
69 topk,
70 hidden_size,
71 input_strides[0],
72 input_strides[1],
73 input_strides[2],
74 output_strides[0],
75 output_strides[1],
76 )