Coverage for src/flag_gems/ops/clamp.py: 76%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")])
12@triton.jit
13def clamp_func_tensor(x, mini, maxi):
14 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
17@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
18@triton.jit
19def clamp_func_min_tensor(x, mini):
20 return tl.maximum(mini, x.to(tl.float32))
23@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
24@triton.jit
25def clamp_func_max_tensor(x, maxi):
26 return tl.minimum(maxi, x.to(tl.float32))
29def clamp_tensor(A, mini=None, maxi=None):
30 logger.debug("GEMS CLAMP TENSOR")
31 if mini is None and maxi is None:
32 raise ValueError("At least one of mini or maxi must not be None")
33 elif mini is None:
34 return clamp_func_max_tensor(A, maxi)
35 elif maxi is None:
36 return clamp_func_min_tensor(A, mini)
37 else:
38 return clamp_func_tensor(A, mini, maxi)
41def clamp_tensor_(A, mini=None, maxi=None):
42 logger.debug("GEMS CLAMP_ TENSOR")
43 if mini is None and maxi is None:
44 raise ValueError("At least one of mini or maxi must not be None")
45 elif mini is None:
46 return clamp_func_max_tensor(A, maxi, out0=A)
47 elif maxi is None:
48 return clamp_func_min_tensor(A, mini, out0=A)
49 else:
50 return clamp_func_tensor(A, mini, maxi, out0=A)
53@pointwise_dynamic(
54 is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
55)
56@triton.jit
57def clamp_func(x, mini, maxi):
58 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
61@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
62@triton.jit
63def clamp_func_min(x, mini):
64 return tl.maximum(mini, x.to(tl.float32))
67@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
68@triton.jit
69def clamp_func_max(x, maxi):
70 return tl.minimum(maxi, x.to(tl.float32))
73def clamp_min(A, mini):
74 logger.debug("GEMS CLAMP MIN")
75 if mini is None:
76 raise ValueError("Mini must not be None")
77 return clamp_func_min(A, mini)
80def clamp_min_(A, mini):
81 logger.debug("GEMS CLAMP_ MIN")
82 if mini is None:
83 raise ValueError("Mini must not be None")
84 return clamp_func_min(A, mini, out0=A)
87def clamp(A, mini=None, maxi=None):
88 logger.debug("GEMS CLAMP")
89 if mini is None and maxi is None:
90 raise ValueError("At least one of mini or maxi must not be None")
91 elif mini is None:
92 return clamp_func_max(A, maxi)
93 elif maxi is None:
94 return clamp_func_min(A, mini)
95 else:
96 return clamp_func(A, mini, maxi)
99def clamp_(A, mini=None, maxi=None):
100 logger.debug("GEMS CLAMP_")
101 if mini is None and maxi is None:
102 raise ValueError("At least one of mini or maxi must not be None")
103 elif mini is None:
104 return clamp_func_max(A, maxi, out0=A)
105 elif maxi is None:
106 return clamp_func_min(A, mini, out0=A)
107 else:
108 return clamp_func(A, mini, maxi, out0=A)