Coverage for src/flag_gems/utils/triton_lang_helper.py: 65%
26 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1from flag_gems.runtime import backend
2from flag_gems.runtime.backend.device import DeviceDetector
4"""
5 To be compatible with different versions of math libraries
6 tl_extra_shim will be selected to a specific library.
7 And the "triton.language.extra" module is only available in
8 Triton 2.2 and later versions.
9"""
11device = DeviceDetector()
12backend.set_torch_backend_device_fn(device.vendor_name)
13try:
14 backend.set_tl_extra_backend_module(device.vendor_name)
15 tl_extra_shim = backend.get_tl_extra_backend_module()
16except ImportError:
17 import triton
19 try:
20 tl_extra_shim = triton.language.math
21 except ImportError:
22 tl_extra_shim = triton.language.libdevice
25def use_backend(module):
26 """using backend module impl"""
28 def decorator(func):
29 func_name = func.__name__
30 if hasattr(module, func_name):
31 try:
32 return getattr(module, func_name)
33 except Exception:
34 pass
35 return func
37 return decorator
40def use_tl_extra(func):
41 """backend function shim"""
42 return use_backend(tl_extra_shim)(func)