Coverage for src/flag_gems/utils/models/sql.py: 55%

102 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1from hashlib import md5 

2from itertools import chain 

3from typing import ( 

4 Any, 

5 Callable, 

6 Dict, 

7 Final, 

8 Mapping, 

9 Optional, 

10 Sequence, 

11 Tuple, 

12 Type, 

13 Union, 

14) 

15 

16import sqlalchemy 

17import sqlalchemy.ext.automap 

18import sqlalchemy.orm 

19import triton 

20from typing_extensions import override 

21 

22from .model import PersistantModel 

23from .session import RollbackSession 

24 

25 

26class Base(sqlalchemy.orm.DeclarativeBase): 

27 ... 

28 

29 

30class SQLPersistantModel(PersistantModel): 

31 def __init__(self, db_url: str, *args, **kwargs) -> None: 

32 super().__init__(*args, **kwargs) 

33 self.engine: Final[sqlalchemy.engine.Engine] = sqlalchemy.create_engine(db_url) 

34 self.sql_model_pool: Dict[str, Type[Base]] = {} 

35 

36 @staticmethod 

37 def build_sql_model_by_py( 

38 name: str, 

39 keys: Mapping[str, Union[Any, Type]], 

40 values: Mapping[str, Union[Any, Type]] = {}, 

41 ) -> Type[Base]: 

42 annotations: Dict[str, Type] = { 

43 k: sqlalchemy.orm.Mapped[v if isinstance(v, Type) else type(v)] 

44 for k, v in chain(keys.items(), values.items()) 

45 } 

46 cols: Dict[str, sqlalchemy.orm.MappedColumn] = { 

47 k: sqlalchemy.orm.mapped_column(primary_key=True) for k in keys.keys() 

48 } | {k: sqlalchemy.orm.mapped_column(primary_key=False) for k in values.keys()} 

49 ModelCls: Type[Base] = type( 

50 name, 

51 (Base,), 

52 { 

53 "__annotations__": annotations, 

54 "__tablename__": name, 

55 **cols, 

56 }, 

57 ) 

58 return ModelCls 

59 

60 @staticmethod 

61 def build_sql_model_by_db( 

62 name: str, 

63 engine: sqlalchemy.engine.Engine, 

64 ) -> Optional[Type[Base]]: 

65 AutoBase: sqlalchemy.ext.automap.AutomapBase = ( 

66 sqlalchemy.ext.automap.automap_base() 

67 ) 

68 AutoBase.prepare(engine) 

69 ModelCls: Optional[Type[Base]] = AutoBase.classes.get(name) 

70 return ModelCls 

71 

72 @staticmethod 

73 def get_key_dict( 

74 keys: Sequence[Union[bool, int, float, str]], 

75 ) -> Dict[str, Union[bool, int, float, str]]: 

76 return {f"key_{i}": v for i, v in enumerate(keys)} 

77 

78 @staticmethod 

79 def get_config_dict( 

80 config: triton.Config, 

81 ) -> Dict[str, Union[bool, int, float, str]]: 

82 return { 

83 k: v 

84 for k, v in config.all_kwargs().items() 

85 if isinstance(v, (int, float, str)) 

86 } 

87 

88 def get_sql_model( 

89 self, 

90 name: str, 

91 keys: Mapping[str, Union[Any, Type]] = {}, 

92 values: Mapping[str, Union[Any, Type]] = {}, 

93 ) -> Callable[[str, Optional[Mapping[str, Type]]], Optional[Type[Base]]]: 

94 with self.lock: 

95 name: str = "{}-{}".format( 

96 name, md5("".join(keys.keys()).encode()).hexdigest() 

97 ) 

98 ModelCls: Optional[Type[Base]] = self.sql_model_pool.get(name) 

99 if ModelCls is not None: 

100 return ModelCls 

101 ModelCls = SQLPersistantModel.build_sql_model_by_db(name, self.engine) 

102 if ModelCls is not None: 

103 self.sql_model_pool[name] = ModelCls 

104 return ModelCls 

105 if not keys or not values: 

106 return None 

107 ModelCls = SQLPersistantModel.build_sql_model_by_py(name, keys, values) 

108 with self.engine.begin() as conn: 

109 conn.execute( 

110 sqlalchemy.schema.CreateTable( 

111 ModelCls.__table__, if_not_exists=True 

112 ) 

113 ) 

114 self.sql_model_pool[name] = ModelCls 

115 return ModelCls 

116 

117 @override 

118 def get_config( 

119 self, name: str, keys: Sequence[Union[bool, int, float, str]] 

120 ) -> Optional[triton.Config]: 

121 key_dict: Dict[ 

122 str, Union[bool, int, float, str] 

123 ] = SQLPersistantModel.get_key_dict(keys) 

124 ConfigCls: Optional[Type[Base]] = self.get_sql_model(name, key_dict) 

125 if ConfigCls is None: 

126 return None 

127 with RollbackSession(self.engine) as session: 

128 obj: Optional[Base] = session.get( 

129 ConfigCls, 

130 key_dict, 

131 ) 

132 if obj is None: 

133 return None 

134 obj_dict: Dict[str, Union[bool, int, float, str]] = { 

135 k.key: getattr(obj, k.key) 

136 for k in sqlalchemy.inspect(obj).mapper.columns 

137 if k.key not in key_dict 

138 } 

139 kwargs: Dict[str, Union[bool, int, float, str]] = { 

140 k: v for k, v in obj_dict.items() if k not in self.signature.parameters 

141 } 

142 config_dict: Dict[str, int] = { 

143 k: v for k, v in obj_dict.items() if k in self.signature.parameters 

144 } 

145 return triton.Config(kwargs, **config_dict) 

146 

147 @override 

148 def get_benchmark( 

149 self, 

150 name: str, 

151 keys: Sequence[Union[bool, int, float, str]], 

152 config: triton.Config, 

153 ) -> Optional[Tuple[float, float, float]]: 

154 key_dict: Dict[str, Union[bool, int, float, str]] = { 

155 **SQLPersistantModel.get_key_dict(keys), 

156 **SQLPersistantModel.get_config_dict(config), 

157 } 

158 BenchmarkCls: Optional[Type[Base]] = self.get_sql_model(name, key_dict) 

159 if BenchmarkCls is None: 

160 return None 

161 with RollbackSession(self.engine) as session: 

162 obj: Optional[Base] = session.get( 

163 BenchmarkCls, 

164 key_dict, 

165 ) 

166 if obj is None: 

167 return None 

168 p50: float = obj.p50 

169 p20: float = obj.p20 

170 p80: float = obj.p80 

171 return (p50, p20, p80) 

172 

173 def put_config( 

174 self, 

175 name: str, 

176 keys: Sequence[Union[bool, int, float, str]], 

177 config: Union[triton.Config, Dict[str, Union[bool, int, float, str]]], 

178 ) -> None: 

179 if isinstance(config, triton.Config): 

180 config: Dict[ 

181 str, Union[bool, int, float, str] 

182 ] = SQLPersistantModel.get_config_dict(config) 

183 key_dict: Dict[ 

184 str, Union[bool, int, float, str] 

185 ] = SQLPersistantModel.get_key_dict(keys) 

186 ConfigCls: Optional[Type[Base]] = self.get_sql_model( 

187 name, 

188 {k: type(v) for k, v in key_dict.items()}, 

189 {k: type(v) for k, v in config.items()}, 

190 ) 

191 if ConfigCls is not None: 

192 with RollbackSession(self.engine) as session: 

193 obj: Base = ConfigCls(**key_dict, **config) 

194 session.merge(obj) 

195 session.commit() 

196 

197 def put_benchmark( 

198 self, 

199 name: str, 

200 keys: Sequence[Union[bool, int, float, str]], 

201 config: Union[triton.Config, Dict[str, Union[bool, int, float, str]]], 

202 benchmark: Tuple[float, float, float], 

203 ) -> None: 

204 key_dict: Dict[ 

205 str, Union[bool, int, float, str] 

206 ] = SQLPersistantModel.get_key_dict(keys) 

207 if isinstance(config, triton.Config): 

208 config: Dict[ 

209 str, Union[bool, int, float, str] 

210 ] = SQLPersistantModel.get_config_dict(config) 

211 p50, p20, p80 = benchmark 

212 benchmark: Dict[str, float] = {"p50": p50, "p20": p20, "p80": p80} 

213 BenchmarkCls: Optional[Type[Base]] = self.get_sql_model( 

214 name, 

215 key_dict | config, 

216 benchmark, 

217 ) 

218 if BenchmarkCls is not None: 

219 with RollbackSession(self.engine) as session: 

220 obj: Base = BenchmarkCls(**key_dict, **config, **benchmark) 

221 session.merge(obj) 

222 session.commit()