Coverage for src/flag_gems/utils/triton_lang_helper.py: 65%

26 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1from flag_gems.runtime import backend 

2from flag_gems.runtime.backend.device import DeviceDetector 

3 

4""" 

5 To be compatible with different versions of math libraries 

6 tl_extra_shim will be selected to a specific library. 

7 And the "triton.language.extra" module is only available in 

8 Triton 2.2 and later versions. 

9""" 

10 

11device = DeviceDetector() 

12backend.set_torch_backend_device_fn(device.vendor_name) 

13try: 

14 backend.set_tl_extra_backend_module(device.vendor_name) 

15 tl_extra_shim = backend.get_tl_extra_backend_module() 

16except ImportError: 

17 import triton 

18 

19 try: 

20 tl_extra_shim = triton.language.math 

21 except ImportError: 

22 tl_extra_shim = triton.language.libdevice 

23 

24 

25def use_backend(module): 

26 """using backend module impl""" 

27 

28 def decorator(func): 

29 func_name = func.__name__ 

30 if hasattr(module, func_name): 

31 try: 

32 return getattr(module, func_name) 

33 except Exception: 

34 pass 

35 return func 

36 

37 return decorator 

38 

39 

40def use_tl_extra(func): 

41 """backend function shim""" 

42 return use_backend(tl_extra_shim)(func)