Coverage for src/flag_gems/utils/triton_lang_extension.py: 48%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1"""
2Triton device functions.
4Custom triton device functions that we need to use.
6NOTE:
7Do not try to add triton builtin-style functions(functions with an ir builder in its
8arguments) here. We only define device-functions(triton.jit decorated functions with
9return statement) here.
11These functions can be used in kernel progamming and are not bound to any grid.
12"""
14import triton
15from triton import language as tl
17from flag_gems.utils.triton_lang_helper import use_tl_extra
20@triton.jit
21def program_id(
22 axis: int,
23) -> tl.tensor:
24 return tl.program_id(axis).to(tl.int64)
27@triton.jit
28def num_programs(
29 axis: int,
30) -> tl.tensor:
31 return tl.num_programs(axis).to(tl.int64)
34@triton.jit
35def promote_to_tensor(x):
36 # Addition promotes to tensor for us
37 return x + tl.zeros((1,), tl.int1)
40@triton.jit
41def is_floating(x):
42 return promote_to_tensor(x).dtype.is_floating()
45@triton.jit
46def minimum_with_index_tie_break_right(a_value, a_index, b_value, b_index):
47 mask = a_value < b_value
48 equal = a_value == b_value
49 if is_floating(a_value):
50 a_isnan = a_value != a_value
51 b_isnan = b_value != b_value
52 mask |= a_isnan and not b_isnan
53 # Consider NaNs as equal
54 equal |= a_isnan and b_isnan
56 # Prefer highest index if values are equal
57 mask |= equal & (a_index > b_index)
58 return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
61@triton.jit
62def maximum_with_index_tie_break_right(a_value, a_index, b_value, b_index):
63 mask = a_value > b_value
64 equal = a_value == b_value
65 if is_floating(a_value):
66 a_isnan = a_value != a_value
67 b_isnan = b_value != b_value
68 mask |= a_isnan and not b_isnan
69 # Consider NaNs as equal
70 equal |= a_isnan and b_isnan
72 # Prefer highest index if values are equal
73 mask |= equal & (a_index > b_index)
74 return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
77@use_tl_extra
78@triton.jit
79def div_rn(x, y):
80 """div_rn default - round to nearest"""
81 result = x / y
82 return tl.floor(result + 0.5)
85@use_tl_extra
86@triton.jit
87def div_rz(x, y):
88 """div_rz default - round toward zero"""
89 result = x / y
90 return tl.where(result >= 0, tl.floor(result), tl.ceil(result))
93@use_tl_extra
94@triton.jit
95def fmod(x, y):
96 """fmod default - floating point modulo"""
97 quotient = div_rz(x, y)
98 return x - y * quotient
101@use_tl_extra
102@triton.jit
103def trunc(x):
104 """trunc default - truncate to integer"""
105 return tl.where(x >= 0, tl.floor(x), tl.ceil(x))