Coverage for src/flag_gems/runtime/__init__.py: 67%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1from . import backend, common, error 

2from .backend.device import DeviceDetector 

3from .configloader import ConfigLoader 

4 

5config_loader = ConfigLoader() 

6device = DeviceDetector() 

7 

8""" 

9The dependency order of the sub-directory is strict, and changing the order arbitrarily may cause errors. 

10""" 

11 

12# torch_device_fn is like 'torch.cuda' object 

13backend.set_torch_backend_device_fn(device.vendor_name) 

14torch_device_fn = backend.gen_torch_device_object() 

15 

16# torch_backend_device is like 'torch.backend.cuda' object 

17torch_backend_device = backend.get_torch_backend_device_fn() 

18 

19 

20def get_tuned_config(op_name): 

21 return config_loader.get_tuned_config(op_name) 

22 

23 

24def get_heuristic_config(op_name): 

25 return config_loader.get_heuristics_config(op_name) 

26 

27 

28def replace_customized_ops(_globals): 

29 event = backend.BackendArchEvent() 

30 arch_specialization_operators = event.get_arch_ops() if event.has_arch else None 

31 backend_customization_operators = backend.get_current_device_extend_op( 

32 device.vendor_name 

33 ) 

34 if device.vendor != common.vendors.NVIDIA: 

35 try: 

36 for fn_name, fn in backend_customization_operators: 

37 _globals[fn_name] = fn 

38 except RuntimeError as e: 

39 error.customized_op_replace_error(e) 

40 if arch_specialization_operators: 

41 try: 

42 for fn_name, fn in arch_specialization_operators: 

43 _globals[fn_name] = fn 

44 except RuntimeError as e: 

45 error.customized_op_replace_error(e) 

46 

47 

48__all__ = ["*"]