| | """ |
| | Train a PPO Agent on the Semantic Triage Environment. |
| | |
| | This script trains an agent that learns from NurseEmbed-encoded clinical observations. |
| | """ |
| |
|
| | import os |
| | import numpy as np |
| | from datetime import datetime |
| |
|
| | |
| | from stable_baselines3 import PPO |
| | from stable_baselines3.common.env_util import make_vec_env |
| | from stable_baselines3.common.callbacks import EvalCallback, BaseCallback |
| | from stable_baselines3.common.vec_env import DummyVecEnv |
| |
|
| | |
| | from nursesim_rl import TriageEnv, NurseEmbedWrapper |
| |
|
| |
|
| | class PrintProgressCallback(BaseCallback): |
| | """Simple callback to print training progress.""" |
| | |
| | def __init__(self, print_freq: int = 1000, verbose: int = 0): |
| | super().__init__(verbose) |
| | self.print_freq = print_freq |
| | |
| | def _on_step(self) -> bool: |
| | if self.n_calls % self.print_freq == 0: |
| | |
| | if len(self.model.ep_info_buffer) > 0: |
| | mean_reward = np.mean([ep['r'] for ep in self.model.ep_info_buffer]) |
| | mean_length = np.mean([ep['l'] for ep in self.model.ep_info_buffer]) |
| | print(f"Step {self.n_calls}: Mean Reward = {mean_reward:.2f}, Mean Length = {mean_length:.1f}") |
| | return True |
| |
|
| |
|
| | def make_semantic_env(): |
| | """Factory function for creating the semantic environment.""" |
| | base_env = TriageEnv(max_steps=50, max_patients=20) |
| | semantic_env = NurseEmbedWrapper(base_env, use_vitals=True) |
| | return semantic_env |
| |
|
| |
|
| | def train_semantic_agent( |
| | total_timesteps: int = 50000, |
| | save_path: str = "models/semantic_ppo", |
| | log_dir: str = "logs/semantic_ppo", |
| | ): |
| | """Train a PPO agent on the semantic triage environment.""" |
| | |
| | print("=" * 60) |
| | print("SEMANTIC RL TRAINING") |
| | print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
| | print("=" * 60) |
| | |
| | |
| | os.makedirs(save_path, exist_ok=True) |
| | os.makedirs(log_dir, exist_ok=True) |
| | |
| | |
| | print("\n[1] Creating Semantic Environment...") |
| | env = DummyVecEnv([make_semantic_env]) |
| | |
| | print(f" Observation space: {env.observation_space}") |
| | print(f" Action space: {env.action_space}") |
| | |
| | |
| | eval_env = DummyVecEnv([make_semantic_env]) |
| | |
| | |
| | print("\n[2] Initializing PPO Agent...") |
| | model = PPO( |
| | "MlpPolicy", |
| | env, |
| | learning_rate=3e-4, |
| | n_steps=2048, |
| | batch_size=64, |
| | n_epochs=10, |
| | gamma=0.99, |
| | gae_lambda=0.95, |
| | clip_range=0.2, |
| | ent_coef=0.01, |
| | verbose=0, |
| | tensorboard_log=log_dir, |
| | ) |
| | |
| | print(f" Policy architecture: {model.policy}") |
| | |
| | |
| | progress_callback = PrintProgressCallback(print_freq=2000) |
| | eval_callback = EvalCallback( |
| | eval_env, |
| | best_model_save_path=save_path, |
| | log_path=log_dir, |
| | eval_freq=5000, |
| | n_eval_episodes=5, |
| | deterministic=True, |
| | verbose=0, |
| | ) |
| | |
| | |
| | print(f"\n[3] Training for {total_timesteps:,} timesteps...") |
| | print("-" * 60) |
| | |
| | model.learn( |
| | total_timesteps=total_timesteps, |
| | callback=[progress_callback, eval_callback], |
| | progress_bar=True, |
| | ) |
| | |
| | |
| | final_path = os.path.join(save_path, "semantic_ppo_final") |
| | model.save(final_path) |
| | print(f"\n[4] Model saved to: {final_path}") |
| | |
| | |
| | print("\n[5] Final Evaluation (10 episodes)...") |
| | rewards = [] |
| | for i in range(10): |
| | obs = eval_env.reset() |
| | episode_reward = 0 |
| | done = False |
| | while not done: |
| | action, _ = model.predict(obs, deterministic=True) |
| | obs, reward, done, info = eval_env.step(action) |
| | episode_reward += reward[0] |
| | rewards.append(episode_reward) |
| | |
| | print(f" Mean Reward: {np.mean(rewards):.2f} +/- {np.std(rewards):.2f}") |
| | print(f" Best Episode: {np.max(rewards):.2f}") |
| | print(f" Worst Episode: {np.min(rewards):.2f}") |
| | |
| | print("\n" + "=" * 60) |
| | print("TRAINING COMPLETE!") |
| | print("=" * 60) |
| | |
| | return model |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | train_semantic_agent(total_timesteps=20000) |
| |
|