Coverage for src/flag_gems/utils/models/sql.py: 55%
102 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1from hashlib import md5
2from itertools import chain
3from typing import (
4 Any,
5 Callable,
6 Dict,
7 Final,
8 Mapping,
9 Optional,
10 Sequence,
11 Tuple,
12 Type,
13 Union,
14)
16import sqlalchemy
17import sqlalchemy.ext.automap
18import sqlalchemy.orm
19import triton
20from typing_extensions import override
22from .model import PersistantModel
23from .session import RollbackSession
26class Base(sqlalchemy.orm.DeclarativeBase):
27 ...
30class SQLPersistantModel(PersistantModel):
31 def __init__(self, db_url: str, *args, **kwargs) -> None:
32 super().__init__(*args, **kwargs)
33 self.engine: Final[sqlalchemy.engine.Engine] = sqlalchemy.create_engine(db_url)
34 self.sql_model_pool: Dict[str, Type[Base]] = {}
36 @staticmethod
37 def build_sql_model_by_py(
38 name: str,
39 keys: Mapping[str, Union[Any, Type]],
40 values: Mapping[str, Union[Any, Type]] = {},
41 ) -> Type[Base]:
42 annotations: Dict[str, Type] = {
43 k: sqlalchemy.orm.Mapped[v if isinstance(v, Type) else type(v)]
44 for k, v in chain(keys.items(), values.items())
45 }
46 cols: Dict[str, sqlalchemy.orm.MappedColumn] = {
47 k: sqlalchemy.orm.mapped_column(primary_key=True) for k in keys.keys()
48 } | {k: sqlalchemy.orm.mapped_column(primary_key=False) for k in values.keys()}
49 ModelCls: Type[Base] = type(
50 name,
51 (Base,),
52 {
53 "__annotations__": annotations,
54 "__tablename__": name,
55 **cols,
56 },
57 )
58 return ModelCls
60 @staticmethod
61 def build_sql_model_by_db(
62 name: str,
63 engine: sqlalchemy.engine.Engine,
64 ) -> Optional[Type[Base]]:
65 AutoBase: sqlalchemy.ext.automap.AutomapBase = (
66 sqlalchemy.ext.automap.automap_base()
67 )
68 AutoBase.prepare(engine)
69 ModelCls: Optional[Type[Base]] = AutoBase.classes.get(name)
70 return ModelCls
72 @staticmethod
73 def get_key_dict(
74 keys: Sequence[Union[bool, int, float, str]],
75 ) -> Dict[str, Union[bool, int, float, str]]:
76 return {f"key_{i}": v for i, v in enumerate(keys)}
78 @staticmethod
79 def get_config_dict(
80 config: triton.Config,
81 ) -> Dict[str, Union[bool, int, float, str]]:
82 return {
83 k: v
84 for k, v in config.all_kwargs().items()
85 if isinstance(v, (int, float, str))
86 }
88 def get_sql_model(
89 self,
90 name: str,
91 keys: Mapping[str, Union[Any, Type]] = {},
92 values: Mapping[str, Union[Any, Type]] = {},
93 ) -> Callable[[str, Optional[Mapping[str, Type]]], Optional[Type[Base]]]:
94 with self.lock:
95 name: str = "{}-{}".format(
96 name, md5("".join(keys.keys()).encode()).hexdigest()
97 )
98 ModelCls: Optional[Type[Base]] = self.sql_model_pool.get(name)
99 if ModelCls is not None:
100 return ModelCls
101 ModelCls = SQLPersistantModel.build_sql_model_by_db(name, self.engine)
102 if ModelCls is not None:
103 self.sql_model_pool[name] = ModelCls
104 return ModelCls
105 if not keys or not values:
106 return None
107 ModelCls = SQLPersistantModel.build_sql_model_by_py(name, keys, values)
108 with self.engine.begin() as conn:
109 conn.execute(
110 sqlalchemy.schema.CreateTable(
111 ModelCls.__table__, if_not_exists=True
112 )
113 )
114 self.sql_model_pool[name] = ModelCls
115 return ModelCls
117 @override
118 def get_config(
119 self, name: str, keys: Sequence[Union[bool, int, float, str]]
120 ) -> Optional[triton.Config]:
121 key_dict: Dict[
122 str, Union[bool, int, float, str]
123 ] = SQLPersistantModel.get_key_dict(keys)
124 ConfigCls: Optional[Type[Base]] = self.get_sql_model(name, key_dict)
125 if ConfigCls is None:
126 return None
127 with RollbackSession(self.engine) as session:
128 obj: Optional[Base] = session.get(
129 ConfigCls,
130 key_dict,
131 )
132 if obj is None:
133 return None
134 obj_dict: Dict[str, Union[bool, int, float, str]] = {
135 k.key: getattr(obj, k.key)
136 for k in sqlalchemy.inspect(obj).mapper.columns
137 if k.key not in key_dict
138 }
139 kwargs: Dict[str, Union[bool, int, float, str]] = {
140 k: v for k, v in obj_dict.items() if k not in self.signature.parameters
141 }
142 config_dict: Dict[str, int] = {
143 k: v for k, v in obj_dict.items() if k in self.signature.parameters
144 }
145 return triton.Config(kwargs, **config_dict)
147 @override
148 def get_benchmark(
149 self,
150 name: str,
151 keys: Sequence[Union[bool, int, float, str]],
152 config: triton.Config,
153 ) -> Optional[Tuple[float, float, float]]:
154 key_dict: Dict[str, Union[bool, int, float, str]] = {
155 **SQLPersistantModel.get_key_dict(keys),
156 **SQLPersistantModel.get_config_dict(config),
157 }
158 BenchmarkCls: Optional[Type[Base]] = self.get_sql_model(name, key_dict)
159 if BenchmarkCls is None:
160 return None
161 with RollbackSession(self.engine) as session:
162 obj: Optional[Base] = session.get(
163 BenchmarkCls,
164 key_dict,
165 )
166 if obj is None:
167 return None
168 p50: float = obj.p50
169 p20: float = obj.p20
170 p80: float = obj.p80
171 return (p50, p20, p80)
173 def put_config(
174 self,
175 name: str,
176 keys: Sequence[Union[bool, int, float, str]],
177 config: Union[triton.Config, Dict[str, Union[bool, int, float, str]]],
178 ) -> None:
179 if isinstance(config, triton.Config):
180 config: Dict[
181 str, Union[bool, int, float, str]
182 ] = SQLPersistantModel.get_config_dict(config)
183 key_dict: Dict[
184 str, Union[bool, int, float, str]
185 ] = SQLPersistantModel.get_key_dict(keys)
186 ConfigCls: Optional[Type[Base]] = self.get_sql_model(
187 name,
188 {k: type(v) for k, v in key_dict.items()},
189 {k: type(v) for k, v in config.items()},
190 )
191 if ConfigCls is not None:
192 with RollbackSession(self.engine) as session:
193 obj: Base = ConfigCls(**key_dict, **config)
194 session.merge(obj)
195 session.commit()
197 def put_benchmark(
198 self,
199 name: str,
200 keys: Sequence[Union[bool, int, float, str]],
201 config: Union[triton.Config, Dict[str, Union[bool, int, float, str]]],
202 benchmark: Tuple[float, float, float],
203 ) -> None:
204 key_dict: Dict[
205 str, Union[bool, int, float, str]
206 ] = SQLPersistantModel.get_key_dict(keys)
207 if isinstance(config, triton.Config):
208 config: Dict[
209 str, Union[bool, int, float, str]
210 ] = SQLPersistantModel.get_config_dict(config)
211 p50, p20, p80 = benchmark
212 benchmark: Dict[str, float] = {"p50": p50, "p20": p20, "p80": p80}
213 BenchmarkCls: Optional[Type[Base]] = self.get_sql_model(
214 name,
215 key_dict | config,
216 benchmark,
217 )
218 if BenchmarkCls is not None:
219 with RollbackSession(self.engine) as session:
220 obj: Base = BenchmarkCls(**key_dict, **config, **benchmark)
221 session.merge(obj)
222 session.commit()