Files
neuroevolution/sac/td3_main.py
2025-12-13 14:12:35 +01:00

62 lines
1.9 KiB
Python

import os
import numpy as np
from stable_baselines3 import TD3
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
from stable_baselines3.common.env_util import make_vec_env
import torch as th
from car_racing_env import CarRacing # <-- Pfad zu deiner Datei
if __name__ == "__main__":
# Optional: reproducibility
np.random.seed(0)
th.manual_seed(0)
run_name = "td3_run_2" # oder datetime.now().strftime("%Y%m%d_%H%M")
tensorboard_log = f"./tb_{run_name}/"
best_model_path = f"./{run_name}_best/"
eval_log_path = f"./{run_name}_eval/"
model_save_path = f"./{run_name}_models/"
os.makedirs(model_save_path, exist_ok=True)
train_env = Monitor(CarRacing(seed_value=0, render_mode=None))
model = TD3(
policy="MlpPolicy",
env=train_env,
verbose=1,
tensorboard_log=tensorboard_log,
learning_starts=20_000,
)
eval_env = Monitor(CarRacing(seed_value=1, render_mode=None))
stop_cb = StopTrainingOnNoModelImprovement(
max_no_improvement_evals=20, min_evals=5, verbose=1
)
eval_cb = EvalCallback(
eval_env,
best_model_save_path=best_model_path,
log_path=eval_log_path,
eval_freq=5_000,
deterministic=True,
render=False,
callback_after_eval=stop_cb,
)
model.learn(total_timesteps=400_000, callback=eval_cb, progress_bar=True)
model.save(f"{model_save_path}/td3_carracing_features")
# Kurzer Testlauf mit Rendering (optional)
test_env = CarRacing(seed_value=0, render_mode="human")
obs, info = test_env.reset()
done = False
trunc = False
while not (done or trunc):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, trunc, info = test_env.step(action)
test_env.render()
test_env.close()