Coverage for src/flag_gems/fused/FLA/triton_ops_helper.py: 60%

15 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1# This file contains code copied from the flash-linear-attention project. 

2# The original source code was licensed under the MIT license and included 

3# the following copyright notice: 

4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

5# ruff: noqa: E501 

6 

7import os 

8 

9import triton 

10import triton.language as tl 

11import triton.language.extra.libdevice as tldevice 

12 

13 

14def get_exp(): 

15 """Return exp implementation (fast or accurate) based on env flag.""" 

16 return ( 

17 tldevice.fast_expf if os.environ.get("FLA_USE_FAST_OPS", "0") == "1" else tl.exp 

18 ) 

19 

20 

21# Default exported exp to be imported by kernels. 

22exp = get_exp() 

23 

24 

25if hasattr(triton.language, "_experimental_make_tensor_descriptor"): 

26 # For Triton 3.3.x 

27 make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor 

28elif hasattr(triton.language, "make_tensor_descriptor"): 

29 # For Triton 3.4.x and later 

30 make_tensor_descriptor = triton.language.make_tensor_descriptor 

31else: 

32 """ 

33 Fallback implementation when TMA is not supported. 

34 Returns None to indicate TMA descriptors are unavailable. 

35 Just make triton compiler happy. 

36 """ 

37 

38 @triton.jit 

39 def make_tensor_descriptor( 

40 base, 

41 shape, 

42 strides, 

43 block_shape, 

44 _builder=None, 

45 ): 

46 return None