Coverage for src/flag_gems/utils/limits.py: 21%
29 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import triton
2from triton import language as tl
5@triton.jit
6def get_dtype_max(dtype: tl.constexpr):
7 """get a value which is greater that all other values of that dtype"""
8 # extract the tl.dtype from tl.constexpr so as to use its methods
9 dtype_ = dtype.value
10 if dtype_.is_floating():
11 value: tl.constexpr = float("inf")
12 return value
13 if dtype_.is_int_signed():
14 width: tl.constexpr = dtype_.int_bitwidth
15 value: tl.constexpr = 2 ** (width - 1) - 1
16 return value
17 if dtype_.is_int_unsigned():
18 width: tl.constexpr = dtype_.int_bitwidth
19 value: tl.constexpr = 2**width - 1
20 return value
23@triton.jit
24def get_dtype_min(dtype):
25 """get a value which is less that all other values of that dtype"""
26 dtype_ = dtype.value # tl.dtype
27 if dtype_.is_floating():
28 value: tl.constexpr = float("-inf")
29 return value
30 if dtype_.is_int_signed():
31 width: tl.constexpr = dtype_.int_bitwidth
32 value: tl.constexpr = -1 * 2 ** (width - 1)
33 return value
34 if dtype_.is_int_unsigned():
35 value: tl.constexpr = 0
36 return value