Coverage for src/flag_gems/runtime/common.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1from enum import Enum 

2 

3 

4class vendors(Enum): 

5 NVIDIA = 0 

6 CAMBRICON = 1 

7 METAX = 2 

8 ILUVATAR = 3 

9 MTHREADS = 4 

10 KUNLUNXIN = 5 

11 HYGON = 6 

12 AMD = 7 

13 AIPU = 8 

14 ASCEND = 9 

15 TSINGMICRO = 10 

16 SUNRISE = 11 

17 ENFLAME = 12 

18 SPACEMIT = 13 

19 THEAD = 14 

20 

21 @classmethod 

22 def get_all_vendors(cls) -> dict: 

23 vendorDict = {} 

24 for member in cls: 

25 vendorDict[member.name.lower()] = member 

26 return vendorDict 

27 

28 

29UNSUPPORT_FP64 = frozenset( 

30 { 

31 vendors.AIPU, 

32 vendors.ASCEND, 

33 vendors.CAMBRICON, 

34 vendors.ENFLAME, 

35 vendors.ILUVATAR, 

36 vendors.KUNLUNXIN, 

37 vendors.MTHREADS, 

38 vendors.SUNRISE, 

39 vendors.SPACEMIT, 

40 vendors.TSINGMICRO, 

41 } 

42) 

43 

44UNSUPPORT_BF16 = frozenset( 

45 { 

46 vendors.AIPU, 

47 vendors.SUNRISE, 

48 vendors.SPACEMIT, 

49 } 

50) 

51 

52UNSUPPORT_INT64 = frozenset( 

53 { 

54 vendors.AIPU, 

55 vendors.ENFLAME, 

56 vendors.SPACEMIT, 

57 vendors.SUNRISE, 

58 vendors.TSINGMICRO, 

59 } 

60) 

61 

62DEFAULT_STRATEGIES = { 

63 "addmm": ["align32", "align32", "align32"], 

64 "addmm_sqmma": ["align32", "align32", "align32"], 

65 "baddbmm": ["align32", "align32", "align32"], 

66 "bmm": ["align32", "align32", "align32", "align32", "align32"], 

67 "bmm_sqmma": ["align32", "align32", "align32"], 

68 "gemv": ["align32", "align32", "align32", "default"], 

69 "mm": ["align32", "align32", "align32", "align32", "align32"], 

70 "mm_general_tma": [ 

71 "align32", 

72 "align32", 

73 "align32", 

74 "align32", 

75 "align32", 

76 "default", 

77 ], 

78 "mv": ["align32", "align32"], 

79 "sparse_attention": ["align32", "align32", "align32"], 

80 "w8a8_block_fp8_general": [ 

81 "align32", 

82 "align32", 

83 "align32", 

84 "align32", 

85 "align32", 

86 ], 

87 "w8a8_block_fp8_general_splitk": [ 

88 "align32", 

89 "align32", 

90 "align32", 

91 "align32", 

92 "align32", 

93 ], 

94 "w8a8_block_fp8_general_tma": [ 

95 "align32", 

96 "align32", 

97 "align32", 

98 "align32", 

99 "align32", 

100 "default", 

101 ], 

102 "mm_splitk": ["align32", "align32", "align32", "align32", "align32"], 

103} 

104 

105OP_KEY_ORDERS = { 

106 "addmm": ["M", "N", "K"], 

107 "addmm_sqmma": ["M", "N", "K"], 

108 "bmm": ["M", "N", "K", "stride_am", "stride_bk"], 

109 "bmm_sqmma": ["M", "N", "K"], 

110 "baddbmm": ["M", "N", "K"], 

111 "gemv": ["M", "K", "stride_am", "stride_bk"], 

112 "mm": ["M", "N", "K", "stride_am", "stride_bk"], 

113 "mm_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

114 "mv": ["M", "N"], 

115 "sparse_attention": ["topk", "H_ACTUAL", "D"], 

116 "w8a8_block_fp8_general": ["M", "N", "K", "stride_am", "stride_bk"], 

117 "w8a8_block_fp8_general_splitk": ["M", "N", "K", "stride_am", "stride_bk"], 

118 "w8a8_block_fp8_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

119 "mm_splitk": ["M", "N", "K", "stride_am", "stride_bk"], 

120} 

121 

122 

123# Mapping from vendor name to torch attribute for quick detection 

124_VENDOR_TORCH_ATTR = { 

125 "ascend": "npu", 

126 "cambricon": "mlu", 

127 "enflame": "gcu", 

128 "hygon": "__hcu_version__", 

129 "iluvatar": "corex", 

130 "mthreads": "musa", 

131 "sunrise": "ptpu", 

132} 

133 

134__all__ = [ 

135 "vendors", 

136 "UNSUPPORT_FP64", 

137 "UNSUPPORT_BF16", 

138 "UNSUPPORT_INT64", 

139 "DEFAULT_STRATEGIES", 

140 "OP_KEY_ORDERS", 

141 "_VENDOR_TORCH_ATTR", 

142]