Files
neuroevolution/mathema/replay.py
2026-02-21 10:58:05 +01:00

102 lines
3.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Replay utility for visualizing the best evolved CarRacing agent.
This module loads the best-performing agent from a given population stored
in Neo4j, reconstructs its policy from a genotype snapshot, and replays the
agent in a human-rendered CarRacing environment using pygame.
High-level workflow:
1. Query Neo4j for the agent with the highest recorded fitness in a population.
2. Load the agents genotype snapshot.
3. Build an executable policy from the snapshot.
4. Run the policy in the CarRacing environment, step by step.
5. Render the environment in real time and automatically handle episode resets.
"""
import numpy as np
import pygame
from mathema.genotype.neo4j.genotype import load_genotype_snapshot, neo4j
from viz_replay import build_policy_from_snapshot
from mathema.envs.openai_car_racing import CarRacing
async def _best_agent_in_population(population_id: str) -> str:
rows = await neo4j.read_all("""
MATCH (a:agent {population_id:$pid})
WHERE a.fitness IS NOT NULL
RETURN a.id AS id, toFloat(a.fitness) AS f
ORDER BY f DESC
LIMIT 1
""", pid=str(population_id))
print(rows)
if not rows:
raise RuntimeError(f"no agents found with fitness in '{population_id}'")
return str(rows[0]["id"])
def _post_process_action(y: np.ndarray) -> np.ndarray:
y0 = float(y[0]) if y.size >= 1 else 0.0
y1 = float(y[1]) if y.size >= 2 else 0.0
y2 = float(y[2]) if y.size >= 3 else 0.0
steer = max(-1.0, min(1.0, y0))
gas = max(0.0, min(1.0, 0.5 * (y1 + 1.0)))
brake = max(0.0, min(1.0, 0.5 * (y2 + 1.0)))
return np.array([steer, gas, brake], dtype=np.float32)
async def replay_best(population_id: str, seed: int = 5, lookahead: int = 10):
aid = await _best_agent_in_population(population_id)
snap = await load_genotype_snapshot(aid)
policy, I = build_policy_from_snapshot(snap)
env = CarRacing(seed_value=seed, render_mode="human")
_, _ = env.reset()
policy.reset_state()
_ = env.step(np.array([0.0, 0.0, 0.0], dtype=np.float32))
frame = 0
try:
while True:
feats = np.array(env.get_feature_vector(lookahead), dtype=np.float32)
if feats.shape[0] < I:
x = np.zeros((I,), dtype=np.float32)
x[:feats.shape[0]] = feats
else:
x = feats[:I]
y = policy.step(x)
act = _post_process_action(y)
_, r, terminated, truncated, _ = env.step(act)
if frame % 2 == 0:
env.render()
frame += 1
if not pygame.display.get_init() or pygame.display.get_surface() is None:
break
if terminated:
env.tile_visited_count = 0
env.prev_reward = env.reward
continue
if truncated:
env._no_progress_steps = 0
continue
finally:
env.close()
if __name__ == "__main__":
import asyncio
asyncio.run(replay_best(population_id="car_pop", seed=1, lookahead=10))