Coverage for src/flag_gems/runtime/backend/device.py: 90%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import os 

2import shlex 

3import subprocess 

4import threading 

5from queue import Queue 

6 

7import torch # noqa: F401 

8 

9from .. import backend, error 

10from ..common import vendors 

11 

12UNSUPPORT_FP64 = [ 

13 vendors.CAMBRICON, 

14 vendors.ILUVATAR, 

15 vendors.KUNLUNXIN, 

16 vendors.MTHREADS, 

17 vendors.AIPU, 

18 vendors.ASCEND, 

19 vendors.TSINGMICRO, 

20 vendors.SUNRISE, 

21 vendors.ENFLAME, 

22] 

23UNSUPPORT_BF16 = [ 

24 vendors.AIPU, 

25 vendors.SUNRISE, 

26] 

27UNSUPPORT_INT64 = [ 

28 vendors.AIPU, 

29 vendors.TSINGMICRO, 

30 vendors.SUNRISE, 

31 vendors.ENFLAME, 

32] 

33 

34 

35# A singleton class to manage device context. 

36class DeviceDetector(object): 

37 _instance = None 

38 

39 def __new__(cls, *args, **kargs): 

40 if cls._instance is None: 

41 cls._instance = super(DeviceDetector, cls).__new__(cls) 

42 return cls._instance 

43 

44 def __init__(self, vendor_name=None): 

45 if not hasattr(self, "initialized"): 

46 self.initialized = True 

47 # A list of all available vendor names. 

48 self.vendor_list = vendors.get_all_vendors().keys() 

49 

50 # A dataclass instance, get the vendor information based on the provided or default vendor name. 

51 self.info = self.get_vendor(vendor_name) 

52 

53 # vendor_name is like 'nvidia', device_name is like 'cuda'. 

54 self.vendor_name = self.info.vendor_name 

55 self.name = self.info.device_name 

56 self.vendor = vendors.get_all_vendors()[self.vendor_name] 

57 self.dispatch_key = ( 

58 self.name.upper() 

59 if self.info.dispatch_key is None 

60 else self.info.dispatch_key 

61 ) 

62 self.device_count = backend.gen_torch_device_object( 

63 self.vendor_name 

64 ).device_count() 

65 self.support_fp64 = self.vendor not in UNSUPPORT_FP64 

66 self.support_bf16 = self.vendor not in UNSUPPORT_BF16 

67 self.support_int64 = self.vendor not in UNSUPPORT_INT64 

68 

69 def get_vendor(self, vendor_name=None) -> tuple: 

70 # Try to get the vendor name from a quick special command like 'torch.mlu'. 

71 vendor_from_env = self._get_vendor_from_env() 

72 if vendor_from_env is not None: 

73 return backend.get_vendor_info(vendor_from_env) 

74 

75 vendor_name = self._get_vendor_from_quick_cmd() 

76 if vendor_name is not None: 

77 return backend.get_vendor_info(vendor_name) 

78 try: 

79 # Obtaining a vendor_info from the methods provided by torch or triton, but is not currently implemented. 

80 return self._get_vendor_from_lib() 

81 except Exception: 

82 return self._get_vendor_from_sys() 

83 

84 def _get_vendor_from_quick_cmd(self): 

85 cmd = { 

86 "cambricon": "mlu", 

87 "mthreads": "musa", 

88 "iluvatar": "corex", 

89 "ascend": "npu", 

90 "sunrise": "ptpu", 

91 "enflame": "gcu", 

92 } 

93 for vendor_name, flag in cmd.items(): 

94 if hasattr(torch, flag): 

95 return vendor_name 

96 try: 

97 import torch_npu 

98 

99 for vendor_name, flag in cmd.items(): 

100 if hasattr(torch_npu, flag): 

101 return vendor_name 

102 except: # noqa: E722 

103 pass 

104 return None 

105 

106 def _get_vendor_from_env(self): 

107 device_from_evn = os.environ.get("GEMS_VENDOR") 

108 return None if device_from_evn not in self.vendor_list else device_from_evn 

109 

110 def _get_vendor_from_sys(self): 

111 vendor_infos = backend.get_vendor_infos() 

112 result_single_info = Queue() 

113 

114 def runcmd(single_info): 

115 device_query_cmd = single_info.device_query_cmd 

116 try: 

117 cmd_args = shlex.split(device_query_cmd) 

118 result = subprocess.run(cmd_args, capture_output=True, text=True) 

119 if result.returncode == 0: 

120 result_single_info.put(single_info) 

121 except: # noqa: E722 

122 pass 

123 

124 threads = [] 

125 for single_info in vendor_infos: 

126 # Get the vendor information by running system commands. 

127 thread = threading.Thread(target=runcmd, args=(single_info,)) 

128 threads.append(thread) 

129 thread.start() 

130 

131 for thread in threads: 

132 thread.join() 

133 if result_single_info.empty(): 

134 error.device_not_found() 

135 else: 

136 return result_single_info.get() 

137 

138 def get_vendor_name(self): 

139 return self.vendor_name 

140 

141 def _get_vendor_from_lib(self): 

142 # Reserve the associated interface for triton or torch 

143 # although they are not implemented yet. 

144 # try: 

145 # return triton.get_vendor_info() 

146 # except Exception: 

147 # return torch.get_vendor_info() 

148 raise RuntimeError("The method is not implemented")