Coverage for src/flag_gems/utils/device_info.py: 69%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import warnings
2from dataclasses import dataclass
3from functools import lru_cache
5from flag_gems.runtime import torch_device_fn
8@dataclass(frozen=True)
9class DeviceInfo:
10 device_id: int
11 l2_cache_size: int
12 sm_count: int
15@lru_cache(maxsize=1)
16def get_device_id() -> int:
17 try:
18 return torch_device_fn.current_device()
19 except Exception:
20 warnings.warn(
21 "[device_info] Failed to get current device, fallback to device_id=0."
22 )
23 return 0
26@lru_cache(maxsize=1)
27def get_device_properties():
28 device_id = get_device_id()
29 try:
30 return torch_device_fn.get_device_properties(device_id)
31 except Exception:
32 warnings.warn(
33 f"[device_info] Failed to get device properties for device_id={device_id}, fallback to None."
34 )
35 return None
38@lru_cache(maxsize=1)
39def get_device_capability() -> tuple[int, int]:
40 device_id = get_device_id()
41 try:
42 result = torch_device_fn.get_device_capability(device_id)
43 if result is None:
44 warnings.warn(
45 f"[device_info] torch_device_fn.get_device_capability returned None "
46 f"for device_id={device_id}, fallback to (0, 0)."
47 )
48 return (0, 0)
49 return result
50 except Exception:
51 warnings.warn(
52 f"[device_info] Failed to get device capability for device_id={device_id} "
53 f"using torch_device_fn, fallback to (0, 0)."
54 )
55 return (0, 0)
58@lru_cache(maxsize=1)
59def get_device_info() -> DeviceInfo:
60 props = get_device_properties()
61 l2_cache_size = None
62 sm_count = None
63 if props is not None:
64 l2_cache_size = None
65 if hasattr(props, "L2_cache_size"):
66 l2_cache_size = props.L2_cache_size
67 elif hasattr(props, "l2_cache_size"):
68 l2_cache_size = props.l2_cache_size
69 sm_count = getattr(props, "multi_processor_count", None) or getattr(
70 props, "multiProcessorCount", None
71 )
72 if l2_cache_size is None:
73 warnings.warn(
74 "[device_info] Failed to get l2_cache_size, fallback to 40MB (A100 default)."
75 )
76 # default L2 cache size to 40MB for A100
77 l2_cache_size = 40 * 1024 * 1024
78 if sm_count is None:
79 warnings.warn(
80 "[device_info] Failed to get sm_count, fallback to 108 (A100 default)."
81 )
82 # default sm_count to 108 for A100
83 sm_count = 108
84 return DeviceInfo(
85 device_id=get_device_id(),
86 l2_cache_size=l2_cache_size,
87 sm_count=sm_count,
88 )
91def get_l2_cache_size() -> int:
92 return get_device_info().l2_cache_size
95def get_sm_count() -> int:
96 return get_device_info().sm_count