Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/clamp.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from ..utils.pointwise_dynamic import pointwise_dynamic 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")]) 

12@triton.jit 

13def clamp_func_tensor(x, mini, maxi): 

14 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) 

15 

16 

17@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

18@triton.jit 

19def clamp_func_min_tensor(x, mini): 

20 return tl.maximum(mini, x.to(tl.float32)) 

21 

22 

23@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

24@triton.jit 

25def clamp_func_max_tensor(x, maxi): 

26 return tl.minimum(maxi, x.to(tl.float32)) 

27 

28 

29def clamp_tensor(A, mini=None, maxi=None): 

30 logger.debug("GEMS CLAMP TENSOR") 

31 if mini is None and maxi is None: 

32 raise ValueError("At least one of mini or maxi must not be None") 

33 elif mini is None: 

34 return clamp_func_max_tensor(A, maxi) 

35 elif maxi is None: 

36 return clamp_func_min_tensor(A, mini) 

37 else: 

38 return clamp_func_tensor(A, mini, maxi) 

39 

40 

41def clamp_tensor_(A, mini=None, maxi=None): 

42 logger.debug("GEMS CLAMP_ TENSOR") 

43 if mini is None and maxi is None: 

44 raise ValueError("At least one of mini or maxi must not be None") 

45 elif mini is None: 

46 return clamp_func_max_tensor(A, maxi, out0=A) 

47 elif maxi is None: 

48 return clamp_func_min_tensor(A, mini, out0=A) 

49 else: 

50 return clamp_func_tensor(A, mini, maxi, out0=A) 

51 

52 

53@pointwise_dynamic( 

54 is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")] 

55) 

56@triton.jit 

57def clamp_func(x, mini, maxi): 

58 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) 

59 

60 

61@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

62@triton.jit 

63def clamp_func_min(x, mini): 

64 return tl.maximum(mini, x.to(tl.float32)) 

65 

66 

67@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

68@triton.jit 

69def clamp_func_max(x, maxi): 

70 return tl.minimum(maxi, x.to(tl.float32)) 

71 

72 

73def clamp_min(A, mini): 

74 logger.debug("GEMS CLAMP MIN") 

75 if mini is None: 

76 raise ValueError("Mini must not be None") 

77 return clamp_func_min(A, mini) 

78 

79 

80def clamp_min_(A, mini): 

81 logger.debug("GEMS CLAMP_ MIN") 

82 if mini is None: 

83 raise ValueError("Mini must not be None") 

84 return clamp_func_min(A, mini, out0=A) 

85 

86 

87def clamp(A, mini=None, maxi=None): 

88 logger.debug("GEMS CLAMP") 

89 if mini is None and maxi is None: 

90 raise ValueError("At least one of mini or maxi must not be None") 

91 elif mini is None: 

92 return clamp_func_max(A, maxi) 

93 elif maxi is None: 

94 return clamp_func_min(A, mini) 

95 else: 

96 return clamp_func(A, mini, maxi) 

97 

98 

99def clamp_(A, mini=None, maxi=None): 

100 logger.debug("GEMS CLAMP_") 

101 if mini is None and maxi is None: 

102 raise ValueError("At least one of mini or maxi must not be None") 

103 elif mini is None: 

104 return clamp_func_max(A, maxi, out0=A) 

105 elif maxi is None: 

106 return clamp_func_min(A, mini, out0=A) 

107 else: 

108 return clamp_func(A, mini, maxi, out0=A)