Coverage for src/flag_gems/fused/moe_sum.py: 44%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.autotune( 

7 configs=[ 

8 triton.Config({"BLOCK_SIZE": 128}, num_warps=2), 

9 triton.Config({"BLOCK_SIZE": 256}, num_warps=4), 

10 triton.Config({"BLOCK_SIZE": 512}, num_warps=8), 

11 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8), 

12 ], 

13 key=["hidden_size", "topk"], 

14) 

15@triton.jit 

16def moe_sum_kernel( 

17 input_ptr, 

18 output_ptr, 

19 num_tokens, 

20 topk, 

21 hidden_size, 

22 input_stride_token, 

23 input_stride_topk, 

24 input_stride_hidden, 

25 output_stride_token, 

26 output_stride_hidden, 

27 BLOCK_SIZE: tl.constexpr, 

28): 

29 token_idx = tl.program_id(0) 

30 block_idx = tl.program_id(1) 

31 hidden_start = block_idx * BLOCK_SIZE 

32 hidden_offsets = hidden_start + tl.arange(0, BLOCK_SIZE) 

33 hidden_mask = hidden_offsets < hidden_size 

34 if token_idx >= num_tokens: 

35 return 

36 acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 

37 input_base = input_ptr + token_idx * input_stride_token 

38 

39 for expert_idx in range(topk): 

40 expert_ptr = input_base + expert_idx * input_stride_topk 

41 expert_data = tl.load(expert_ptr + hidden_offsets, mask=hidden_mask, other=0.0) 

42 acc += expert_data 

43 output_ptr_pos = output_ptr + token_idx * output_stride_token + hidden_offsets 

44 

45 tl.store( 

46 output_ptr_pos, 

47 acc.to(tl.float16) if input_ptr.dtype.element_ty == tl.float16 else acc, 

48 mask=hidden_mask, 

49 ) 

50 

51 

52def moe_sum( 

53 input: torch.Tensor, 

54 output: torch.Tensor, 

55): 

56 num_tokens, topk, hidden_size = input.shape 

57 input_strides = input.stride() 

58 output_strides = output.stride() 

59 grid = lambda meta: (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) 

60 moe_sum_kernel[grid]( 

61 input, 

62 output, 

63 num_tokens, 

64 topk, 

65 hidden_size, 

66 input_strides[0], 

67 input_strides[1], 

68 input_strides[2], 

69 output_strides[0], 

70 output_strides[1], 

71 )