Coverage for src/flag_gems/fused/FLA/utils.py: 81%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1# This file contains code copied from the flash-linear-attention project.
2# The original source code was licensed under the MIT license and included
3# the following copyright notice:
4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
5# ruff: noqa: E501
7import contextlib
8import functools
9import os
10from collections.abc import Callable
11from typing import Any
13import torch
14import triton
16from flag_gems import runtime
17from flag_gems.utils.device_info import get_device_capability
19# envrironments setting
20SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
21FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1"
23use_cuda_graph = os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
26def _detect_nvidia_hopper() -> bool:
27 """Return True if current device is NVIDIA and SM major version >= 9.
29 We rely on `runtime.device.vendor_name` and `get_device_capability()` which
30 already handle errors and fallbacks elsewhere.
31 """
32 vendor_name = getattr(runtime.device, "vendor_name", "").lower()
33 if "nvidia" not in vendor_name:
34 return False
35 major, _ = get_device_capability()
36 return major >= 9
39is_nvidia_hopper = _detect_nvidia_hopper()
41is_tma_supported = is_nvidia_hopper and (
42 hasattr(triton.language, "_experimental_make_tensor_descriptor")
43 or hasattr(triton.language, "make_tensor_descriptor")
44)
47def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
48 """
49 A decorator that caches the most recent results of a function with tensor inputs.
51 This decorator will store the output of the decorated function for the most recent set of input tensors.
52 The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
54 Args:
55 fn (Callable[..., torch.Tensor]):
56 The function to be decorated. It should take tensor inputs and return tensor outputs.
58 Returns:
59 Callable[..., torch.Tensor]:
60 A wrapped version of the input function with single-entry caching.
61 """
63 cache_entries: tuple[tuple | None, dict | None, Any] = []
64 cache_size = 8
66 @functools.wraps(fn)
67 def wrapper(*args: Any, **kwargs: Any) -> Any:
68 nonlocal cache_entries
69 for i, entry in enumerate(cache_entries):
70 last_args, last_kwargs, last_result = entry
71 if (
72 len(args) == len(last_args)
73 and len(kwargs) == len(last_kwargs)
74 and all(a is b for a, b in zip(args, last_args))
75 and all(
76 k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
77 )
78 ):
79 cache_entries = (
80 cache_entries[:i]
81 + cache_entries[i + 1 :]
82 + [(args, kwargs, last_result)]
83 )
84 return last_result
86 result = fn(*args, **kwargs)
88 if len(cache_entries) >= cache_size:
89 cache_entries = cache_entries[1:]
90 cache_entries.append((args, kwargs, result))
91 return result
93 return wrapper
96def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
97 """
98 A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
99 """
101 @functools.wraps(fn)
102 def wrapper(*args, **kwargs):
103 contiguous_args = (
104 i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
105 )
106 contiguous_kwargs = {
107 k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
108 for k, v in kwargs.items()
109 }
111 tensor = None
112 for arg in args:
113 if isinstance(arg, torch.Tensor):
114 tensor = arg
115 break
116 if tensor is None:
117 for value in kwargs.values():
118 if isinstance(value, torch.Tensor):
119 tensor = value
120 break
122 if tensor is not None:
123 ctx = runtime.torch_device_fn.device(tensor.device)
124 else:
125 ctx = contextlib.nullcontext()
127 with ctx:
128 return fn(*contiguous_args, **contiguous_kwargs)
130 return wrapper
133def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
134 from flag_gems.utils.device_info import get_device_properties
136 props = get_device_properties()
137 if props is None:
138 return False
140 # property names differ across torch versions/drivers; try common ones
141 max_shared = getattr(props, "max_shared_memory_per_multiprocessor", None)
142 if max_shared is None:
143 max_shared = getattr(props, "max_shared_memory", None)
144 if max_shared is None:
145 # fallback conservative default
146 return False
147 # Use the AMPERE threshold used in the original project as heuristic
148 return max_shared >= 166_000