Coverage for src/flag_gems/runtime/backend/_cambricon/ops/clamp.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from ..utils.pointwise_dynamic import pointwise_dynamic 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11@pointwise_dynamic( 

12 is_tensor=[True, True, True, False], promotion_methods=[(0, 1, 2, "DEFAULT")] 

13) 

14@triton.jit 

15def clamp_func_tensor(x, mini, maxi, inplace): 

16 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) 

17 

18 

19@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

20@triton.jit 

21def clamp_func_min_tensor(x, mini, inplace): 

22 return tl.maximum(mini, x.to(tl.float32)) 

23 

24 

25@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

26@triton.jit 

27def clamp_func_max_tensor(x, maxi, inplace): 

28 return tl.minimum(maxi, x.to(tl.float32)) 

29 

30 

31def clamp_tensor(A, mini=None, maxi=None): 

32 logger.debug("GEMS_CAMBRICON CLAMP TENSOR") 

33 if mini is None and maxi is None: 

34 raise ValueError("At least one of mini or maxi must not be None") 

35 elif mini is None: 

36 return clamp_func_max_tensor(A, maxi, False) 

37 elif maxi is None: 

38 return clamp_func_min_tensor(A, mini, False) 

39 else: 

40 return clamp_func_tensor(A, mini, maxi, False) 

41 

42 

43def clamp_tensor_(A, mini=None, maxi=None): 

44 logger.debug("GEMS_CAMBRICON CLAMP_ TENSOR") 

45 if mini is None and maxi is None: 

46 raise ValueError("At least one of mini or maxi must not be None") 

47 elif mini is None: 

48 return clamp_func_max_tensor(A, maxi, True, out0=A) 

49 elif maxi is None: 

50 return clamp_func_min_tensor(A, mini, True, out0=A) 

51 else: 

52 return clamp_func_tensor(A, mini, maxi, True, out0=A) 

53 

54 

55@pointwise_dynamic( 

56 is_tensor=[True, False, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")] 

57) 

58@triton.jit 

59def clamp_func(x, mini, maxi, inplace): 

60 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) 

61 

62 

63@pointwise_dynamic( 

64 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

65) 

66@triton.jit 

67def clamp_func_min(x, mini, inplace): 

68 return tl.maximum(mini, x.to(tl.float32)) 

69 

70 

71@pointwise_dynamic( 

72 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

73) 

74@triton.jit 

75def clamp_func_max(x, maxi, inplace): 

76 return tl.minimum(maxi, x.to(tl.float32)) 

77 

78 

79def clamp_min(A, mini): 

80 logger.debug("GEMS_CAMBRICON CLAMP MIN") 

81 if mini is None: 

82 raise ValueError("Mini must not be None") 

83 return clamp_func_min(A, mini, False) 

84 

85 

86def clamp_min_(A, mini): 

87 logger.debug("GEMS_CAMBRICON CLAMP_ MIN") 

88 if mini is None: 

89 raise ValueError("Mini must not be None") 

90 return clamp_func_min(A, mini, True, out0=A) 

91 

92 

93def clamp(A, mini=None, maxi=None): 

94 logger.debug("GEMS_CAMBRICON CLAMP") 

95 if mini is None and maxi is None: 

96 raise ValueError("At least one of mini or maxi must not be None") 

97 elif mini is None: 

98 return clamp_func_max(A, maxi, False) 

99 elif maxi is None: 

100 return clamp_func_min(A, mini, False) 

101 else: 

102 return clamp_func(A, mini, maxi, False) 

103 

104 

105def clamp_(A, mini=None, maxi=None): 

106 logger.debug("GEMS_CAMBRICON CLAMP_") 

107 if mini is None and maxi is None: 

108 raise ValueError("At least one of mini or maxi must not be None") 

109 elif mini is None: 

110 return clamp_func_max(A, maxi, True, out0=A) 

111 elif maxi is None: 

112 return clamp_func_min(A, mini, True, out0=A) 

113 else: 

114 return clamp_func(A, mini, maxi, True, out0=A)