Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/moe_align_block_size.py: 0%

76 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11def ceil_div(a, b): 

12 return (a + b - 1) // b 

13 

14 

15def round_up(x: int, y: int) -> int: 

16 return ((x + y - 1) // y) * y 

17 

18 

19@triton.jit(do_not_specialize=["numel", "tokens_per_thread"]) 

20def moe_align_block_size_stage1( 

21 topk_ids_ptr, 

22 tokens_cnts_ptr, 

23 num_experts: tl.constexpr, 

24 numel, 

25 tokens_per_thread, 

26): 

27 pid = tl.program_id(0) 

28 

29 start_idx = pid * tokens_per_thread 

30 

31 off_c = (pid + 1) * num_experts 

32 

33 for i in range(tokens_per_thread): 

34 if start_idx + i < numel: 

35 idx = tl.load(topk_ids_ptr + start_idx + i) 

36 token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) 

37 tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) 

38 

39 

40@triton.jit 

41def moe_align_block_size_stage2( 

42 tokens_cnts_ptr, 

43 num_experts: tl.constexpr, 

44): 

45 pid = tl.program_id(0) 

46 

47 last_cnt = 0 

48 for i in range(1, num_experts + 1): 

49 token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) 

50 last_cnt = last_cnt + token_cnt 

51 tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) 

52 

53 

54@triton.jit 

55def moe_align_block_size_stage3( 

56 total_tokens_post_pad_ptr, 

57 tokens_cnts_ptr, 

58 cumsum_ptr, 

59 num_experts: tl.constexpr, 

60 block_size: tl.constexpr, 

61): 

62 last_cumsum = 0 

63 off_cnt = num_experts * num_experts 

64 for i in range(1, num_experts + 1): 

65 token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) 

66 last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size 

67 tl.store(cumsum_ptr + i, last_cumsum) 

68 tl.store(total_tokens_post_pad_ptr, last_cumsum) 

69 

70 

71@triton.jit(do_not_specialize=["numel", "tokens_per_thread"]) 

72def moe_align_block_size_stage4( 

73 topk_ids_ptr, 

74 sorted_token_ids_ptr, 

75 expert_ids_ptr, 

76 tokens_cnts_ptr, 

77 cumsum_ptr, 

78 num_experts: tl.constexpr, 

79 block_size: tl.constexpr, 

80 numel, 

81 tokens_per_thread, 

82): 

83 pid = tl.program_id(0) 

84 start_idx = tl.load(cumsum_ptr + pid) 

85 end_idx = tl.load(cumsum_ptr + pid + 1) 

86 

87 for i in range(start_idx, end_idx, block_size): 

88 tl.store(expert_ids_ptr + i // block_size, pid) 

89 

90 start_idx = pid * tokens_per_thread 

91 off_t = pid * num_experts 

92 

93 for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): 

94 expert_id = tl.load(topk_ids_ptr + i) 

95 token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) 

96 rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) 

97 tl.store(sorted_token_ids_ptr + rank_post_pad, i) 

98 tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) 

99 

100 

101def moe_align_block_size_triton( 

102 topk_ids: torch.Tensor, 

103 num_experts: int, 

104 block_size: int, 

105 sorted_token_ids: torch.Tensor, 

106 expert_ids: torch.Tensor, 

107 num_tokens_post_pad: torch.Tensor, 

108) -> None: 

109 numel = topk_ids.numel() 

110 grid = (num_experts,) 

111 tokens_cnts = torch.zeros( 

112 (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device 

113 ) 

114 cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) 

115 tokens_per_thread = ceil_div(numel, num_experts) 

116 

117 moe_align_block_size_stage1[grid]( 

118 topk_ids, 

119 tokens_cnts, 

120 num_experts, 

121 numel, 

122 tokens_per_thread, 

123 ) 

124 moe_align_block_size_stage2[grid]( 

125 tokens_cnts, 

126 num_experts, 

127 ) 

128 moe_align_block_size_stage3[(1,)]( 

129 num_tokens_post_pad, 

130 tokens_cnts, 

131 cumsum, 

132 num_experts, 

133 block_size, 

134 ) 

135 moe_align_block_size_stage4[grid]( 

136 topk_ids, 

137 sorted_token_ids, 

138 expert_ids, 

139 tokens_cnts, 

140 cumsum, 

141 num_experts, 

142 block_size, 

143 numel, 

144 tokens_per_thread, 

145 ) 

146 

147 

148def moe_align_block_size( 

149 topk_ids: torch.Tensor, 

150 block_size: int, 

151 num_experts: int, 

152 expert_map: Optional[torch.Tensor] = None, 

153 pad_sorted_ids: bool = False, 

154) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor]": 

155 logger.debug("GEMS MOE ALIGN BLOCK SIZE") 

156 max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) 

157 if pad_sorted_ids: 

158 max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) 

159 sorted_ids = torch.empty( 

160 (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device 

161 ) 

162 sorted_ids.fill_(topk_ids.numel()) 

163 max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) 

164 expert_ids = torch.zeros( 

165 (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device 

166 ) 

167 num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) 

168 

169 moe_align_block_size_triton( 

170 topk_ids, 

171 num_experts, 

172 block_size, 

173 sorted_ids, 

174 expert_ids, 

175 num_tokens_post_pad, 

176 ) 

177 

178 if expert_map is not None: 

179 expert_ids = expert_map[expert_ids] 

180 

181 return sorted_ids, expert_ids, num_tokens_post_pad