62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
import os
|
|
import numpy as np
|
|
from stable_baselines3 import TD3
|
|
from stable_baselines3.common.monitor import Monitor
|
|
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
|
|
from stable_baselines3.common.env_util import make_vec_env
|
|
import torch as th
|
|
|
|
from car_racing_env import CarRacing # <-- Pfad zu deiner Datei
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Optional: reproducibility
|
|
np.random.seed(0)
|
|
th.manual_seed(0)
|
|
|
|
run_name = "td3_run_2" # oder datetime.now().strftime("%Y%m%d_%H%M")
|
|
|
|
tensorboard_log = f"./tb_{run_name}/"
|
|
best_model_path = f"./{run_name}_best/"
|
|
eval_log_path = f"./{run_name}_eval/"
|
|
model_save_path = f"./{run_name}_models/"
|
|
|
|
os.makedirs(model_save_path, exist_ok=True)
|
|
|
|
train_env = Monitor(CarRacing(seed_value=0, render_mode=None))
|
|
model = TD3(
|
|
policy="MlpPolicy",
|
|
env=train_env,
|
|
verbose=1,
|
|
tensorboard_log=tensorboard_log,
|
|
learning_starts=20_000,
|
|
)
|
|
|
|
eval_env = Monitor(CarRacing(seed_value=1, render_mode=None))
|
|
stop_cb = StopTrainingOnNoModelImprovement(
|
|
max_no_improvement_evals=20, min_evals=5, verbose=1
|
|
)
|
|
eval_cb = EvalCallback(
|
|
eval_env,
|
|
best_model_save_path=best_model_path,
|
|
log_path=eval_log_path,
|
|
eval_freq=5_000,
|
|
deterministic=True,
|
|
render=False,
|
|
callback_after_eval=stop_cb,
|
|
)
|
|
|
|
model.learn(total_timesteps=400_000, callback=eval_cb, progress_bar=True)
|
|
model.save(f"{model_save_path}/td3_carracing_features")
|
|
|
|
# Kurzer Testlauf mit Rendering (optional)
|
|
test_env = CarRacing(seed_value=0, render_mode="human")
|
|
obs, info = test_env.reset()
|
|
done = False
|
|
trunc = False
|
|
while not (done or trunc):
|
|
action, _ = model.predict(obs, deterministic=True)
|
|
obs, reward, done, trunc, info = test_env.step(action)
|
|
test_env.render()
|
|
test_env.close()
|