11from contextlib
import contextmanager
12from typing
import Any, Optional, Iterable
13from urllib.parse
import unquote
14import matplotlib.pyplot
as plt
19from .LHSIterator
import LHSIterator
27 data = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
28 return base64.b64encode(data).decode(
"ascii")
32 data = base64.b64decode(payload.encode(
"ascii"))
33 return pickle.loads(data)
39 Return the first 'user' frame (
not this lib,
not stdlib).
40 Works better
with contextmanager wrappers (contextlib).
42 lib_file = os.path.abspath(__file__) if "__file__" in globals()
else None
44 stdlib_dir = os.path.abspath(sysconfig.get_paths()[
"stdlib"])
45 purelib_dir = os.path.abspath(sysconfig.get_paths().
get(
"purelib",
""))
46 platlib_dir = os.path.abspath(sysconfig.get_paths().
get(
"platlib",
""))
48 def is_internal_frame(fr) -> bool:
49 filename = fr.f_code.co_filename
52 if not os.path.isabs(filename):
56 filename = os.path.abspath(filename)
59 if lib_file
is not None and filename == lib_file:
63 if filename.startswith(stdlib_dir + os.sep):
67 if purelib_dir
and filename.startswith(purelib_dir + os.sep):
69 if platlib_dir
and filename.startswith(platlib_dir + os.sep):
74 fr = inspect.currentframe()
81 if not is_internal_frame(fr):
95 return fr.f_globals, fr.f_locals
100 Robust-ish getsource for normal .py + custom notebooks
if linecache
is populated.
103 return textwrap.dedent(inspect.getsource(fn))
105 code = getattr(fn,
"__code__",
None)
109 filename = code.co_filename
110 lines = linecache.getlines(filename)
115 tree = ast.parse(src)
118 target_name = fn.__name__
119 target_lineno = getattr(code,
"co_firstlineno",
None)
121 for node
in tree.body:
122 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
and node.name == target_name:
123 if target_lineno
is None or node.lineno == target_lineno:
124 if hasattr(node,
"end_lineno")
and node.end_lineno
is not None:
125 return textwrap.dedent(
"".join(lines[node.lineno - 1: node.end_lineno]))
126 return textwrap.dedent(
"".join(lines[node.lineno - 1:]))
137 Lightweight helper to standardize what is stored
in Optuna study/trial user_attrs.
140 - Typed helpers
for trial outputs (timeseries)
141 - Explicit function registration (suggest/run/eval)
142 - Optional context capture via `
with archiver.register(): ...`
143 to archive helper functions + runtime objects automatically
144 into study_user_attributes[
'backupContext'].
147 KEY_ENTRY_SUGGEST = "entrypoint:suggest"
148 KEY_ENTRY_RUN =
"entrypoint:run"
149 KEY_ENTRY_EVAL =
"entrypoint:eval"
150 KEY_CONTEXT =
"backupContext"
153 CONTEXT_ID =
"__default__"
157 self.backupContext: dict[str, dict[str, Any]] = {}
160 self._entrypoints_cache: dict[str, str] = {}
164 self.
_con_custom: Optional[sqlite3.Connection] =
None
166 if study
is not None:
192 "No custom DB connection yet. Call set_study(study) first."
197 if self.
study is None:
198 raise RuntimeError(
"OptunaArchive.study is None. Attach a study first (OptunaArchive(study) or set_study()).")
206 CREATE TABLE IF NOT EXISTS isicell_context (
207 context_id TEXT NOT NULL,
210 payload TEXT NOT NULL,
211 PRIMARY KEY (context_id, key)
216 CREATE TABLE IF NOT EXISTS isicell_context_entrypoints (
217 context_id TEXT NOT NULL,
218 kind TEXT NOT NULL, -- suggest | run | eval
220 PRIMARY KEY (context_id, kind)
228 Flush in-memory backupContext to custom table.
229 Upsert behavior (safe
if called multiple times).
231 if self.
_con_custom is None or not self.backupContext:
237 for name, entry
in self.backupContext.items():
238 mode = entry.get(
"mode")
239 payload = entry.get(
"payload")
243 payload_txt = json.dumps(payload)
246 payload_txt = str(payload)
249 INSERT INTO isicell_context (context_id, key, mode, payload)
251 ON CONFLICT(context_id, key) DO UPDATE SET
253 payload=excluded.payload;
254 """, (self.CONTEXT_ID, name, mode, payload_txt))
260 Flush cached entrypoints to custom table.
262 if self.
_con_custom is None or not self._entrypoints_cache:
268 for kind, name
in self._entrypoints_cache.items():
270 INSERT INTO isicell_context_entrypoints (context_id, kind, name)
272 ON CONFLICT(context_id, kind) DO UPDATE SET
274 """, (self.CONTEXT_ID, kind, name))
283 if isinstance(x, np.ndarray):
285 if isinstance(x, (np.integer, np.floating)):
287 if isinstance(x, pd.Series):
289 "__kind__":
"series",
293 if isinstance(x, pd.DataFrame):
295 "__kind__":
"dataframe",
297 "value": x.to_dict(
"split")
299 if isinstance(x, dict):
300 return {k: OptunaArchive._to_jsonable(v)
for k, v
in x.items()}
301 if isinstance(x, (list, tuple)):
302 return [OptunaArchive._to_jsonable(v)
for v
in x]
308 Context encoding policy:
309 - functions/classes/callables -> source code (mode='source')
if possible
310 - JSON-serializable (after _to_jsonable) -> mode=
'json'
311 - fallback -> pickle64b
315 return {
"mode":
"source",
"payload": src}
320 return {
"mode":
"json",
"payload": j}
330 Capture newly created symbols in caller scope
and store them
in backupContext.
335 with archiver.register():
337 def dataCatcherOnStep(...): ...
338 def run_simu(...): ...
339 def optim(trial): ...
343 before_g = set(g0.keys())
344 before_l = set(l0.keys())
350 new_names = [k
for k
in l1.keys()
if k
not in before_l]
351 new_names += [k
for k
in g1.keys()
if k
not in before_g
and k
not in new_names]
353 for name
in new_names:
354 if name.startswith(
"__"):
356 value = l1[name]
if name
in l1
else g1[name]
359 except Exception
as e:
360 raise RuntimeError(f
"Failed to archive symbol '{name}'.\nerror: {e}")
from e
373 for k, v
in mapping.items():
378 Extract sqlite file path from an Optuna study backed by RDBStorage(sqlite),
379 even
if wrapped by _CachedStorage.
381 storage = getattr(study, "_storage",
None)
383 raise RuntimeError(
"Study has no _storage attribute.")
387 while storage
is not None and id(storage)
not in visited:
388 visited.add(id(storage))
390 engine = getattr(storage,
"engine",
None)
391 if engine
is not None:
392 url = str(engine.url)
393 if not url.startswith(
"sqlite:///"):
394 raise RuntimeError(f
"Only sqlite storage is supported. Got storage URL: {url}")
396 path = url[len(
"sqlite:///"):]
399 if not os.path.isabs(path):
400 path = os.path.abspath(path)
404 if hasattr(storage,
"_backend"):
405 storage = storage._backend
407 if hasattr(storage,
"_storage"):
408 storage = storage._storage
414 "Could not access underlying RDBStorage engine from study._storage "
415 "(possibly an unsupported Optuna storage wrapper/version)."
421 - vérifie que le symbole existe dans backupContext et est source-backed
422 - stocke en cache + flush si possible
424 if name
not in self.backupContext:
426 f
"'{name}' not found in backupContext. "
427 "Define it inside `with archiver.register(): ...` first."
430 entry = self.backupContext[name]
431 if entry.get(
"mode") !=
"source":
433 f
"Entrypoint '{name}' must be a source-backed symbol, "
434 f
"got mode='{entry.get('mode')}'."
445 raise KeyError(f
"Unknown entrypoint key: {key}")
447 self._entrypoints_cache[kind] = name
468 trial.set_user_attr(key, OptunaArchive._to_jsonable(value))
486 columns: Optional[list[str]] =
None,
487 value_col: str =
"value",
488 layout: str =
"aligned",
491 Store trial-level flattened values and a study-level spec.
493 layout=
'product' : cartesian product of axes
494 layout=
'aligned' : zipped/aligned axes (same length
as values)
498 if not isinstance(axes, (list, tuple))
or len(axes) < 1:
499 raise ValueError(
"axes must be a non-empty list of axis names.")
502 if len(columns) != len(axes):
503 raise ValueError(
"columns and axes must have the same length.")
504 if layout
not in (
"product",
"aligned"):
505 raise ValueError(
"layout must be 'product' or 'aligned'.")
508 if f
"axis:{ax}" not in self.
study.user_attrs:
509 raise KeyError(f
"Missing study axis 'axis:{ax}'. Call study_add_axis('{ax}', ...) first.")
511 spec_key = f
"timeseries_spec:{key}"
514 "columns": list(columns),
515 "value_col": value_col,
519 existing = self.
study.user_attrs.get(spec_key,
None)
521 self.
study.set_user_attr(spec_key, spec)
522 elif existing != spec:
523 raise ValueError(f
"Incompatible timeseries spec for key '{key}'. Existing={existing}, New={spec}")
534 if "suggest" not in self._entrypoints_cache:
535 missing.append(
"suggest entrypoint")
536 if require_run_function
and "run" not in self._entrypoints_cache:
537 missing.append(
"run entrypoint")
541 f
"Missing required archive entrypoints: {missing}. "
542 "Use set_suggest_entrypoint/set_run_entrypoint before optimize()."
548 rows = cur.execute(
"""
550 FROM isicell_context_entrypoints
551 WHERE context_id = ?;
552 """, (self.CONTEXT_ID,)).fetchall()
553 kinds_in_db = {k for k, _
in rows}
555 if "suggest" not in kinds_in_db:
556 raise RuntimeError(
"Missing 'suggest' entrypoint in isicell_context_entrypoints.")
557 if require_run_function
and "run" not in kinds_in_db:
558 raise RuntimeError(
"Missing 'run' entrypoint in isicell_context_entrypoints.")
567 Read-only Optuna SQLite helper (fast SQL) + archived function loading (exec) + replay-ready helpers.
569 Context and entrypoints are read
from custom tables:
571 - isicell_context_entrypoints
573 Study-specific metadata (axes, timeseries specs, etc.) remains
in study_user_attributes.
576 CONTEXT_ID = "__default__"
578 def __init__(self, db_path: str, readonly: bool =
True):
581 self.
con = sqlite3.connect(f
"file:{db_path}?mode=ro", uri=
True)
583 self.
con = sqlite3.connect(db_path)
604 rows = self.
con.execute(
"SELECT study_name FROM studies ORDER BY study_name;").fetchall()
605 return [r[0]
for r
in rows]
608 row = self.
con.execute(
610 SELECT studies.study_name
612 INNER JOIN studies ON studies.study_id = trials.study_id
613 WHERE trials.trial_id = ?;
618 raise KeyError(f
"trial_id {trial_id} not found")
622 rows = self.
con.execute(
624 SELECT key, value_json
625 FROM study_user_attributes
626 INNER JOIN studies ON studies.study_id = study_user_attributes.study_id
627 WHERE studies.study_name = ?;
631 return {k: json.loads(v)
for k, v
in rows}
633 def trial_attrs(self, study_name: str, key: Optional[str] =
None) -> pd.DataFrame:
635 df = pd.read_sql_query(
637 SELECT studies.study_name, trial_user_attributes.trial_id, trial_user_attributes.key, trial_user_attributes.value_json
638 FROM trial_user_attributes
639 INNER JOIN trials ON trial_user_attributes.trial_id = trials.trial_id
640 INNER JOIN studies ON studies.study_id = trials.study_id
641 WHERE trials.state="COMPLETE" AND studies.study_name = ?;
644 params=(study_name,),
647 df = pd.read_sql_query(
649 SELECT studies.study_name, trial_user_attributes.trial_id, trial_user_attributes.key, trial_user_attributes.value_json
650 FROM trial_user_attributes
651 INNER JOIN trials ON trial_user_attributes.trial_id = trials.trial_id
652 INNER JOIN studies ON studies.study_id = trials.study_id
653 WHERE trials.state="COMPLETE" AND studies.study_name = ? AND trial_user_attributes.key = ?;
656 params=(study_name, key),
660 return pd.DataFrame(columns=[
"study_name",
"trial_id",
"key",
"value"])
662 df = df.rename(columns={
"value_json":
"value"})
663 df[
"value"] = df[
"value"].map(json.loads)
669 def fitness_long(self, study_name: Optional[str] =
None) -> pd.DataFrame:
670 if study_name
is None:
671 return pd.read_sql_query(
673 SELECT studies.study_name, trial_values.trial_id, trial_values.objective, trial_values.value
675 INNER JOIN trials ON trial_values.trial_id = trials.trial_id
676 INNER JOIN studies ON studies.study_id = trials.study_id
677 WHERE trials.state="COMPLETE";
682 return pd.read_sql_query(
684 SELECT studies.study_name, trial_values.trial_id, trial_values.objective, trial_values.value
686 INNER JOIN trials ON trial_values.trial_id = trials.trial_id
687 INNER JOIN studies ON studies.study_id = trials.study_id
688 WHERE trials.state="COMPLETE" AND studies.study_name = ?;
691 params=(study_name,),
694 def fitness_wide(self, study_name: Optional[str] =
None) -> pd.DataFrame:
699 wide = res.pivot(index=[
"study_name",
"trial_id"], columns=
"objective", values=
"value").reset_index()
700 wide[
"trial_num"] = wide.groupby(
"study_name")[
"trial_id"].transform(
lambda s: s - s.min())
702 if 0
in wide.columns:
703 wide = wide.sort_values([
"study_name",
"trial_num"]).reset_index(drop=
True)
704 wide[
"best"] = wide.groupby(
"study_name")[0].cummin()
708 def params(self, study_name: Optional[str] =
None) -> pd.DataFrame:
709 if study_name
is None:
710 res = pd.read_sql_query(
712 SELECT studies.study_name, trial_params.trial_id, trial_params.param_name, trial_params.param_value
714 INNER JOIN trials ON trial_params.trial_id = trials.trial_id
715 INNER JOIN studies ON studies.study_id = trials.study_id
716 WHERE trials.state="COMPLETE";
721 res = pd.read_sql_query(
723 SELECT studies.study_name, trial_params.trial_id, trial_params.param_name, trial_params.param_value
725 INNER JOIN trials ON trial_params.trial_id = trials.trial_id
726 INNER JOIN studies ON studies.study_id = trials.study_id
727 WHERE trials.state="COMPLETE" AND studies.study_name = ?;
730 params=(study_name,),
734 return pd.DataFrame()
736 return res.pivot(index=[
"study_name",
"trial_id"], columns=
"param_name", values=
"param_value")
743 Normalize None | scalar | iterable ->
None | list[cast(x)].
744 Strings are treated
as scalars (
not iterables).
750 if isinstance(x, (str, bytes)):
752 if not isinstance(x, Iterable):
756 vals = [cast(v)
for v
in x]
762 Returns ('(?,?,?)', [..])
for SQL IN clauses.
763 values must be a non-empty list.
765 if values
is None or len(values) == 0:
766 raise ValueError(
"values must be a non-empty list")
767 return "(" +
",".join([
"?"] * len(values)) +
")", values
775 study_names = self.
_normalize_ids(study_name, cast=str, name=
"study_name")
776 trial_ids = self.
_normalize_ids(trial_ids, cast=int, name=
"trial_ids")
778 if study_names
is None:
779 if trial_ids
is None:
784 if len(trial_ids) == 0:
785 return pd.DataFrame(columns=[
"study_name",
"trial_id",
"value"])
787 rows = self.
con.execute(
789 SELECT DISTINCT studies.study_name
791 INNER JOIN studies ON studies.study_id = trials.study_id
792 WHERE trials.trial_id IN {in_clause}
793 ORDER BY studies.study_name;
797 study_names = [r[0] for r
in rows]
800 return pd.DataFrame(columns=[
"study_name",
"trial_id",
"value"])
808 base_value_col =
None
810 for sn
in study_names:
812 spec_key = f
"timeseries_spec:{key}"
813 if spec_key
not in attrs:
814 raise KeyError(f
"Missing study_user_attributes['{spec_key}'] for study '{sn}'")
816 spec = attrs[spec_key]
817 axis_names = spec[
"axes"]
818 col_names = spec[
"columns"]
819 value_col = spec.get(
"value_col",
"value")
820 layout = spec.get(
"layout",
"product")
823 for ax
in axis_names:
826 raise KeyError(f
"Missing study_user_attributes['{k}'] in study '{sn}'")
827 axis_values.append(attrs[k])
832 "columns": col_names,
833 "value_col": value_col,
835 "axis_values": axis_values,
838 if base_spec
is None:
839 base_spec = spec_full
840 base_cols = list(col_names)
841 base_value_col = value_col
844 if spec_full != base_spec:
846 f
"Timeseries spec mismatch for key '{key}' between studies. "
847 f
"Study '{sn}' has a different spec/axes."
850 specs_by_study[sn] = spec_full
856 SELECT studies.study_name, trial_user_attributes.trial_id, trial_user_attributes.value_json
857 FROM trial_user_attributes
858 INNER JOIN trials ON trial_user_attributes.trial_id = trials.trial_id
859 INNER JOIN studies ON studies.study_id = trials.study_id
860 WHERE trials.state="COMPLETE"
861 AND trial_user_attributes.key = ?
865 if study_names
is not None:
867 sql += f
" AND studies.study_name IN {in_clause}"
868 params.extend(in_vals)
870 if trial_ids
is not None:
871 if len(trial_ids) == 0:
872 return pd.DataFrame(columns=[
"study_name",
"trial_id", *base_cols, base_value_col])
874 sql += f
" AND trial_user_attributes.trial_id IN {in_clause}"
875 params.extend(in_vals)
877 raw = pd.read_sql_query(sql, self.
con, params=params)
880 return pd.DataFrame(columns=[
"study_name",
"trial_id", *base_cols, base_value_col])
887 for sn, raw_sn
in raw.groupby(
"study_name", sort=
False):
888 spec = specs_by_study[sn]
889 axis_names = spec[
"axes"]
890 col_names = spec[
"columns"]
891 value_col = spec[
"value_col"]
892 layout = spec[
"layout"]
893 axis_values = spec[
"axis_values"]
895 arr = np.asarray([json.loads(v)
for v
in raw_sn[
"value_json"]], dtype=float)
897 if layout ==
"aligned":
899 for ax, vals
in zip(axis_names, axis_values):
902 f
"Aligned layout mismatch for axis '{ax}' in study '{sn}': {len(vals)} != {n}"
906 for i, (_, tid)
in enumerate(raw_sn[[
"study_name",
"trial_id"]].to_numpy()):
907 d = pd.DataFrame({
"study_name": sn,
"trial_id": tid, value_col: arr[i]})
908 for cname, vals
in zip(col_names, axis_values):
911 out_parts.append(pd.concat(rows, ignore_index=
True))
914 if layout ==
"product":
915 if len(axis_values) == 1:
916 cols = pd.Index(axis_values[0], name=col_names[0])
918 cols = pd.MultiIndex.from_product(axis_values, names=col_names)
920 if arr.shape[1] != len(cols):
922 f
"Timeseries size mismatch for key '{key}' in study '{sn}': "
923 f
"payload length={arr.shape[1]} expected={len(cols)}"
928 index=pd.MultiIndex.from_frame(raw_sn[[
"study_name",
"trial_id"]]),
932 wide.reset_index().melt(
933 id_vars=[
"study_name",
"trial_id"],
934 var_name=col_names
if len(col_names) > 1
else col_names[0],
935 value_name=value_col,
940 raise ValueError(f
"Unknown timeseries layout '{layout}' in study '{sn}'")
943 return pd.DataFrame(columns=[
"study_name",
"trial_id", *base_cols, base_value_col])
945 return pd.concat(out_parts, ignore_index=
True)
956 df = self.
get_trial_data(key=key, study_name=study_name, trial_ids=trial_ids,raw=
True)
961 is_dict = df[
"value"].map(
lambda x: isinstance(x, dict))
963 expanded = pd.json_normalize(df[
"value"], sep=sep)
964 expanded.index = df.index
965 out = pd.concat([df[[
"study_name",
"trial_id"]], expanded], axis=1)
971 out = df[[
"study_name",
"trial_id"]].copy()
972 out[
"value"] = df[
"value"]
974 dict_rows = df[is_dict]
975 expanded = pd.json_normalize(dict_rows[
"value"], sep=sep)
976 expanded.index = dict_rows.index
978 out = pd.concat([out, expanded], axis=1)
984 study_names = self.
_normalize_ids(study_name, cast=str, name=
"study_name")
985 trial_ids = self.
_normalize_ids(trial_ids, cast=int, name=
"trial_ids")
988 SELECT studies.study_name, trial_user_attributes.trial_id, trial_user_attributes.value_json
989 FROM trial_user_attributes
990 INNER JOIN trials ON trial_user_attributes.trial_id = trials.trial_id
991 INNER JOIN studies ON studies.study_id = trials.study_id
992 WHERE trials.state="COMPLETE"
993 AND trial_user_attributes.key = ?
997 if study_names
is not None:
998 if len(study_names) == 0:
999 return pd.DataFrame(columns=[
"study_name",
"trial_id",
"value"])
1001 sql += f
" AND studies.study_name IN {in_clause}"
1002 params.extend(in_vals)
1004 if trial_ids
is not None:
1005 if len(trial_ids) == 0:
1006 return pd.DataFrame(columns=[
"study_name",
"trial_id",
"value"])
1008 sql += f
" AND trial_user_attributes.trial_id IN {in_clause}"
1009 params.extend(in_vals)
1011 df = pd.read_sql_query(sql, self.
con, params=params)
1014 return pd.DataFrame(columns=[
"study_name",
"trial_id",
"value"])
1016 df = df.rename(columns={
"value_json":
"value"})
1017 df[
"value"] = df[
"value"].map(json.loads)
1023 rows = self.
con.execute(
1027 WHERE type='table' AND name IN (
'isicell_context',
'isicell_context_entrypoints');
1030 names = {r[0] for r
in rows}
1031 return "isicell_context" in names
and "isicell_context_entrypoints" in names
1035 Returns rows: (key, mode, payload) for CONTEXT_ID.
1040 return self.
con.execute(
1042 SELECT key, mode, payload
1043 FROM isicell_context
1044 WHERE context_id = ?
1052 Returns {"suggest":
"optim",
"run":
"run_simu", ...}
1057 rows = self.
con.execute(
1060 FROM isicell_context_entrypoints
1061 WHERE context_id = ?;
1065 return {k: n
for k, n
in rows}
1072 if isinstance(v, dict)
and v.get(
"__kind__") ==
"dataframe" and v.get(
"orient") ==
"split":
1074 return pd.DataFrame(data=vv[
"data"], index=vv[
"index"], columns=vv[
"columns"])
1075 if isinstance(v, dict)
and v.get(
"__kind__") ==
"series":
1076 s = pd.Series(v[
"value"])
1077 s.name = v.get(
"name")
1083 Return a data object from the shared isicell context (json/pickle only).
1087 raise RuntimeError(
"No isicell_context table (or empty context) in this DB.")
1089 for key, mode, payload_txt
in rows:
1095 if mode ==
"pickle64b":
1098 raise TypeError(f
"'{name}' is source-backed (mode='source'), not a data object")
1100 raise KeyError(f
"'{name}' not found in isicell_context")
1104 Build exec namespace from shared isicell_context + caller namespace.
1105 Source-backed symbols are exec
'ed into the same namespace.
1106 Entrypoints are resolved from isicell_context_entrypoints.
1115 for key, mode, payload_txt
in rows:
1118 elif mode ==
"pickle64b":
1122 for key, mode, payload_txt
in rows:
1123 if mode ==
"source":
1124 exec(payload_txt, g, g)
1128 if "suggest" in eps
and eps[
"suggest"]
in g:
1130 if "run" in eps
and eps[
"run"]
in g:
1131 self.
run_fn = g[eps[
"run"]]
1132 if "eval" in eps
and eps[
"eval"]
in g:
1140 if x
is None or isinstance(x, (int, float, bool)):
1150 if any(ch
in s
for ch
in [
".",
"e",
"E"]):
1157 return json.loads(s)
1162 df = pd.read_sql_query(
1164 SELECT trial_params.param_name, trial_params.param_value
1166 INNER JOIN trials ON trial_params.trial_id = trials.trial_id
1167 WHERE trials.state="COMPLETE"
1168 AND trial_params.trial_id = ?;
1171 params=(int(trial_id),),
1173 return {r[
"param_name"]: self.
_coerce_param_value(r[
"param_value"])
for _, r
in df.iterrows()}
1177 Rebuild structured params from flat optuna params using archived suggest_fn.
1178 Handles categorical params stored
as category indices.
1184 raise RuntimeError(
"No suggest entrypoint found in isicell_context_entrypoints.")
1188 bounds = LHSIterator.getBoundaries(self.
suggest_fn)
1189 for p, info
in bounds.items():
1190 if info.get(
"type") ==
"categorical" and p
in flat:
1192 flat[p] = info[
"choices"][int(flat[p])]
1198 fake_trial.params = flat
1199 fake_trial.modeSuggest =
True
1204 fitness_threshold: float = 1500,
1208 figsize_width: float = 11,
1209 row_height: float = 0.42,
1210 min_fig_height: float = 3.0,
1211 label_fontsize: int = 8,
1212 min_cat_label_width: float = 0.08,
1213 max_cat_label_len: int = 14,
1216 Plot normalized distributions of selected trials' parameters:
1217 - numeric params: violin plots on x in [0, 1]
1218 - categorical params: stacked horizontal bars of frequencies on x
in [0, 1]
1221 - keep completed trials
with fitness(objective) < fitness_threshold
1222 - optional filter by study_name
and/
or trial_ids
1226 - Requires self.
suggest_fn (loaded
from isicell context) to infer param bounds/types.
1227 - Assumes categorical params
in trial_params are stored
as category indices (Optuna/LHSIterator path).
1228 - Uses seaborn + matplotlib (imported lazily).
1230 import matplotlib.pyplot
as plt
1231 import seaborn
as sns
1237 raise RuntimeError(
"No suggest entrypoint loaded. Cannot infer parameter bounds/types.")
1240 study_names = self.
_normalize_ids(study_name, cast=str, name=
"study_name")
1241 trial_ids_norm = self.
_normalize_ids(trial_ids, cast=int, name=
"trial_ids")
1245 all_params = self.
params()
1247 if all_fitness.empty
or all_params.empty:
1248 raise RuntimeError(
"No completed trials found in DB.")
1251 if objective
not in all_fitness.columns:
1252 raise KeyError(f
"Objective column {objective!r} not found in fitness_wide(). Available: {list(all_fitness.columns)}")
1255 fit = all_fitness.loc[all_fitness[objective] < fitness_threshold, [
"study_name",
"trial_id", objective]].copy()
1257 if study_names
is not None:
1258 fit = fit[fit[
"study_name"].isin(study_names)]
1260 if trial_ids_norm
is not None:
1261 fit = fit[fit[
"trial_id"].isin(trial_ids_norm)]
1264 raise RuntimeError(
"No trial matches the selected filters / fitness threshold.")
1267 selected_pairs = set(map(tuple, fit[[
"study_name",
"trial_id"]].to_numpy()))
1270 idx_study = all_params.index.get_level_values(
"study_name")
1271 idx_trial = all_params.index.get_level_values(
"trial_id")
1272 mask = [(sn, tid)
in selected_pairs
for sn, tid
in zip(idx_study, idx_trial)]
1273 tmp_params = all_params.loc[mask].copy()
1275 if tmp_params.empty:
1276 raise RuntimeError(
"No parameters found for selected trials.")
1279 bounds = LHSIterator.getBoundaries(self.
suggest_fn)
1282 tmp_num_all = tmp_params.copy()
1283 for c
in tmp_num_all.columns:
1284 tmp_num_all[c] = pd.to_numeric(tmp_num_all[c], errors=
"coerce")
1287 num_params = [k
for k, v
in bounds.items()
if v.get(
"type")
in (
"int",
"float")
and k
in tmp_params.columns]
1288 cat_params = [k
for k, v
in bounds.items()
if v.get(
"type") ==
"categorical" and k
in tmp_params.columns]
1292 for p
in num_params:
1294 s = pd.to_numeric(tmp_params[p], errors=
"coerce")
1296 low = float(info[
"low"])
1297 high = float(info[
"high"])
1302 if info.get(
"log",
False):
1305 vals = (np.log(s) - np.log(low)) / (np.log(high) - np.log(low))
1307 vals = (s - low) / (high - low)
1309 vals = vals.clip(0, 1)
1311 label = f
"{p} [{low:.3f},{high:.3f}]"
1312 for v
in vals.dropna().to_numpy():
1313 num_rows.append({
"param": label,
"value": float(v),
"kind":
"numeric"})
1315 num_df = pd.DataFrame(num_rows)
1319 cat_segment_labels = {}
1321 for p
in cat_params:
1323 choices = info[
"choices"]
1326 def _choice_repr(x):
1328 return repr(choices[int(x)])
1333 s = tmp_params[p].map(_choice_repr)
1334 choice_labels = [repr(c)
for c
in choices]
1336 vc = s.value_counts(normalize=
True)
1337 freqs = [float(vc.get(cl, 0.0))
for cl
in choice_labels]
1341 param_label = f
"{p} ({len(choice_labels)} choices)"
1342 cat_segment_labels[param_label] = []
1344 for cl, f
in zip(choice_labels, freqs):
1348 "param": param_label,
1353 "kind":
"categorical",
1355 cat_segment_labels[param_label].append((x0 + f / 2, cl, f))
1358 cat_df = pd.DataFrame(cat_rows)
1362 if not num_df.empty:
1363 param_order += list(dict.fromkeys(num_df[
"param"].tolist()))
1364 if not cat_df.empty:
1365 param_order += [p
for p
in dict.fromkeys(cat_df[
"param"].tolist())
if p
not in param_order]
1367 if len(param_order) == 0:
1368 raise RuntimeError(
"No plottable parameter found (check bounds / selected trials).")
1371 ypos = {p: i
for i, p
in enumerate(param_order)}
1374 fig_h =
max(min_fig_height, row_height * len(param_order) + 1.5)
1375 fig = plt.figure(figsize=(figsize_width, fig_h))
1379 if not num_df.empty:
1380 num_df_plot = num_df.copy()
1381 num_df_plot[
"y"] = num_df_plot[
"param"].map(ypos)
1396 if not cat_df.empty:
1397 for _, r
in cat_df.iterrows():
1398 y = ypos[r[
"param"]]
1408 for p, segs
in cat_segment_labels.items():
1410 for xc, cl, f
in segs:
1411 if f >= min_cat_label_width:
1413 if len(txt) > max_cat_label_len:
1414 txt = txt[: max_cat_label_len - 3] +
"..."
1415 ax.text(xc, y, txt, ha=
"center", va=
"center", fontsize=label_fontsize)
1419 ax.set_xlabel(
"Normalized value / category frequency")
1423 for i
in range(len(param_order)):
1424 ax.axhline(i + 0.5, linewidth=0.3, alpha=0.3)
def set_eval_entrypoint(self, str name)
def set_study(self, study)
def _require_custom_db(self)
dict[str, Any] _encode_context_value(self, Any value)
str _extract_sqlite_path_from_study(self, study)
def _flush_context_if_possible(self)
def trial_add_timeseries(self, trial, str key, Any values, list[str] axes, Optional[list[str]] columns=None, str value_col="value", str layout="aligned")
def set_suggest_entrypoint(self, str name)
def trial_add_data(trial, str key, Any value)
def _set_entrypoint(self, str key, str name)
def check_minimal(self, bool require_run_function=True)
def set_run_entrypoint(self, str name)
def context_add_many(self, dict[str, Any] mapping)
def __init__(self, study=None)
def _flush_entrypoints_if_possible(self)
def study_add_data(self, str key, Any value)
def study_add_axis(self, str name, Any values)
def context_add(self, str name, Any value)
def _ensure_isicell_tables(self)
dict[str, Any] study_attrs(self, str study_name)
pd.DataFrame fitness_wide(self, Optional[str] study_name=None)
pd.DataFrame get_trial_timeseries(self, str key, study_name=None, trial_ids=None)
pd.DataFrame params(self, Optional[str] study_name=None)
def build_params(self, int trial_id)
def plot_param_distributions(self, float fitness_threshold=1500, int objective=0, study_name=None, trial_ids=None, float figsize_width=11, float row_height=0.42, float min_fig_height=3.0, int label_fontsize=8, float min_cat_label_width=0.08, int max_cat_label_len=14)
dict[str, str] _read_isicell_entrypoints(self)
dict[str, Any] get_trial_params_flat(self, int trial_id)
def _normalize_ids(self, x, cast=int, name="ids")
Any _coerce_param_value(Any x)
bool _has_isicell_tables(self)
str trial_to_study(self, int trial_id)
pd.DataFrame trial_attrs(self, str study_name, Optional[str] key=None)
pd.DataFrame get_trial_data(self, str key, study_name=None, trial_ids=None, raw=False, sep='.')
def _sql_in_clause(self, values)
Any _decode_json_runtime_value(Any v)
def __init__(self, str db_path, bool readonly=True)
Any context_value(self, str name)
pd.DataFrame fitness_long(self, Optional[str] study_name=None)
list[tuple[str, str, str]] _read_isicell_context_rows(self)
dict[str, Any] load_context(self)
auto get(const nlohmann::detail::iteration_proxy_value< IteratorType > &i) -> decltype(i.key())
str variable_to_pickle64b(Any value)
def _capture_external_caller_namespace()
def _capture_external_caller_frame()
Any pickle64b_to_variable(str payload)
str _get_function_source(fn)
double max(double a, double b)
Computes the maximum of two numbers.