Files
neuroevolution/sac/sac_main.py
2025-12-13 14:12:35 +01:00

53 lines
1.1 KiB
Python

import gymnasium as gym
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import VecMonitor
from car_racing_env import CarRacing
SEED = 5
def make_env():
env = CarRacing(render_mode=None)
env.reset(seed=SEED)
return Monitor(env)
venv = make_vec_env(make_env, n_envs=1)
venv = VecMonitor(venv)
model = SAC(
"MlpPolicy",
venv,
seed=SEED,
learning_rate=3e-4,
buffer_size=300_000,
batch_size=256,
tau=0.01,
gamma=0.99,
train_freq=(1, "step"),
gradient_steps=1,
ent_coef="auto",
target_entropy=-3,
verbose=1,
device="auto",
)
model.learn(total_timesteps=500_000)
model.save("sac_carracing_features")
# Testen (mit Rendern)
test_env = CarRacing(render_mode="human")
obs, _ = test_env.reset()
done = False
trunc = False
while True:
action, _ = model.predict(obs, deterministic=True)
obs, r, done, trunc, _ = test_env.step(action)
test_env.render()
if done or trunc:
obs, _ = test_env.reset()