94 lines
2.4 KiB
Python
94 lines
2.4 KiB
Python
from __future__ import annotations
|
|
import atexit, json, os, time, math
|
|
from typing import Dict, List, Callable, Iterable, Any, Optional
|
|
|
|
import matplotlib
|
|
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
|
|
_FLUSHED: set[str] = set()
|
|
|
|
|
|
def ensure_dir() -> None:
|
|
os.makedirs("stats", exist_ok=True)
|
|
|
|
|
|
def save_series(population_id: str, rows: Iterable[Dict[str, Any]]) -> None:
|
|
ensure_dir()
|
|
pid = str(population_id)
|
|
rows = list(rows)
|
|
if not rows:
|
|
return
|
|
|
|
jsonl_path = os.path.join("stats", f"{pid}.jsonl")
|
|
tmp_path = jsonl_path + ".tmp"
|
|
|
|
with open(tmp_path, "w", encoding="utf-8") as f:
|
|
for r in rows:
|
|
|
|
if "ts" not in r:
|
|
rr = {"ts": int(time.time())} | r
|
|
else:
|
|
rr = r
|
|
f.write(json.dumps(rr, separators=(",", ":"), ensure_ascii=False) + "\n")
|
|
os.replace(tmp_path, jsonl_path)
|
|
|
|
_plot_from_jsonl(jsonl_path, pid)
|
|
|
|
|
|
def register_atexit(population_id: str, rows_provider: Callable[[], Iterable[Dict[str, Any]]]) -> None:
|
|
pid = str(population_id)
|
|
|
|
def _flush_once() -> None:
|
|
if pid in _FLUSHED:
|
|
return
|
|
try:
|
|
save_series(pid, rows_provider())
|
|
finally:
|
|
_FLUSHED.add(pid)
|
|
|
|
atexit.register(_flush_once)
|
|
|
|
|
|
def _plot_from_jsonl(src_path: str, pid: str) -> None:
|
|
gens: List[int] = []
|
|
series: dict[str, list[float]] = {}
|
|
|
|
with open(src_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
obj = json.loads(line)
|
|
g = int(obj.get("gen", 0))
|
|
gens.append(g)
|
|
for k, v in obj.items():
|
|
if k in ("gen", "ts"):
|
|
continue
|
|
try:
|
|
series.setdefault(k, []).append(float(v))
|
|
except Exception:
|
|
pass
|
|
|
|
if not gens or not series:
|
|
return
|
|
|
|
for metric, ys in series.items():
|
|
if len(ys) != len(gens):
|
|
L = min(len(ys), len(gens))
|
|
ys = ys[:L]
|
|
x = gens[:L]
|
|
else:
|
|
x = gens
|
|
|
|
plt.figure()
|
|
plt.title(f"{metric} over generations")
|
|
plt.xlabel("generation")
|
|
plt.ylabel(metric)
|
|
plt.plot(x, ys, marker="o")
|
|
out_png = os.path.join("stats", f"{pid}_{metric}.png")
|
|
plt.tight_layout()
|
|
plt.savefig(out_png, dpi=120)
|
|
plt.close()
|