Coverage for src/flag_gems/fused/topk_softmax.py: 12%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def topk_gating_softmax_kernel( 

8 input_ptr, 

9 finished_ptr, # interface reserved, not yet used 

10 output_ptr, 

11 indices_ptr, 

12 source_rows_ptr, 

13 num_rows, 

14 k, 

15 num_experts, 

16 start_expert, 

17 end_expert, 

18 renormalize: tl.constexpr, 

19 INDEX_TY: tl.constexpr, 

20 BLOCK_SIZE_ROWS: tl.constexpr, 

21 BLOCK_SIZE_EXPERTS: tl.constexpr, 

22): 

23 pid = tl.program_id(0) 

24 rows = tl.arange(0, BLOCK_SIZE_ROWS) + pid * BLOCK_SIZE_ROWS 

25 valid_rows = rows < num_rows 

26 

27 cols = start_expert + tl.arange(0, BLOCK_SIZE_EXPERTS) 

28 valid_cols = cols < end_expert 

29 

30 logits = tl.load( 

31 input_ptr + rows[:, None] * num_experts + cols[None, :], 

32 mask=valid_rows[:, None] & valid_cols[None, :], 

33 other=-float("inf"), 

34 ).to(tl.float32) 

35 

36 row_max = tl.max(logits, axis=1)[:, None] 

37 exp_vals = tl.exp(logits - row_max) 

38 probs = exp_vals / (tl.sum(exp_vals, axis=1)[:, None] + 1e-8) 

39 

40 selected_sum = tl.zeros([BLOCK_SIZE_ROWS], dtype=tl.float32) 

41 for ki in range(k): 

42 curr_max, curr_arg = tl.max(probs, axis=1, return_indices=True) 

43 

44 tl.store(output_ptr + rows * k + ki, curr_max, mask=valid_rows) 

45 tl.store(indices_ptr + rows * k + ki, curr_arg.to(INDEX_TY), mask=valid_rows) 

46 tl.store( 

47 source_rows_ptr + rows * k + ki, 

48 (ki * num_rows + rows).to(tl.int32), 

49 mask=valid_rows, 

50 ) 

51 if renormalize: 

52 selected_sum += curr_max 

53 

54 probs = tl.where( 

55 cols[None, :] == (curr_arg[:, None] - start_expert), -float("inf"), probs 

56 ) 

57 

58 if renormalize: 

59 norm = selected_sum + 1e-8 

60 for ki in range(k): 

61 idx = rows * k + ki 

62 val = tl.load(output_ptr + idx, mask=valid_rows) 

63 tl.store(output_ptr + idx, val / norm, mask=valid_rows) 

64 

65 

66def topk_softmax( 

67 topk_weights: torch.Tensor, 

68 topk_indices: torch.Tensor, 

69 token_expert_indices: torch.Tensor, 

70 gating_output: torch.Tensor, 

71 renormalize: bool = False, 

72) -> None: 

73 num_tokens, num_experts = gating_output.shape 

74 topk = topk_weights.size(-1) 

75 assert topk <= 32 

76 

77 if topk_indices.dtype == torch.int32: 

78 index_ty = tl.int32 

79 elif topk_indices.dtype == torch.uint32: 

80 index_ty = tl.uint32 

81 elif topk_indices.dtype == torch.int64: 

82 index_ty = tl.int64 

83 else: 

84 raise TypeError("topk_indices must be int32/int64/uint32") 

85 

86 max_total_threads = 1024 

87 BLOCK_SIZE_EXPERTS = ((triton.next_power_of_2(num_experts) + 31) // 32) * 32 

88 BLOCK_SIZE_EXPERTS = min(BLOCK_SIZE_EXPERTS, 1024) 

89 BLOCK_SIZE_ROWS = max_total_threads // BLOCK_SIZE_EXPERTS 

90 BLOCK_SIZE_ROWS = max(BLOCK_SIZE_ROWS, 1) 

91 

92 # If num_experts > 128, intra-warp shuffling is forced for reduction, 

93 # which requires the warp layout to be confined to a single row. 

94 # Consequently, in the TTGIR, the second dimension of warpsPerCTA is fixed to 1. 

95 if num_experts > 128: 

96 BLOCK_SIZE_ROWS = 1 

97 num_warps = 1 

98 else: 

99 num_warps = 4 

100 

101 grid = (triton.cdiv(num_tokens, BLOCK_SIZE_ROWS),) 

102 topk_gating_softmax_kernel[grid]( 

103 input_ptr=gating_output, 

104 finished_ptr=None, 

105 output_ptr=topk_weights, 

106 indices_ptr=topk_indices, 

107 source_rows_ptr=token_expert_indices, 

108 num_rows=num_tokens, 

109 k=topk, 

110 num_experts=num_experts, 

111 start_expert=0, 

112 end_expert=num_experts, 

113 renormalize=renormalize, 

114 INDEX_TY=index_ty, 

115 BLOCK_SIZE_ROWS=BLOCK_SIZE_ROWS, 

116 BLOCK_SIZE_EXPERTS=BLOCK_SIZE_EXPERTS, 

117 num_warps=num_warps, 

118 )