Coverage for src/flag_gems/fused/FLA/utils.py: 81%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1# This file contains code copied from the flash-linear-attention project. 

2# The original source code was licensed under the MIT license and included 

3# the following copyright notice: 

4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

5# ruff: noqa: E501 

6 

7import contextlib 

8import functools 

9import os 

10from collections.abc import Callable 

11from typing import Any 

12 

13import torch 

14import triton 

15 

16from flag_gems import runtime 

17from flag_gems.utils.device_info import get_device_capability 

18 

19# envrironments setting 

20SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) 

21FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" 

22 

23use_cuda_graph = os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" 

24 

25 

26def _detect_nvidia_hopper() -> bool: 

27 """Return True if current device is NVIDIA and SM major version >= 9. 

28 

29 We rely on `runtime.device.vendor_name` and `get_device_capability()` which 

30 already handle errors and fallbacks elsewhere. 

31 """ 

32 vendor_name = getattr(runtime.device, "vendor_name", "").lower() 

33 if "nvidia" not in vendor_name: 

34 return False 

35 major, _ = get_device_capability() 

36 return major >= 9 

37 

38 

39is_nvidia_hopper = _detect_nvidia_hopper() 

40 

41is_tma_supported = is_nvidia_hopper and ( 

42 hasattr(triton.language, "_experimental_make_tensor_descriptor") 

43 or hasattr(triton.language, "make_tensor_descriptor") 

44) 

45 

46 

47def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: 

48 """ 

49 A decorator that caches the most recent results of a function with tensor inputs. 

50 

51 This decorator will store the output of the decorated function for the most recent set of input tensors. 

52 The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. 

53 

54 Args: 

55 fn (Callable[..., torch.Tensor]): 

56 The function to be decorated. It should take tensor inputs and return tensor outputs. 

57 

58 Returns: 

59 Callable[..., torch.Tensor]: 

60 A wrapped version of the input function with single-entry caching. 

61 """ 

62 

63 cache_entries: tuple[tuple | None, dict | None, Any] = [] 

64 cache_size = 8 

65 

66 @functools.wraps(fn) 

67 def wrapper(*args: Any, **kwargs: Any) -> Any: 

68 nonlocal cache_entries 

69 for i, entry in enumerate(cache_entries): 

70 last_args, last_kwargs, last_result = entry 

71 if ( 

72 len(args) == len(last_args) 

73 and len(kwargs) == len(last_kwargs) 

74 and all(a is b for a, b in zip(args, last_args)) 

75 and all( 

76 k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items() 

77 ) 

78 ): 

79 cache_entries = ( 

80 cache_entries[:i] 

81 + cache_entries[i + 1 :] 

82 + [(args, kwargs, last_result)] 

83 ) 

84 return last_result 

85 

86 result = fn(*args, **kwargs) 

87 

88 if len(cache_entries) >= cache_size: 

89 cache_entries = cache_entries[1:] 

90 cache_entries.append((args, kwargs, result)) 

91 return result 

92 

93 return wrapper 

94 

95 

96def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: 

97 """ 

98 A decorator to make sure all input tensors are contiguous and set the device based on input tensors. 

99 """ 

100 

101 @functools.wraps(fn) 

102 def wrapper(*args, **kwargs): 

103 contiguous_args = ( 

104 i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args 

105 ) 

106 contiguous_kwargs = { 

107 k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) 

108 for k, v in kwargs.items() 

109 } 

110 

111 tensor = None 

112 for arg in args: 

113 if isinstance(arg, torch.Tensor): 

114 tensor = arg 

115 break 

116 if tensor is None: 

117 for value in kwargs.values(): 

118 if isinstance(value, torch.Tensor): 

119 tensor = value 

120 break 

121 

122 if tensor is not None: 

123 ctx = runtime.torch_device_fn.device(tensor.device) 

124 else: 

125 ctx = contextlib.nullcontext() 

126 

127 with ctx: 

128 return fn(*contiguous_args, **contiguous_kwargs) 

129 

130 return wrapper 

131 

132 

133def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: 

134 from flag_gems.utils.device_info import get_device_properties 

135 

136 props = get_device_properties() 

137 if props is None: 

138 return False 

139 

140 # property names differ across torch versions/drivers; try common ones 

141 max_shared = getattr(props, "max_shared_memory_per_multiprocessor", None) 

142 if max_shared is None: 

143 max_shared = getattr(props, "max_shared_memory", None) 

144 if max_shared is None: 

145 # fallback conservative default 

146 return False 

147 # Use the AMPERE threshold used in the original project as heuristic 

148 return max_shared >= 166_000