87 lines
2.4 KiB
Python
87 lines
2.4 KiB
Python
import numpy as np
|
|
import pygame
|
|
|
|
from mathema.genotype.neo4j.genotype import load_genotype_snapshot, neo4j
|
|
from viz_replay import build_policy_from_snapshot
|
|
from mathema.envs.openai_car_racing import CarRacing
|
|
|
|
|
|
async def _best_agent_in_population(population_id: str) -> str:
|
|
rows = await neo4j.read_all("""
|
|
MATCH (a:agent {population_id:$pid})
|
|
WHERE a.fitness IS NOT NULL
|
|
RETURN a.id AS id, toFloat(a.fitness) AS f
|
|
ORDER BY f DESC
|
|
LIMIT 1
|
|
""", pid=str(population_id))
|
|
print(rows)
|
|
if not rows:
|
|
raise RuntimeError(f"no agents found with fitness in '{population_id}'")
|
|
return str(rows[0]["id"])
|
|
|
|
|
|
def _post_process_action(y: np.ndarray) -> np.ndarray:
|
|
y0 = float(y[0]) if y.size >= 1 else 0.0
|
|
y1 = float(y[1]) if y.size >= 2 else 0.0
|
|
y2 = float(y[2]) if y.size >= 3 else 0.0
|
|
|
|
steer = max(-1.0, min(1.0, y0))
|
|
gas = max(0.0, min(1.0, 0.5 * (y1 + 1.0)))
|
|
brake = max(0.0, min(1.0, 0.5 * (y2 + 1.0)))
|
|
|
|
return np.array([steer, gas, brake], dtype=np.float32)
|
|
|
|
|
|
async def replay_best(population_id: str, seed: int = 5, lookahead: int = 10):
|
|
aid = await _best_agent_in_population(population_id)
|
|
snap = await load_genotype_snapshot(aid)
|
|
policy, I = build_policy_from_snapshot(snap)
|
|
|
|
env = CarRacing(seed_value=seed, render_mode="human")
|
|
_, _ = env.reset()
|
|
policy.reset_state()
|
|
|
|
_ = env.step(np.array([0.0, 0.0, 0.0], dtype=np.float32))
|
|
|
|
frame = 0
|
|
|
|
try:
|
|
while True:
|
|
feats = np.array(env.get_feature_vector(lookahead), dtype=np.float32)
|
|
|
|
if feats.shape[0] < I:
|
|
x = np.zeros((I,), dtype=np.float32)
|
|
x[:feats.shape[0]] = feats
|
|
else:
|
|
x = feats[:I]
|
|
|
|
y = policy.step(x)
|
|
act = _post_process_action(y)
|
|
_, r, terminated, truncated, _ = env.step(act)
|
|
|
|
if frame % 2 == 0:
|
|
env.render()
|
|
|
|
frame += 1
|
|
|
|
if not pygame.display.get_init() or pygame.display.get_surface() is None:
|
|
break
|
|
|
|
if terminated:
|
|
env.tile_visited_count = 0
|
|
env.prev_reward = env.reward
|
|
|
|
continue
|
|
|
|
if truncated:
|
|
env._no_progress_steps = 0
|
|
continue
|
|
finally:
|
|
env.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import asyncio
|
|
|
|
asyncio.run(replay_best(population_id="car_pop", seed=1, lookahead=10))
|