Coverage for src/flag_gems/fused/FLA/triton_ops_helper.py: 60%
15 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1# This file contains code copied from the flash-linear-attention project.
2# The original source code was licensed under the MIT license and included
3# the following copyright notice:
4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
5# ruff: noqa: E501
7import os
9import triton
10import triton.language as tl
11import triton.language.extra.libdevice as tldevice
14def get_exp():
15 """Return exp implementation (fast or accurate) based on env flag."""
16 return (
17 tldevice.fast_expf if os.environ.get("FLA_USE_FAST_OPS", "0") == "1" else tl.exp
18 )
21# Default exported exp to be imported by kernels.
22exp = get_exp()
25if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
26 # For Triton 3.3.x
27 make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
28elif hasattr(triton.language, "make_tensor_descriptor"):
29 # For Triton 3.4.x and later
30 make_tensor_descriptor = triton.language.make_tensor_descriptor
31else:
32 """
33 Fallback implementation when TMA is not supported.
34 Returns None to indicate TMA descriptors are unavailable.
35 Just make triton compiler happy.
36 """
38 @triton.jit
39 def make_tensor_descriptor(
40 base,
41 shape,
42 strides,
43 block_shape,
44 _builder=None,
45 ):
46 return None