Coverage for src/flag_gems/fused/moe_sum.py: 50%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.autotune( 

11 configs=[ 

12 triton.Config({"BLOCK_SIZE": 128}, num_warps=2), 

13 triton.Config({"BLOCK_SIZE": 256}, num_warps=4), 

14 triton.Config({"BLOCK_SIZE": 512}, num_warps=8), 

15 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8), 

16 ], 

17 key=["hidden_size", "topk"], 

18) 

19@triton.jit 

20def moe_sum_kernel( 

21 input_ptr, 

22 output_ptr, 

23 num_tokens, 

24 topk, 

25 hidden_size, 

26 input_stride_token, 

27 input_stride_topk, 

28 input_stride_hidden, 

29 output_stride_token, 

30 output_stride_hidden, 

31 BLOCK_SIZE: tl.constexpr, 

32): 

33 token_idx = tl.program_id(0) 

34 block_idx = tl.program_id(1) 

35 hidden_start = block_idx * BLOCK_SIZE 

36 hidden_offsets = hidden_start + tl.arange(0, BLOCK_SIZE) 

37 hidden_mask = hidden_offsets < hidden_size 

38 if token_idx >= num_tokens: 

39 return 

40 acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 

41 input_base = input_ptr + token_idx * input_stride_token 

42 

43 for expert_idx in range(topk): 

44 expert_ptr = input_base + expert_idx * input_stride_topk 

45 expert_data = tl.load(expert_ptr + hidden_offsets, mask=hidden_mask, other=0.0) 

46 acc += expert_data 

47 output_ptr_pos = output_ptr + token_idx * output_stride_token + hidden_offsets 

48 

49 tl.store( 

50 output_ptr_pos, 

51 acc.to(tl.float16) if input_ptr.dtype.element_ty == tl.float16 else acc, 

52 mask=hidden_mask, 

53 ) 

54 

55 

56def moe_sum( 

57 input: torch.Tensor, 

58 output: torch.Tensor, 

59): 

60 logger.debug("GEMS MOE SUM") 

61 num_tokens, topk, hidden_size = input.shape 

62 input_strides = input.stride() 

63 output_strides = output.stride() 

64 grid = lambda meta: (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) 

65 moe_sum_kernel[grid]( 

66 input, 

67 output, 

68 num_tokens, 

69 topk, 

70 hidden_size, 

71 input_strides[0], 

72 input_strides[1], 

73 input_strides[2], 

74 output_strides[0], 

75 output_strides[1], 

76 )