CellModules
DatabaseManager.py
Go to the documentation of this file.
1import sqlite3
2import pandas as pd
3import json
4import re
5
7 """
8 Manager for interacting with a SQLite database. This class provides methods to add data, query data, and manage tables in the database.
9
10 Parameters
11 ----------
12 - **db_filename** : str
13 The filename of the SQLite database.
14
15 Methods
16 -------
17 - **add_conditions** : Adds a DataFrame of conditions to the 'condition' table in the database.
18 - **add_parameters** : Adds a list of parameters to the 'parameters' table in the database.
19 - **add_data** : Adds a DataFrame of data to a specified table in the database.
20 - **add_error** : Adds a DataFrame of error data to the 'error' table in the database.
21 - **get** : Retrieves records from a table based on filter conditions.
22 - **getByID** : Retrieves records from a table based on ID_PARAMETER and/or ID_CONDITION.
23 - **getPareto** : Retrieves the Pareto front for given errors.
24 - **getBestScore** : Retrieves the best scores for a given error.
25 - **getBest** : Retrieves the best records from a table based on a given error.
26 - **summarize** : Returns a summary of the database.
27
28 Example
29 -------
30 ```python
31 conditions = pd.DataFrame([
32 [0.5, 1],
33 [0.2, 2],
34 [0.2, 5]
35 ], columns=['info1', 'info2'])
36 parameters_list = LHSIterator(param,10)
37
38 with DatabaseManager("example.db") as db:
39 db.add_parameters(parameters_list)
40 db.add_conditions(conditions)
41
42 for states,other in MultiSimu(run_simu,parameters_list,conditions=conditions.values.tolist(),replicat=5,withTqdm=True,batch_size_level='param'):
43 db.add_data('states', states.reset_index())
44 db.add_data('otherData', other.reset_index())
45 error = your_error_function(states)
46 db.add_error(error)
47 ```
48 """
49
50 def __init__(self, db_filename: str):
51 self.db_filename = db_filename
53
54 def open(self):
55 """
56 Opens a connection to the database.
57 """
58 self.conn = sqlite3.connect(self.db_filename)
59 cursor = self.conn.cursor()
60 cursor.execute("PRAGMA journal_mode=WAL;")
61 cursor.execute("PRAGMA synchronous = NORMAL;")
62 cursor.execute("PRAGMA cache_size = -50000;")
63 self._check_tables()
64 return self
65
66 def close(self):
67 """
68 Closes the connection to the database.
69 """
70 self.conn.close()
71
72 def __enter__(self):
73 self.open()
74 return self
75
76 def __exit__(self, exc_type, exc_val, exc_tb):
77 self.close()
78
79 def add_conditions(self, conditions_df: pd.DataFrame, id_column: str = None):
80 """
81 Adds a DataFrame of conditions to the 'condition' table in the database.
82
83 Parameters
84 ----------
85 - **conditions_df** : pd.DataFrame
86 The DataFrame containing the conditions to add.
87 - **id_column** : str, optional
88 The column to use as the ID_CONDITION. If not provided, the index will be used.
89 """
90 for i in range(conditions_df.shape[1]):
91 if hasattr(conditions_df.iloc[0, i], '__iter__'):
92 conditions_df.iloc[:, i] = conditions_df.iloc[:, i].apply(lambda x: json.dumps(x))
93 if id_column is None:
94 conditions_df = conditions_df.reset_index().rename(columns={'index': 'ID_CONDITION'})
95 elif id_column != 'ID_CONDITION':
96 conditions_df = conditions_df.rename(columns={id_column: 'ID_CONDITION'})
97
98 with self.conn:
99 conditions_df.to_sql('condition', self.conn, if_exists='fail', index=False)
100 self._check_tables()
101 self._create_index('condition', 'ID_CONDITION')
102
103 def add_error(self, data_df: pd.DataFrame):
104 """
105 Adds a DataFrame of error data to the 'error' table in the database.
106
107 Parameters
108 ----------
109 - **data_df** : pd.DataFrame
110 The DataFrame containing the error data to add.
111 """
112 already_exists = self.available_tables and 'error' in self.available_tables
113 with self.conn:
114 data_df.to_sql('error', self.conn, if_exists='append', index=False)
115 if not already_exists:
116 self._check_tables()
117 if 'ID_PARAMETER' in data_df.columns:
118 self._create_index('error', 'ID_PARAMETER')
119 if 'ID_CONDITION' in data_df.columns:
120 self._create_index('error', 'ID_CONDITION')
121
122 def add_parameters(self, parameters_list: list):
123 """
124 Adds a list of parameters to the 'parameters' table in the database.
125
126 Parameters
127 ----------
128 - **parameters_list** : list
129 The list of parameters to add.
130 """
131 flattened_parameters = [self._flatten_dict(p) for p in parameters_list]
132 parameters_df = pd.DataFrame(flattened_parameters)
133 parameters_df = parameters_df.reset_index().rename(columns={'index': 'ID_PARAMETER'})
134 with self.conn:
135 parameters_df.to_sql('parameters', self.conn, if_exists='fail', index=False)
136 self._check_tables()
137 self._create_index('parameters', 'ID_PARAMETER')
138
139 def add_data(self, table_name: str, data_df: pd.DataFrame):
140 """
141 Adds a DataFrame of data to a specified table in the database.
142
143 Parameters
144 ----------
145 - **table_name** : str
146 The name of the table to add data to.
147 - **data_df** : pd.DataFrame
148 The DataFrame containing the data to add.
149 """
150 already_exists = self.available_tables and table_name in self.available_tables
151 with self.conn:
152 data_df.to_sql(table_name, self.conn, if_exists='append', index=False)
153 if not already_exists:
154 self._check_tables()
155
156 if 'ID_PARAMETER' in data_df.columns:
157 self._create_index(table_name, 'ID_PARAMETER')
158 if 'ID_CONDITION' in data_df.columns:
159 self._create_index(table_name, 'ID_CONDITION')
160
161
162 def _parse_json_conditions(self, filter_conditions):
163 json_path_regex = re.compile(r"(\w+)\.(\w+(\.\w+)*)")
164 def replace_json_path(match):
165 table_alias = match.group(1)
166 json_path = match.group(2).replace('.', '$')
167 return f"{table_alias}.{json_path}"
168 converted_conditions = json_path_regex.sub(replace_json_path, filter_conditions)
169 return converted_conditions
170
171 def get_parameters(self, filter_conditions: str = None, raw: bool = False) -> pd.DataFrame:
172 """
173 Retrieves the parameters based on filter conditions.
174
175 Parameters
176 ----------
177 - **filter_conditions** : str, optional
178 The SQL filter conditions to apply.
179 - **raw** : bool, optional
180 If True, returns the raw DataFrame without JSON parsing.
181
182 Returns
183 -------
184 - **pd.DataFrame** : The DataFrame containing the retrieved records.
185 """
186 return self.get('parameters', filter_conditions)
187
188 def get_conditions(self, filter_conditions: str = None) -> pd.DataFrame:
189 """
190 Retrieves records from the 'condition' table based on filter conditions.
191
192 Parameters
193 ----------
194 - **filter_conditions** : str, optional
195 The SQL filter conditions to apply.
196
197 Returns
198 -------
199 - **pd.DataFrame** : The DataFrame containing the retrieved records.
200 """
201 return self.get('condition', filter_conditions, rawParameters=rawParameters)
202
203
204 def _build_clauses(self, table: str, filter_conditions: str = None, alias_prefix: str = 'main_table'):
205 """
206 Build JOIN and WHERE clauses based on filter_conditions, similar to get().
207 Returns: join_sql, where_sql, alias_dict
208 """
209 alias_dict = {table: alias_prefix}
210 alias_count = 0
211 join_clauses = []
212 referenced = set()
213
214 if filter_conditions:
215 cond = self._parse_json_conditions(filter_conditions)
216 for tbl in re.findall(r'(\w+)\.', cond):
217 referenced.add(tbl)
218 filter_conditions = cond
219
220 main_cols = self._get_table_columns(table)
221 # build joins for each referenced table
222 for ref in referenced:
223 if ref in self.available_tables and ref != table:
224 alias = f"t_{ref}_{alias_count}"
225 alias_count += 1
226 alias_dict[ref] = alias
227 ref_cols = self._get_table_columns(ref)
228 if 'ID_PARAMETER' in main_cols and 'ID_PARAMETER' in ref_cols:
229 join_clauses.append(f"LEFT JOIN {ref} AS {alias} ON main_table.ID_PARAMETER = {alias}.ID_PARAMETER")
230 elif 'ID_CONDITION' in main_cols and 'ID_CONDITION' in ref_cols:
231 join_clauses.append(f"LEFT JOIN {ref} AS {alias} ON main_table.ID_CONDITION = {alias}.ID_CONDITION")
232 elif 'ID_PARAMETER' in ref_cols:
233 join_clauses.append(f"LEFT JOIN {ref} AS {alias} ON {alias}.ID_PARAMETER = main_table.ID_PARAMETER")
234 elif 'ID_CONDITION' in ref_cols:
235 join_clauses.append(f"LEFT JOIN {ref} AS {alias} ON {alias}.ID_CONDITION = main_table.ID_CONDITION")
236
237 # apply aliases in filter
238 where_sql = ''
239 if filter_conditions:
240 for tbl, al in alias_dict.items():
241 filter_conditions = re.sub(rf"\b{tbl}\.", f"{al}.", filter_conditions)
242 where_sql = f"WHERE {filter_conditions}"
243
244 join_sql = ' '.join(join_clauses)
245 return join_sql, where_sql, alias_dict
246
247 def get(self, table: str, filter_conditions: str = None, rawParameters: bool = False) -> pd.DataFrame:
248 """
249 Retrieves records from a table based on filter conditions.
250
251 Parameters
252 ----------
253 - **table** : str
254 The name of the table to retrieve records from.
255 - **filter_conditions** : str, optional
256 The SQL filter conditions to apply.
257 - **rawParameters** : bool, optional
258 If True, returns the raw DataFrame without JSON parsing.
259
260 Returns
261 -------
262 - **pd.DataFrame** : The DataFrame containing the retrieved records.
263 """
264 if table not in self.available_tables:
265 raise ValueError(f"Table '{table}' does not exist.")
266 join_sql, where_sql, _ = self._build_clauses(table, filter_conditions)
267 query = f"SELECT DISTINCT main_table.* FROM {table} AS main_table {join_sql} {where_sql}".strip()
268 df = pd.read_sql(query, self.conn)
269
270 if table == 'condition':
271 for col in df.columns:
272 if df[col].dtype == object and df[col].str.startswith(('[','{')).any():
273 df[col] = df[col].apply(lambda x: json.loads(x) if isinstance(x, str) else x)
274 if table == 'parameters' and not rawParameters :
275 return pd.DataFrame([self._unflat(row.to_dict()) for _, row in df.iterrows()])
276 elif table == 'parameters' and rawParameters:
277 return df.rename(columns=lambda x: x.replace('$', '.'))
278 return df
279
280 def iterOn(self, table: str, group, filter_conditions: str = None):
281 # normalize group
282 group_cols = (group,) if isinstance(group, str) else tuple(group)
283 join_sql, where_sql, _ = self._build_clauses(table, filter_conditions)
284 # distinct keys
285 cols = ', '.join(f"main_table.\"{c}\"" for c in group_cols)
286 distinct_sql = f"SELECT DISTINCT {cols} FROM {table} AS main_table {join_sql} {where_sql}".strip()
287
288 all_cols = list(self._get_table_columns(table))
289 cols_sql = ', '.join([f'"{c}"' for c in all_cols if c not in group_cols])
290 where_clauses = ' AND '.join(f'{table}."{c}" = ?' for c in group_cols)
291 iter_query = f"SELECT {cols_sql} FROM {table} WHERE {where_clauses}"
292
293 for key in self.conn.execute(distinct_sql):
294 df_part = pd.read_sql(iter_query, self.conn, params=list(key))
295 yield key, df_part
296
297 def getBestScore(self, errorName: str, n: int = 1, forEachCondition: bool = False) -> pd.DataFrame:
298 """
299 Retrieves the best scores for a given error.
300
301 Parameters
302 ----------
303 - **errorName** : str
304 The name of the error column to base the scores on.
305 - **n** : int, optional
306 The number of top scores to retrieve.
307 - **forEachCondition** : bool, optional
308 Whether to retrieve the best scores for each condition separately.
309
310 Returns
311 -------
312 - **pd.DataFrame** : The DataFrame containing the best scores.
313 """
314 if forEachCondition:
315 query = f'''WITH ranked_errors AS (
316 SELECT
317 ID_PARAMETER,
318 ID_CONDITION,
319 "{errorName}",
320 ROW_NUMBER() OVER (PARTITION BY ID_CONDITION ORDER BY "{errorName}" ASC, ID_PARAMETER ASC) AS rn
321 FROM error)
322 SELECT
323 ID_PARAMETER,
324 ID_CONDITION,
325 "{errorName}"
326 FROM
327 ranked_errors
328 WHERE
329 rn <= {n};'''
330 else:
331 query = f'''SELECT ID_PARAMETER, AVG("{errorName}") AS avg_error
332 FROM error
333 GROUP BY ID_PARAMETER
334 ORDER BY avg_error ASC
335 LIMIT {n};'''
336 return pd.read_sql(query, self.conn)
337
338 def getBest(self, table: str, errorName: str, n: int = 1, forEachCondition: bool = False) -> pd.DataFrame:
339 """
340 Retrieves the best records from a table based on a given error.
341
342 Parameters
343 ----------
344 - **table** : str
345 The name of the table to retrieve records from.
346 - **errorName** : str
347 The name of the error column to base the records on.
348 - **n** : int, optional
349 The number of top records to retrieve.
350 - **forEachCondition** : bool, optional
351 Whether to retrieve the best records for each condition separately.
352
353 Returns
354 -------
355 - **pd.DataFrame** : The DataFrame containing the best records.
356 """
357 if forEachCondition:
358 table_columns = self._get_table_columns(table)
359 query = f'''WITH ranked_errors AS (
360 SELECT
361 ID_PARAMETER,
362 ID_CONDITION,
363 "{errorName}",
364 ROW_NUMBER() OVER (PARTITION BY ID_CONDITION ORDER BY "{errorName}" ASC) AS rn
365 FROM
366 error
367 )
368 SELECT *
369 FROM {table} t
370 WHERE {'(t.ID_PARAMETER, t.ID_CONDITION)' if 'ID_CONDITION' in table_columns else 't.ID_PARAMETER' } IN (
371 SELECT {'ID_PARAMETER, ID_CONDITION' if 'ID_CONDITION' in table_columns else 'ID_PARAMETER' }
372 FROM ranked_errors
373 WHERE rn <= {n}
374 );'''
375 else:
376 query = f'''SELECT *
377 FROM {table}
378 WHERE ID_PARAMETER IN (
379 SELECT ID_PARAMETER
380 FROM error
381 GROUP BY ID_PARAMETER
382 ORDER BY SUM("{errorName}") ASC
383 LIMIT {n}
384 );'''
385 return pd.read_sql(query, self.conn)
386
387 def getByID(self, table: str, ID_PARAMETER: int = None, ID_CONDITION: int = None) -> pd.DataFrame:
388 """
389 Retrieves records from a table based on ID_PARAMETER and/or ID_CONDITION.
390
391 Parameters
392 ----------
393 - **table** : str
394 The name of the table to retrieve records from.
395 - **ID_PARAMETER** : int or list of int, optional
396 The ID_PARAMETER to filter by.
397 - **ID_CONDITION** : int or list of int, optional
398 The ID_CONDITION to filter by.
399
400 Returns
401 -------
402 - **pd.DataFrame** : The DataFrame containing the retrieved records.
403 """
404 whereP = ''
405 whereC = ''
406 if isinstance(ID_PARAMETER, list):
407 whereP = f'ID_PARAMETER IN ({",".join(map(str, ID_PARAMETER))})'
408 elif isinstance(ID_PARAMETER, int):
409 whereP = f'ID_PARAMETER = {ID_PARAMETER}'
410 elif ID_PARAMETER is not None:
411 raise ValueError("ID_PARAMETER must be an integer or a list of integers.")
412 if isinstance(ID_CONDITION, list):
413 whereC = f'ID_CONDITION IN ({",".join(map(str, ID_CONDITION))})'
414 elif isinstance(ID_CONDITION, int):
415 whereC = f'ID_CONDITION = {ID_CONDITION}'
416 elif ID_CONDITION is not None:
417 raise ValueError("ID_CONDITION must be an integer or a list of integers.")
418
419 where_clauses = [clause for clause in [whereP, whereC] if clause]
420 where = ' AND '.join(where_clauses)
421
422 query = f'''SELECT *
423 FROM {table}
424 WHERE {where};'''
425 return pd.read_sql(query, self.conn)
426
427 def getPareto(self, errorNameList: list, forEachCondition: bool = False) -> pd.DataFrame:
428 """
429 Retrieves the Pareto front for given errors.
430
431 Parameters
432 ----------
433 - **errorNameList** : list
434 The list of error column names to consider.
435 - **forEachCondition** : bool, optional
436 Whether to retrieve the Pareto front for each condition separately.
437
438 Returns
439 -------
440 - **pd.DataFrame** : The DataFrame containing the Pareto front.
441 """
442 if forEachCondition:
443 query = f'''WITH avg_errors AS (
444 SELECT
445 e.ID_PARAMETER,
446 e.ID_CONDITION,
447 {', '.join(['AVG(e."'+en+'") AS '+en for en in errorNameList])}
448 FROM
449 error e
450 GROUP BY
451 e.ID_PARAMETER, e.ID_CONDITION
452 )
453 SELECT t1.*
454 FROM avg_errors t1
455 LEFT JOIN avg_errors t2
456 ON t1.ID_CONDITION = t2.ID_CONDITION
457 AND t1.ID_PARAMETER != t2.ID_PARAMETER
458 AND {' AND '.join(['t1."'+en+'" >= t2."'+en+'"' for en in errorNameList])}
459 AND ({' OR '.join(['t1."'+en+'" > t2."'+en+'"' for en in errorNameList])})
460 WHERE t2.ID_PARAMETER IS NULL;'''
461 else:
462 query = f'''WITH avg_errors AS (
463 SELECT
464 e.ID_PARAMETER,
465 {', '.join(['AVG(e."'+en+'") AS '+en for en in errorNameList])}
466 FROM
467 error e
468 GROUP BY
469 e.ID_PARAMETER
470 )
471 SELECT t1.*
472 FROM avg_errors t1
473 LEFT JOIN avg_errors t2
474 ON t1.ID_PARAMETER != t2.ID_PARAMETER
475 AND {' AND '.join(['t1."'+en+'" >= t2."'+en+'"' for en in errorNameList])}
476 AND ({' OR '.join(['t1."'+en+'" > t2."'+en+'"' for en in errorNameList])})
477 WHERE t2.ID_PARAMETER IS NULL;'''
478
479 return pd.read_sql(query, self.conn)
480
481 def query(self, query: str) -> pd.DataFrame:
482 """
483 Executes a raw SQL query and returns the result as a DataFrame.
484
485 Parameters
486 ----------
487 - **query** : str
488 The SQL query to execute.
489
490 Returns
491 -------
492 - **pd.DataFrame** : The DataFrame containing the query results.
493 """
494 df = pd.read_sql(query, self.conn)
495 return df
496
497 def __getitem__(self, table: str) -> pd.DataFrame:
498 df = pd.read_sql(f"SELECT * FROM {table}", self.conn)
499 df.columns = [c.replace('$', '.') for c in df.columns.values]
500 return df
501
502 def summarize(self) -> str:
503 """
504 Returns a summary of the database.
505
506 Returns
507 -------
508 - **str** : The summary of the database.
509 """
510 summary = []
511 summary.append("Database Summary:")
512 for table in self.available_tables:
513 df = pd.read_sql(f"SELECT * FROM {table} LIMIT 5", self.conn)
514 count_query = f"SELECT COUNT(*) FROM {table}"
515 count = self.conn.execute(count_query).fetchone()[0]
516 summary.append(f"\nTable: {table}")
517 summary.append(f"Number of rows: {count}")
518 summary.append("Preview:")
519 summary.append(df.head().to_string())
520 return "\n".join(summary)
521
522 def __str__(self) -> str:
523 return self.summarize()
524
525
526 def _check_tables(self):
527 query = "SELECT name FROM sqlite_master WHERE type='table';"
528 tables = self.conn.execute(query).fetchall()
529 self.available_tables = {table[0] for table in tables}
530
531 def _get_table_columns(self, table_name):
532 if table_name in self._table_columns_cache:
533 return self._table_columns_cache[table_name]
534
535 # Sinon, on interroge SQLite et on stocke
536 query = f"PRAGMA table_info({table_name})"
537 cols = {col[1] for col in self.conn.execute(query).fetchall()}
538 self._table_columns_cache[table_name] = cols
539 return cols
540
541 def _create_index(self, table_name, column_name):
542 try:
543 query = f"CREATE INDEX IF NOT EXISTS idx_{table_name}_{column_name} ON {table_name} ({column_name});"
544 self.conn.execute(query)
545 except sqlite3.Error as e:
546 print(f"Failed to create index on {table_name}({column_name}): {e}")
547
548 def _flatten_dict(self, d, parent_key='', sep='$'):
549 items = []
550 for k, v in d.items():
551 new_key = f"{parent_key}{sep}{k}" if parent_key else k
552 if isinstance(v, dict):
553 items.extend(self._flatten_dict(v, new_key, sep=sep).items())
554 elif isinstance(v, list):
555 items.append((new_key, json.dumps(v))) # Convert lists to JSON strings
556 else:
557 items.append((new_key, v))
558 return dict(items)
559
560 def _unflat(self, d, sep='$'):
561 result_dict = {}
562 for key, value in d.items():
563 parts = key.split(sep)
564 d = result_dict
565 for part in parts[:-1]:
566 if part not in d:
567 d[part] = {}
568 d = d[part]
569 # Convert back JSON strings to lists if needed
570 try:
571 if isinstance(value, str) and (value.startswith('[') or value.startswith('{')):
572 value = json.loads(value)
573 except (TypeError, json.JSONDecodeError):
574 pass
575 d[parts[-1]] = value
576 return result_dict
def add_parameters(self, list parameters_list)
def _flatten_dict(self, d, parent_key='', sep='$')
def add_conditions(self, pd.DataFrame conditions_df, str id_column=None)
def add_error(self, pd.DataFrame data_df)
def __exit__(self, exc_type, exc_val, exc_tb)
def _build_clauses(self, str table, str filter_conditions=None, str alias_prefix='main_table')
pd.DataFrame getBest(self, str table, str errorName, int n=1, bool forEachCondition=False)
def _parse_json_conditions(self, filter_conditions)
pd.DataFrame get_parameters(self, str filter_conditions=None, bool raw=False)
pd.DataFrame get(self, str table, str filter_conditions=None, bool rawParameters=False)
pd.DataFrame __getitem__(self, str table)
def _create_index(self, table_name, column_name)
pd.DataFrame getPareto(self, list errorNameList, bool forEachCondition=False)
pd.DataFrame query(self, str query)
def iterOn(self, str table, group, str filter_conditions=None)
pd.DataFrame getByID(self, str table, int ID_PARAMETER=None, int ID_CONDITION=None)
pd.DataFrame getBestScore(self, str errorName, int n=1, bool forEachCondition=False)
def add_data(self, str table_name, pd.DataFrame data_df)
pd.DataFrame get_conditions(self, str filter_conditions=None)