last workig state.
This commit is contained in:
61
sac/td3_main.py
Normal file
61
sac/td3_main.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from stable_baselines3 import TD3
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
import torch as th
|
||||
|
||||
from car_racing_env import CarRacing # <-- Pfad zu deiner Datei
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Optional: reproducibility
|
||||
np.random.seed(0)
|
||||
th.manual_seed(0)
|
||||
|
||||
run_name = "td3_run_2" # oder datetime.now().strftime("%Y%m%d_%H%M")
|
||||
|
||||
tensorboard_log = f"./tb_{run_name}/"
|
||||
best_model_path = f"./{run_name}_best/"
|
||||
eval_log_path = f"./{run_name}_eval/"
|
||||
model_save_path = f"./{run_name}_models/"
|
||||
|
||||
os.makedirs(model_save_path, exist_ok=True)
|
||||
|
||||
train_env = Monitor(CarRacing(seed_value=0, render_mode=None))
|
||||
model = TD3(
|
||||
policy="MlpPolicy",
|
||||
env=train_env,
|
||||
verbose=1,
|
||||
tensorboard_log=tensorboard_log,
|
||||
learning_starts=20_000,
|
||||
)
|
||||
|
||||
eval_env = Monitor(CarRacing(seed_value=1, render_mode=None))
|
||||
stop_cb = StopTrainingOnNoModelImprovement(
|
||||
max_no_improvement_evals=20, min_evals=5, verbose=1
|
||||
)
|
||||
eval_cb = EvalCallback(
|
||||
eval_env,
|
||||
best_model_save_path=best_model_path,
|
||||
log_path=eval_log_path,
|
||||
eval_freq=5_000,
|
||||
deterministic=True,
|
||||
render=False,
|
||||
callback_after_eval=stop_cb,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=400_000, callback=eval_cb, progress_bar=True)
|
||||
model.save(f"{model_save_path}/td3_carracing_features")
|
||||
|
||||
# Kurzer Testlauf mit Rendering (optional)
|
||||
test_env = CarRacing(seed_value=0, render_mode="human")
|
||||
obs, info = test_env.reset()
|
||||
done = False
|
||||
trunc = False
|
||||
while not (done or trunc):
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, trunc, info = test_env.step(action)
|
||||
test_env.render()
|
||||
test_env.close()
|
||||
Reference in New Issue
Block a user