Coverage for src/flag_gems/utils/triton_lang_extension.py: 48%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1""" 

2Triton device functions. 

3 

4Custom triton device functions that we need to use. 

5 

6NOTE: 

7Do not try to add triton builtin-style functions(functions with an ir builder in its 

8arguments) here. We only define device-functions(triton.jit decorated functions with 

9return statement) here. 

10 

11These functions can be used in kernel progamming and are not bound to any grid. 

12""" 

13 

14import triton 

15from triton import language as tl 

16 

17from flag_gems.utils.triton_lang_helper import use_tl_extra 

18 

19 

20@triton.jit 

21def program_id( 

22 axis: int, 

23) -> tl.tensor: 

24 return tl.program_id(axis).to(tl.int64) 

25 

26 

27@triton.jit 

28def num_programs( 

29 axis: int, 

30) -> tl.tensor: 

31 return tl.num_programs(axis).to(tl.int64) 

32 

33 

34@triton.jit 

35def promote_to_tensor(x): 

36 # Addition promotes to tensor for us 

37 return x + tl.zeros((1,), tl.int1) 

38 

39 

40@triton.jit 

41def is_floating(x): 

42 return promote_to_tensor(x).dtype.is_floating() 

43 

44 

45@triton.jit 

46def minimum_with_index_tie_break_right(a_value, a_index, b_value, b_index): 

47 mask = a_value < b_value 

48 equal = a_value == b_value 

49 if is_floating(a_value): 

50 a_isnan = a_value != a_value 

51 b_isnan = b_value != b_value 

52 mask |= a_isnan and not b_isnan 

53 # Consider NaNs as equal 

54 equal |= a_isnan and b_isnan 

55 

56 # Prefer highest index if values are equal 

57 mask |= equal & (a_index > b_index) 

58 return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) 

59 

60 

61@triton.jit 

62def maximum_with_index_tie_break_right(a_value, a_index, b_value, b_index): 

63 mask = a_value > b_value 

64 equal = a_value == b_value 

65 if is_floating(a_value): 

66 a_isnan = a_value != a_value 

67 b_isnan = b_value != b_value 

68 mask |= a_isnan and not b_isnan 

69 # Consider NaNs as equal 

70 equal |= a_isnan and b_isnan 

71 

72 # Prefer highest index if values are equal 

73 mask |= equal & (a_index > b_index) 

74 return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) 

75 

76 

77@use_tl_extra 

78@triton.jit 

79def div_rn(x, y): 

80 """div_rn default - round to nearest""" 

81 result = x / y 

82 return tl.floor(result + 0.5) 

83 

84 

85@use_tl_extra 

86@triton.jit 

87def div_rz(x, y): 

88 """div_rz default - round toward zero""" 

89 result = x / y 

90 return tl.where(result >= 0, tl.floor(result), tl.ceil(result)) 

91 

92 

93@use_tl_extra 

94@triton.jit 

95def fmod(x, y): 

96 """fmod default - floating point modulo""" 

97 quotient = div_rz(x, y) 

98 return x - y * quotient 

99 

100 

101@use_tl_extra 

102@triton.jit 

103def trunc(x): 

104 """trunc default - truncate to integer""" 

105 return tl.where(x >= 0, tl.floor(x), tl.ceil(x))