Coverage for src/flag_gems/fused/topk_softmax.py: 15%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.jit 

11def topk_gating_softmax_kernel( 

12 input_ptr, 

13 finished_ptr, # interface reserved, not yet used 

14 output_ptr, 

15 indices_ptr, 

16 source_rows_ptr, 

17 num_rows, 

18 k, 

19 num_experts, 

20 start_expert, 

21 end_expert, 

22 renormalize: tl.constexpr, 

23 INDEX_TY: tl.constexpr, 

24 BLOCK_SIZE_ROWS: tl.constexpr, 

25 BLOCK_SIZE_EXPERTS: tl.constexpr, 

26): 

27 pid = tl.program_id(0) 

28 rows = tl.arange(0, BLOCK_SIZE_ROWS) + pid * BLOCK_SIZE_ROWS 

29 valid_rows = rows < num_rows 

30 

31 cols = start_expert + tl.arange(0, BLOCK_SIZE_EXPERTS) 

32 valid_cols = cols < end_expert 

33 

34 logits = tl.load( 

35 input_ptr + rows[:, None] * num_experts + cols[None, :], 

36 mask=valid_rows[:, None] & valid_cols[None, :], 

37 other=-float("inf"), 

38 ).to(tl.float32) 

39 

40 row_max = tl.max(logits, axis=1)[:, None] 

41 exp_vals = tl.exp(logits - row_max) 

42 probs = exp_vals / (tl.sum(exp_vals, axis=1)[:, None] + 1e-8) 

43 

44 selected_sum = tl.zeros([BLOCK_SIZE_ROWS], dtype=tl.float32) 

45 for ki in range(k): 

46 curr_max, curr_arg = tl.max(probs, axis=1, return_indices=True) 

47 

48 tl.store(output_ptr + rows * k + ki, curr_max, mask=valid_rows) 

49 tl.store(indices_ptr + rows * k + ki, curr_arg.to(INDEX_TY), mask=valid_rows) 

50 tl.store( 

51 source_rows_ptr + rows * k + ki, 

52 (ki * num_rows + rows).to(tl.int32), 

53 mask=valid_rows, 

54 ) 

55 if renormalize: 

56 selected_sum += curr_max 

57 

58 probs = tl.where( 

59 cols[None, :] == (curr_arg[:, None] - start_expert), -float("inf"), probs 

60 ) 

61 

62 if renormalize: 

63 norm = selected_sum + 1e-8 

64 for ki in range(k): 

65 idx = rows * k + ki 

66 val = tl.load(output_ptr + idx, mask=valid_rows) 

67 tl.store(output_ptr + idx, val / norm, mask=valid_rows) 

68 

69 

70def topk_softmax( 

71 topk_weights: torch.Tensor, 

72 topk_indices: torch.Tensor, 

73 token_expert_indices: torch.Tensor, 

74 gating_output: torch.Tensor, 

75 renormalize: bool = False, 

76) -> None: 

77 logger.debug("GEMS TOPK SOFTMAX") 

78 num_tokens, num_experts = gating_output.shape 

79 topk = topk_weights.size(-1) 

80 assert topk <= 32 

81 

82 if topk_indices.dtype == torch.int32: 

83 index_ty = tl.int32 

84 elif topk_indices.dtype == torch.uint32: 

85 index_ty = tl.uint32 

86 elif topk_indices.dtype == torch.int64: 

87 index_ty = tl.int64 

88 else: 

89 raise TypeError("topk_indices must be int32/int64/uint32") 

90 

91 max_total_threads = 1024 

92 BLOCK_SIZE_EXPERTS = ((triton.next_power_of_2(num_experts) + 31) // 32) * 32 

93 BLOCK_SIZE_EXPERTS = min(BLOCK_SIZE_EXPERTS, 1024) 

94 BLOCK_SIZE_ROWS = max_total_threads // BLOCK_SIZE_EXPERTS 

95 BLOCK_SIZE_ROWS = max(BLOCK_SIZE_ROWS, 1) 

96 

97 # If num_experts > 128, intra-warp shuffling is forced for reduction, 

98 # which requires the warp layout to be confined to a single row. 

99 # Consequently, in the TTGIR, the second dimension of warpsPerCTA is fixed to 1. 

100 if num_experts > 128: 

101 BLOCK_SIZE_ROWS = 1 

102 num_warps = 1 

103 else: 

104 num_warps = 4 

105 

106 grid = (triton.cdiv(num_tokens, BLOCK_SIZE_ROWS),) 

107 topk_gating_softmax_kernel[grid]( 

108 input_ptr=gating_output, 

109 finished_ptr=None, 

110 output_ptr=topk_weights, 

111 indices_ptr=topk_indices, 

112 source_rows_ptr=token_expert_indices, 

113 num_rows=num_tokens, 

114 k=topk, 

115 num_experts=num_experts, 

116 start_expert=0, 

117 end_expert=num_experts, 

118 renormalize=renormalize, 

119 INDEX_TY=index_ty, 

120 BLOCK_SIZE_ROWS=BLOCK_SIZE_ROWS, 

121 BLOCK_SIZE_EXPERTS=BLOCK_SIZE_EXPERTS, 

122 num_warps=num_warps, 

123 )