Coverage for src/flag_gems/utils/limits.py: 21%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import triton 

2from triton import language as tl 

3 

4 

5@triton.jit 

6def get_dtype_max(dtype: tl.constexpr): 

7 """get a value which is greater that all other values of that dtype""" 

8 # extract the tl.dtype from tl.constexpr so as to use its methods 

9 dtype_ = dtype.value 

10 if dtype_.is_floating(): 

11 value: tl.constexpr = float("inf") 

12 return value 

13 if dtype_.is_int_signed(): 

14 width: tl.constexpr = dtype_.int_bitwidth 

15 value: tl.constexpr = 2 ** (width - 1) - 1 

16 return value 

17 if dtype_.is_int_unsigned(): 

18 width: tl.constexpr = dtype_.int_bitwidth 

19 value: tl.constexpr = 2**width - 1 

20 return value 

21 

22 

23@triton.jit 

24def get_dtype_min(dtype): 

25 """get a value which is less that all other values of that dtype""" 

26 dtype_ = dtype.value # tl.dtype 

27 if dtype_.is_floating(): 

28 value: tl.constexpr = float("-inf") 

29 return value 

30 if dtype_.is_int_signed(): 

31 width: tl.constexpr = dtype_.int_bitwidth 

32 value: tl.constexpr = -1 * 2 ** (width - 1) 

33 return value 

34 if dtype_.is_int_unsigned(): 

35 value: tl.constexpr = 0 

36 return value