Coverage for src/flag_gems/runtime/backend/device.py: 90%
83 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import os
2import shlex
3import subprocess
4import threading
5from queue import Queue
7import torch # noqa: F401
9from .. import backend, error
10from ..common import vendors
12UNSUPPORT_FP64 = [
13 vendors.CAMBRICON,
14 vendors.ILUVATAR,
15 vendors.KUNLUNXIN,
16 vendors.MTHREADS,
17 vendors.AIPU,
18 vendors.ASCEND,
19 vendors.TSINGMICRO,
20 vendors.SUNRISE,
21 vendors.ENFLAME,
22]
23UNSUPPORT_BF16 = [
24 vendors.AIPU,
25 vendors.SUNRISE,
26]
27UNSUPPORT_INT64 = [
28 vendors.AIPU,
29 vendors.TSINGMICRO,
30 vendors.SUNRISE,
31 vendors.ENFLAME,
32]
35# A singleton class to manage device context.
36class DeviceDetector(object):
37 _instance = None
39 def __new__(cls, *args, **kargs):
40 if cls._instance is None:
41 cls._instance = super(DeviceDetector, cls).__new__(cls)
42 return cls._instance
44 def __init__(self, vendor_name=None):
45 if not hasattr(self, "initialized"):
46 self.initialized = True
47 # A list of all available vendor names.
48 self.vendor_list = vendors.get_all_vendors().keys()
50 # A dataclass instance, get the vendor information based on the provided or default vendor name.
51 self.info = self.get_vendor(vendor_name)
53 # vendor_name is like 'nvidia', device_name is like 'cuda'.
54 self.vendor_name = self.info.vendor_name
55 self.name = self.info.device_name
56 self.vendor = vendors.get_all_vendors()[self.vendor_name]
57 self.dispatch_key = (
58 self.name.upper()
59 if self.info.dispatch_key is None
60 else self.info.dispatch_key
61 )
62 self.device_count = backend.gen_torch_device_object(
63 self.vendor_name
64 ).device_count()
65 self.support_fp64 = self.vendor not in UNSUPPORT_FP64
66 self.support_bf16 = self.vendor not in UNSUPPORT_BF16
67 self.support_int64 = self.vendor not in UNSUPPORT_INT64
69 def get_vendor(self, vendor_name=None) -> tuple:
70 # Try to get the vendor name from a quick special command like 'torch.mlu'.
71 vendor_from_env = self._get_vendor_from_env()
72 if vendor_from_env is not None:
73 return backend.get_vendor_info(vendor_from_env)
75 vendor_name = self._get_vendor_from_quick_cmd()
76 if vendor_name is not None:
77 return backend.get_vendor_info(vendor_name)
78 try:
79 # Obtaining a vendor_info from the methods provided by torch or triton, but is not currently implemented.
80 return self._get_vendor_from_lib()
81 except Exception:
82 return self._get_vendor_from_sys()
84 def _get_vendor_from_quick_cmd(self):
85 cmd = {
86 "cambricon": "mlu",
87 "mthreads": "musa",
88 "iluvatar": "corex",
89 "ascend": "npu",
90 "sunrise": "ptpu",
91 "enflame": "gcu",
92 }
93 for vendor_name, flag in cmd.items():
94 if hasattr(torch, flag):
95 return vendor_name
96 try:
97 import torch_npu
99 for vendor_name, flag in cmd.items():
100 if hasattr(torch_npu, flag):
101 return vendor_name
102 except: # noqa: E722
103 pass
104 return None
106 def _get_vendor_from_env(self):
107 device_from_evn = os.environ.get("GEMS_VENDOR")
108 return None if device_from_evn not in self.vendor_list else device_from_evn
110 def _get_vendor_from_sys(self):
111 vendor_infos = backend.get_vendor_infos()
112 result_single_info = Queue()
114 def runcmd(single_info):
115 device_query_cmd = single_info.device_query_cmd
116 try:
117 cmd_args = shlex.split(device_query_cmd)
118 result = subprocess.run(cmd_args, capture_output=True, text=True)
119 if result.returncode == 0:
120 result_single_info.put(single_info)
121 except: # noqa: E722
122 pass
124 threads = []
125 for single_info in vendor_infos:
126 # Get the vendor information by running system commands.
127 thread = threading.Thread(target=runcmd, args=(single_info,))
128 threads.append(thread)
129 thread.start()
131 for thread in threads:
132 thread.join()
133 if result_single_info.empty():
134 error.device_not_found()
135 else:
136 return result_single_info.get()
138 def get_vendor_name(self):
139 return self.vendor_name
141 def _get_vendor_from_lib(self):
142 # Reserve the associated interface for triton or torch
143 # although they are not implemented yet.
144 # try:
145 # return triton.get_vendor_info()
146 # except Exception:
147 # return torch.get_vendor_info()
148 raise RuntimeError("The method is not implemented")