last workig state.
This commit is contained in:
93
mathema/utils/stats.py
Normal file
93
mathema/utils/stats.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
import atexit, json, os, time, math
|
||||
from typing import Dict, List, Callable, Iterable, Any, Optional
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
_FLUSHED: set[str] = set()
|
||||
|
||||
|
||||
def ensure_dir() -> None:
|
||||
os.makedirs("stats", exist_ok=True)
|
||||
|
||||
|
||||
def save_series(population_id: str, rows: Iterable[Dict[str, Any]]) -> None:
|
||||
ensure_dir()
|
||||
pid = str(population_id)
|
||||
rows = list(rows)
|
||||
if not rows:
|
||||
return
|
||||
|
||||
jsonl_path = os.path.join("stats", f"{pid}.jsonl")
|
||||
tmp_path = jsonl_path + ".tmp"
|
||||
|
||||
with open(tmp_path, "w", encoding="utf-8") as f:
|
||||
for r in rows:
|
||||
|
||||
if "ts" not in r:
|
||||
rr = {"ts": int(time.time())} | r
|
||||
else:
|
||||
rr = r
|
||||
f.write(json.dumps(rr, separators=(",", ":"), ensure_ascii=False) + "\n")
|
||||
os.replace(tmp_path, jsonl_path)
|
||||
|
||||
_plot_from_jsonl(jsonl_path, pid)
|
||||
|
||||
|
||||
def register_atexit(population_id: str, rows_provider: Callable[[], Iterable[Dict[str, Any]]]) -> None:
|
||||
pid = str(population_id)
|
||||
|
||||
def _flush_once() -> None:
|
||||
if pid in _FLUSHED:
|
||||
return
|
||||
try:
|
||||
save_series(pid, rows_provider())
|
||||
finally:
|
||||
_FLUSHED.add(pid)
|
||||
|
||||
atexit.register(_flush_once)
|
||||
|
||||
|
||||
def _plot_from_jsonl(src_path: str, pid: str) -> None:
|
||||
gens: List[int] = []
|
||||
series: dict[str, list[float]] = {}
|
||||
|
||||
with open(src_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
obj = json.loads(line)
|
||||
g = int(obj.get("gen", 0))
|
||||
gens.append(g)
|
||||
for k, v in obj.items():
|
||||
if k in ("gen", "ts"):
|
||||
continue
|
||||
try:
|
||||
series.setdefault(k, []).append(float(v))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not gens or not series:
|
||||
return
|
||||
|
||||
for metric, ys in series.items():
|
||||
if len(ys) != len(gens):
|
||||
L = min(len(ys), len(gens))
|
||||
ys = ys[:L]
|
||||
x = gens[:L]
|
||||
else:
|
||||
x = gens
|
||||
|
||||
plt.figure()
|
||||
plt.title(f"{metric} over generations")
|
||||
plt.xlabel("generation")
|
||||
plt.ylabel(metric)
|
||||
plt.plot(x, ys, marker="o")
|
||||
out_png = os.path.join("stats", f"{pid}_{metric}.png")
|
||||
plt.tight_layout()
|
||||
plt.savefig(out_png, dpi=120)
|
||||
plt.close()
|
||||
Reference in New Issue
Block a user