NurseSim-Triage-Demo / train_semantic_agent.py
NurseCitizenDeveloper's picture
chore: Update workspace before push
3d38e23
"""
Train a PPO Agent on the Semantic Triage Environment.
This script trains an agent that learns from NurseEmbed-encoded clinical observations.
"""
import os
import numpy as np
from datetime import datetime
# Stable Baselines 3
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
from stable_baselines3.common.vec_env import DummyVecEnv
# Our environment
from nursesim_rl import TriageEnv, NurseEmbedWrapper
class PrintProgressCallback(BaseCallback):
"""Simple callback to print training progress."""
def __init__(self, print_freq: int = 1000, verbose: int = 0):
super().__init__(verbose)
self.print_freq = print_freq
def _on_step(self) -> bool:
if self.n_calls % self.print_freq == 0:
# Get recent episode rewards if available
if len(self.model.ep_info_buffer) > 0:
mean_reward = np.mean([ep['r'] for ep in self.model.ep_info_buffer])
mean_length = np.mean([ep['l'] for ep in self.model.ep_info_buffer])
print(f"Step {self.n_calls}: Mean Reward = {mean_reward:.2f}, Mean Length = {mean_length:.1f}")
return True
def make_semantic_env():
"""Factory function for creating the semantic environment."""
base_env = TriageEnv(max_steps=50, max_patients=20)
semantic_env = NurseEmbedWrapper(base_env, use_vitals=True)
return semantic_env
def train_semantic_agent(
total_timesteps: int = 50000,
save_path: str = "models/semantic_ppo",
log_dir: str = "logs/semantic_ppo",
):
"""Train a PPO agent on the semantic triage environment."""
print("=" * 60)
print("SEMANTIC RL TRAINING")
print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("=" * 60)
# Create directories
os.makedirs(save_path, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
# Create vectorized environment
print("\n[1] Creating Semantic Environment...")
env = DummyVecEnv([make_semantic_env])
print(f" Observation space: {env.observation_space}")
print(f" Action space: {env.action_space}")
# Create evaluation environment
eval_env = DummyVecEnv([make_semantic_env])
# Create the PPO agent
print("\n[2] Initializing PPO Agent...")
model = PPO(
"MlpPolicy",
env,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.01,
verbose=0,
tensorboard_log=log_dir,
)
print(f" Policy architecture: {model.policy}")
# Callbacks
progress_callback = PrintProgressCallback(print_freq=2000)
eval_callback = EvalCallback(
eval_env,
best_model_save_path=save_path,
log_path=log_dir,
eval_freq=5000,
n_eval_episodes=5,
deterministic=True,
verbose=0,
)
# Train!
print(f"\n[3] Training for {total_timesteps:,} timesteps...")
print("-" * 60)
model.learn(
total_timesteps=total_timesteps,
callback=[progress_callback, eval_callback],
progress_bar=True,
)
# Save final model
final_path = os.path.join(save_path, "semantic_ppo_final")
model.save(final_path)
print(f"\n[4] Model saved to: {final_path}")
# Quick evaluation
print("\n[5] Final Evaluation (10 episodes)...")
rewards = []
for i in range(10):
obs = eval_env.reset()
episode_reward = 0
done = False
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, info = eval_env.step(action)
episode_reward += reward[0]
rewards.append(episode_reward)
print(f" Mean Reward: {np.mean(rewards):.2f} +/- {np.std(rewards):.2f}")
print(f" Best Episode: {np.max(rewards):.2f}")
print(f" Worst Episode: {np.min(rewards):.2f}")
print("\n" + "=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)
return model
if __name__ == "__main__":
# Run with reduced timesteps for quick demo
train_semantic_agent(total_timesteps=20000)