Coverage for src/flag_gems/runtime/backend/_cambricon/ops/clamp.py: 0%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import triton
4import triton.language as tl
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(
12 is_tensor=[True, True, True, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
13)
14@triton.jit
15def clamp_func_tensor(x, mini, maxi, inplace):
16 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
19@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
20@triton.jit
21def clamp_func_min_tensor(x, mini, inplace):
22 return tl.maximum(mini, x.to(tl.float32))
25@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
26@triton.jit
27def clamp_func_max_tensor(x, maxi, inplace):
28 return tl.minimum(maxi, x.to(tl.float32))
31def clamp_tensor(A, mini=None, maxi=None):
32 logger.debug("GEMS_CAMBRICON CLAMP TENSOR")
33 if mini is None and maxi is None:
34 raise ValueError("At least one of mini or maxi must not be None")
35 elif mini is None:
36 return clamp_func_max_tensor(A, maxi, False)
37 elif maxi is None:
38 return clamp_func_min_tensor(A, mini, False)
39 else:
40 return clamp_func_tensor(A, mini, maxi, False)
43def clamp_tensor_(A, mini=None, maxi=None):
44 logger.debug("GEMS_CAMBRICON CLAMP_ TENSOR")
45 if mini is None and maxi is None:
46 raise ValueError("At least one of mini or maxi must not be None")
47 elif mini is None:
48 return clamp_func_max_tensor(A, maxi, True, out0=A)
49 elif maxi is None:
50 return clamp_func_min_tensor(A, mini, True, out0=A)
51 else:
52 return clamp_func_tensor(A, mini, maxi, True, out0=A)
55@pointwise_dynamic(
56 is_tensor=[True, False, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
57)
58@triton.jit
59def clamp_func(x, mini, maxi, inplace):
60 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
63@pointwise_dynamic(
64 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
65)
66@triton.jit
67def clamp_func_min(x, mini, inplace):
68 return tl.maximum(mini, x.to(tl.float32))
71@pointwise_dynamic(
72 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
73)
74@triton.jit
75def clamp_func_max(x, maxi, inplace):
76 return tl.minimum(maxi, x.to(tl.float32))
79def clamp_min(A, mini):
80 logger.debug("GEMS_CAMBRICON CLAMP MIN")
81 if mini is None:
82 raise ValueError("Mini must not be None")
83 return clamp_func_min(A, mini, False)
86def clamp_min_(A, mini):
87 logger.debug("GEMS_CAMBRICON CLAMP_ MIN")
88 if mini is None:
89 raise ValueError("Mini must not be None")
90 return clamp_func_min(A, mini, True, out0=A)
93def clamp(A, mini=None, maxi=None):
94 logger.debug("GEMS_CAMBRICON CLAMP")
95 if mini is None and maxi is None:
96 raise ValueError("At least one of mini or maxi must not be None")
97 elif mini is None:
98 return clamp_func_max(A, maxi, False)
99 elif maxi is None:
100 return clamp_func_min(A, mini, False)
101 else:
102 return clamp_func(A, mini, maxi, False)
105def clamp_(A, mini=None, maxi=None):
106 logger.debug("GEMS_CAMBRICON CLAMP_")
107 if mini is None and maxi is None:
108 raise ValueError("At least one of mini or maxi must not be None")
109 elif mini is None:
110 return clamp_func_max(A, maxi, True, out0=A)
111 elif maxi is None:
112 return clamp_func_min(A, mini, True, out0=A)
113 else:
114 return clamp_func(A, mini, maxi, True, out0=A)