initial data commit.
This commit is contained in:
52
sac_main.py
Normal file
52
sac_main.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from stable_baselines3 import SAC
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import VecMonitor
|
||||
|
||||
from car_racing_env import CarRacing
|
||||
|
||||
SEED = 5
|
||||
|
||||
|
||||
def make_env():
|
||||
env = CarRacing(render_mode=None)
|
||||
env.reset(seed=SEED)
|
||||
return Monitor(env)
|
||||
|
||||
|
||||
venv = make_vec_env(make_env, n_envs=1)
|
||||
venv = VecMonitor(venv)
|
||||
|
||||
model = SAC(
|
||||
"MlpPolicy",
|
||||
venv,
|
||||
seed=SEED,
|
||||
learning_rate=3e-4,
|
||||
buffer_size=300_000,
|
||||
batch_size=256,
|
||||
tau=0.01,
|
||||
gamma=0.99,
|
||||
train_freq=(1, "step"),
|
||||
gradient_steps=1,
|
||||
ent_coef="auto",
|
||||
target_entropy=-3,
|
||||
verbose=1,
|
||||
device="auto",
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=500_000)
|
||||
model.save("sac_carracing_features")
|
||||
|
||||
# Testen (mit Rendern)
|
||||
test_env = CarRacing(render_mode="human")
|
||||
obs, _ = test_env.reset()
|
||||
done = False
|
||||
trunc = False
|
||||
while True:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, r, done, trunc, _ = test_env.step(action)
|
||||
test_env.render()
|
||||
if done or trunc:
|
||||
obs, _ = test_env.reset()
|
||||
Reference in New Issue
Block a user