Coverage for src/flag_gems/runtime/backend/_tsingmicro/__init__.py: 18%

28 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1from __future__ import annotations 

2 

3from dataclasses import dataclass 

4from typing import Optional, Tuple, Union 

5 

6import torch 

7import torch_txda # noqa: F401 

8from backend_utils import VendorInfoBase # noqa: E402 

9 

10 

11@dataclass 

12class TxdaDeviceProperties: 

13 name: str 

14 major: int 

15 minor: int 

16 total_memory: int # MB 

17 multi_processor_count: int 

18 uuid: str 

19 L2_cache_size: int # MB 

20 

21 def __repr__(self) -> str: 

22 return ( 

23 f"TxdaDeviceProperties(name='{self.name}', major={self.major}, " 

24 f"minor={self.minor}, total_memory={self.total_memory}MB, " 

25 f"multi_processor_count={self.multi_processor_count}, " 

26 f"uuid={self.uuid}, L2_cache_size={self.L2_cache_size}MB)" 

27 ) 

28 

29 

30def get_device_properties( 

31 device: torch.device | str | int | None = None, 

32) -> TxdaDeviceProperties: 

33 return TxdaDeviceProperties( 

34 name="TX81", 

35 major=8, 

36 minor=1, 

37 total_memory=64 * 1024 ^ 3, # 64GB 

38 multi_processor_count=16, 

39 uuid="", 

40 L2_cache_size=3 * 1024 ^ 2, # 3MB 

41 ) 

42 

43 

44def get_device_capability( 

45 device: Optional[Union[torch.device, str, int]] = None 

46) -> Tuple[int, int]: 

47 return (8, 0) 

48 

49 

50if not hasattr(torch.txda, "get_device_properties"): 

51 setattr(torch.txda, "get_device_properties", get_device_properties) 

52 

53if not hasattr(torch.txda, "get_device_capability"): 

54 setattr(torch.txda, "get_device_capability", get_device_capability) 

55 

56vendor_info = VendorInfoBase( 

57 vendor_name="tsingmicro", 

58 device_name="txda", 

59 device_query_cmd="tsm_smi", 

60 dispatch_key="PrivateUse1", 

61) 

62 

63CUSTOMIZED_UNUSED_OPS = () 

64 

65__all__ = ["*"]