Coverage for src/flag_gems/utils/device_info.py: 69%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import warnings 

2from dataclasses import dataclass 

3from functools import lru_cache 

4 

5from flag_gems.runtime import torch_device_fn 

6 

7 

8@dataclass(frozen=True) 

9class DeviceInfo: 

10 device_id: int 

11 l2_cache_size: int 

12 sm_count: int 

13 

14 

15@lru_cache(maxsize=1) 

16def get_device_id() -> int: 

17 try: 

18 return torch_device_fn.current_device() 

19 except Exception: 

20 warnings.warn( 

21 "[device_info] Failed to get current device, fallback to device_id=0." 

22 ) 

23 return 0 

24 

25 

26@lru_cache(maxsize=1) 

27def get_device_properties(): 

28 device_id = get_device_id() 

29 try: 

30 return torch_device_fn.get_device_properties(device_id) 

31 except Exception: 

32 warnings.warn( 

33 f"[device_info] Failed to get device properties for device_id={device_id}, fallback to None." 

34 ) 

35 return None 

36 

37 

38@lru_cache(maxsize=1) 

39def get_device_capability() -> tuple[int, int]: 

40 device_id = get_device_id() 

41 try: 

42 result = torch_device_fn.get_device_capability(device_id) 

43 if result is None: 

44 warnings.warn( 

45 f"[device_info] torch_device_fn.get_device_capability returned None " 

46 f"for device_id={device_id}, fallback to (0, 0)." 

47 ) 

48 return (0, 0) 

49 return result 

50 except Exception: 

51 warnings.warn( 

52 f"[device_info] Failed to get device capability for device_id={device_id} " 

53 f"using torch_device_fn, fallback to (0, 0)." 

54 ) 

55 return (0, 0) 

56 

57 

58@lru_cache(maxsize=1) 

59def get_device_info() -> DeviceInfo: 

60 props = get_device_properties() 

61 l2_cache_size = None 

62 sm_count = None 

63 if props is not None: 

64 l2_cache_size = None 

65 if hasattr(props, "L2_cache_size"): 

66 l2_cache_size = props.L2_cache_size 

67 elif hasattr(props, "l2_cache_size"): 

68 l2_cache_size = props.l2_cache_size 

69 sm_count = getattr(props, "multi_processor_count", None) or getattr( 

70 props, "multiProcessorCount", None 

71 ) 

72 if l2_cache_size is None: 

73 warnings.warn( 

74 "[device_info] Failed to get l2_cache_size, fallback to 40MB (A100 default)." 

75 ) 

76 # default L2 cache size to 40MB for A100 

77 l2_cache_size = 40 * 1024 * 1024 

78 if sm_count is None: 

79 warnings.warn( 

80 "[device_info] Failed to get sm_count, fallback to 108 (A100 default)." 

81 ) 

82 # default sm_count to 108 for A100 

83 sm_count = 108 

84 return DeviceInfo( 

85 device_id=get_device_id(), 

86 l2_cache_size=l2_cache_size, 

87 sm_count=sm_count, 

88 ) 

89 

90 

91def get_l2_cache_size() -> int: 

92 return get_device_info().l2_cache_size 

93 

94 

95def get_sm_count() -> int: 

96 return get_device_info().sm_count