102 lines
3.0 KiB
Python
102 lines
3.0 KiB
Python
"""
|
||
Replay utility for visualizing the best evolved CarRacing agent.
|
||
|
||
This module loads the best-performing agent from a given population stored
|
||
in Neo4j, reconstructs its policy from a genotype snapshot, and replays the
|
||
agent in a human-rendered CarRacing environment using pygame.
|
||
|
||
High-level workflow:
|
||
1. Query Neo4j for the agent with the highest recorded fitness in a population.
|
||
2. Load the agent’s genotype snapshot.
|
||
3. Build an executable policy from the snapshot.
|
||
4. Run the policy in the CarRacing environment, step by step.
|
||
5. Render the environment in real time and automatically handle episode resets.
|
||
"""
|
||
|
||
import numpy as np
|
||
import pygame
|
||
|
||
from mathema.genotype.neo4j.genotype import load_genotype_snapshot, neo4j
|
||
from viz_replay import build_policy_from_snapshot
|
||
from mathema.envs.openai_car_racing import CarRacing
|
||
|
||
|
||
async def _best_agent_in_population(population_id: str) -> str:
|
||
rows = await neo4j.read_all("""
|
||
MATCH (a:agent {population_id:$pid})
|
||
WHERE a.fitness IS NOT NULL
|
||
RETURN a.id AS id, toFloat(a.fitness) AS f
|
||
ORDER BY f DESC
|
||
LIMIT 1
|
||
""", pid=str(population_id))
|
||
print(rows)
|
||
if not rows:
|
||
raise RuntimeError(f"no agents found with fitness in '{population_id}'")
|
||
return str(rows[0]["id"])
|
||
|
||
|
||
def _post_process_action(y: np.ndarray) -> np.ndarray:
|
||
y0 = float(y[0]) if y.size >= 1 else 0.0
|
||
y1 = float(y[1]) if y.size >= 2 else 0.0
|
||
y2 = float(y[2]) if y.size >= 3 else 0.0
|
||
|
||
steer = max(-1.0, min(1.0, y0))
|
||
gas = max(0.0, min(1.0, 0.5 * (y1 + 1.0)))
|
||
brake = max(0.0, min(1.0, 0.5 * (y2 + 1.0)))
|
||
|
||
return np.array([steer, gas, brake], dtype=np.float32)
|
||
|
||
|
||
async def replay_best(population_id: str, seed: int = 5, lookahead: int = 10):
|
||
aid = await _best_agent_in_population(population_id)
|
||
snap = await load_genotype_snapshot(aid)
|
||
policy, I = build_policy_from_snapshot(snap)
|
||
|
||
env = CarRacing(seed_value=seed, render_mode="human")
|
||
_, _ = env.reset()
|
||
policy.reset_state()
|
||
|
||
_ = env.step(np.array([0.0, 0.0, 0.0], dtype=np.float32))
|
||
|
||
frame = 0
|
||
|
||
try:
|
||
while True:
|
||
feats = np.array(env.get_feature_vector(lookahead), dtype=np.float32)
|
||
|
||
if feats.shape[0] < I:
|
||
x = np.zeros((I,), dtype=np.float32)
|
||
x[:feats.shape[0]] = feats
|
||
else:
|
||
x = feats[:I]
|
||
|
||
y = policy.step(x)
|
||
act = _post_process_action(y)
|
||
_, r, terminated, truncated, _ = env.step(act)
|
||
|
||
if frame % 2 == 0:
|
||
env.render()
|
||
|
||
frame += 1
|
||
|
||
if not pygame.display.get_init() or pygame.display.get_surface() is None:
|
||
break
|
||
|
||
if terminated:
|
||
env.tile_visited_count = 0
|
||
env.prev_reward = env.reward
|
||
|
||
continue
|
||
|
||
if truncated:
|
||
env._no_progress_steps = 0
|
||
continue
|
||
finally:
|
||
env.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import asyncio
|
||
|
||
asyncio.run(replay_best(population_id="car_pop", seed=1, lookahead=10))
|