Files
neuroevolution/mathema/replay.py
2025-12-13 14:12:35 +01:00

87 lines
2.4 KiB
Python

import numpy as np
import pygame
from mathema.genotype.neo4j.genotype import load_genotype_snapshot, neo4j
from viz_replay import build_policy_from_snapshot
from mathema.envs.openai_car_racing import CarRacing
async def _best_agent_in_population(population_id: str) -> str:
rows = await neo4j.read_all("""
MATCH (a:agent {population_id:$pid})
WHERE a.fitness IS NOT NULL
RETURN a.id AS id, toFloat(a.fitness) AS f
ORDER BY f DESC
LIMIT 1
""", pid=str(population_id))
print(rows)
if not rows:
raise RuntimeError(f"no agents found with fitness in '{population_id}'")
return str(rows[0]["id"])
def _post_process_action(y: np.ndarray) -> np.ndarray:
y0 = float(y[0]) if y.size >= 1 else 0.0
y1 = float(y[1]) if y.size >= 2 else 0.0
y2 = float(y[2]) if y.size >= 3 else 0.0
steer = max(-1.0, min(1.0, y0))
gas = max(0.0, min(1.0, 0.5 * (y1 + 1.0)))
brake = max(0.0, min(1.0, 0.5 * (y2 + 1.0)))
return np.array([steer, gas, brake], dtype=np.float32)
async def replay_best(population_id: str, seed: int = 5, lookahead: int = 10):
aid = await _best_agent_in_population(population_id)
snap = await load_genotype_snapshot(aid)
policy, I = build_policy_from_snapshot(snap)
env = CarRacing(seed_value=seed, render_mode="human")
_, _ = env.reset()
policy.reset_state()
_ = env.step(np.array([0.0, 0.0, 0.0], dtype=np.float32))
frame = 0
try:
while True:
feats = np.array(env.get_feature_vector(lookahead), dtype=np.float32)
if feats.shape[0] < I:
x = np.zeros((I,), dtype=np.float32)
x[:feats.shape[0]] = feats
else:
x = feats[:I]
y = policy.step(x)
act = _post_process_action(y)
_, r, terminated, truncated, _ = env.step(act)
if frame % 2 == 0:
env.render()
frame += 1
if not pygame.display.get_init() or pygame.display.get_surface() is None:
break
if terminated:
env.tile_visited_count = 0
env.prev_reward = env.reward
continue
if truncated:
env._no_progress_steps = 0
continue
finally:
env.close()
if __name__ == "__main__":
import asyncio
asyncio.run(replay_best(population_id="car_pop", seed=1, lookahead=10))