Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/topk_softmax.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def topk_gating_softmax_kernel( 

8 input_ptr, 

9 finished_ptr, # interface reserved, not yet used 

10 output_ptr, 

11 indices_ptr, 

12 source_rows_ptr, 

13 num_rows, 

14 k, 

15 num_experts, 

16 start_expert, 

17 end_expert, 

18 INDEX_TY: tl.constexpr, 

19 BLOCK_SIZE_ROWS: tl.constexpr, 

20 BLOCK_SIZE_EXPERTS: tl.constexpr, 

21): 

22 pid = tl.program_id(0) 

23 rows = tl.arange(0, BLOCK_SIZE_ROWS) + pid * BLOCK_SIZE_ROWS 

24 valid_rows = rows < num_rows 

25 

26 cols = start_expert + tl.arange(0, BLOCK_SIZE_EXPERTS) 

27 valid_cols = cols < end_expert 

28 

29 logits = tl.load( 

30 input_ptr + rows[:, None] * num_experts + cols[None, :], 

31 mask=valid_rows[:, None] & valid_cols[None, :], 

32 other=-float("inf"), 

33 ) 

34 

35 row_max = tl.max(logits, axis=1)[:, None] 

36 exp_vals = tl.exp(logits - row_max) 

37 probs = exp_vals / (tl.sum(exp_vals, axis=1)[:, None] + 1e-8) 

38 

39 for ki in range(k): 

40 curr_max = tl.max(probs, axis=1) 

41 curr_arg = tl.argmax(probs, axis=1) + start_expert 

42 

43 tl.store(output_ptr + rows * k + ki, curr_max, mask=valid_rows) 

44 tl.store(indices_ptr + rows * k + ki, curr_arg.to(INDEX_TY), mask=valid_rows) 

45 tl.store( 

46 source_rows_ptr + rows * k + ki, 

47 (ki * num_rows + rows).to(tl.int32), 

48 mask=valid_rows, 

49 ) 

50 

51 probs = tl.where( 

52 cols[None, :] == (curr_arg[:, None] - start_expert), -float("inf"), probs 

53 ) 

54 

55 

56def topk_softmax( 

57 topk_weights: torch.Tensor, 

58 topk_indices: torch.Tensor, 

59 token_expert_indices: torch.Tensor, 

60 gating_output: torch.Tensor, 

61) -> None: 

62 num_tokens, num_experts = gating_output.shape 

63 topk = topk_weights.size(-1) 

64 assert topk <= 32 

65 

66 if topk_indices.dtype == torch.int32: 

67 index_ty = tl.int32 

68 # elif topk_indices.dtype == torch.uint32: 

69 # index_ty = tl.uint32 

70 elif topk_indices.dtype == torch.int64: 

71 index_ty = tl.int64 

72 else: 

73 raise TypeError("topk_indices must be int32/int64/uint32") 

74 

75 max_total_threads = 1024 

76 BLOCK_SIZE_EXPERTS = ((triton.next_power_of_2(num_experts) + 31) // 32) * 32 

77 BLOCK_SIZE_EXPERTS = min(BLOCK_SIZE_EXPERTS, 1024) 

78 BLOCK_SIZE_ROWS = max_total_threads // BLOCK_SIZE_EXPERTS 

79 BLOCK_SIZE_ROWS = max(BLOCK_SIZE_ROWS, 1) 

80 

81 grid = (triton.cdiv(num_tokens, BLOCK_SIZE_ROWS),) 

82 

83 topk_gating_softmax_kernel[grid]( 

84 input_ptr=gating_output, 

85 finished_ptr=None, 

86 output_ptr=topk_weights, 

87 indices_ptr=topk_indices, 

88 source_rows_ptr=token_expert_indices, 

89 num_rows=num_tokens, 

90 k=topk, 

91 num_experts=num_experts, 

92 start_expert=0, 

93 end_expert=num_experts, 

94 INDEX_TY=index_ty, 

95 BLOCK_SIZE_ROWS=BLOCK_SIZE_ROWS, 

96 BLOCK_SIZE_EXPERTS=BLOCK_SIZE_EXPERTS, 

97 isCloseCoreTiling=True, 

98 )