53 lines
1.1 KiB
Python
53 lines
1.1 KiB
Python
import gymnasium as gym
|
|
import numpy as np
|
|
from stable_baselines3 import SAC
|
|
from stable_baselines3.common.env_util import make_vec_env
|
|
from stable_baselines3.common.monitor import Monitor
|
|
from stable_baselines3.common.vec_env import VecMonitor
|
|
|
|
from car_racing_env import CarRacing
|
|
|
|
SEED = 5
|
|
|
|
|
|
def make_env():
|
|
env = CarRacing(render_mode=None)
|
|
env.reset(seed=SEED)
|
|
return Monitor(env)
|
|
|
|
|
|
venv = make_vec_env(make_env, n_envs=1)
|
|
venv = VecMonitor(venv)
|
|
|
|
model = SAC(
|
|
"MlpPolicy",
|
|
venv,
|
|
seed=SEED,
|
|
learning_rate=3e-4,
|
|
buffer_size=300_000,
|
|
batch_size=256,
|
|
tau=0.01,
|
|
gamma=0.99,
|
|
train_freq=(1, "step"),
|
|
gradient_steps=1,
|
|
ent_coef="auto",
|
|
target_entropy=-3,
|
|
verbose=1,
|
|
device="auto",
|
|
)
|
|
|
|
model.learn(total_timesteps=500_000)
|
|
model.save("sac_carracing_features")
|
|
|
|
# Testen (mit Rendern)
|
|
test_env = CarRacing(render_mode="human")
|
|
obs, _ = test_env.reset()
|
|
done = False
|
|
trunc = False
|
|
while True:
|
|
action, _ = model.predict(obs, deterministic=True)
|
|
obs, r, done, trunc, _ = test_env.step(action)
|
|
test_env.render()
|
|
if done or trunc:
|
|
obs, _ = test_env.reset()
|