MapMorph2 / mapmorph_agents /sdk_orchestrator.py
CarlosHenr1que's picture
finish
420d49f
"""SDK-based orchestrator using OpenAI Agents SDK handoffs."""
import asyncio
import logging
import time
from typing import Optional, Dict, Any
from dataclasses import dataclass, field
from agents import Agent, Runner, handoff
from models.modification import MapState, Modification
from core.modification_applicator import ModificationApplicator
from core.guardrails import GuardrailsManager, ErrorResponse
from styles import load_style, load_style_from_file, save_style_to_file
from constants.llm_config import LLM_MODEL
from mapmorph_agents.sdk_agents import create_color_agent, create_language_agent
# Configure logging
logger = logging.getLogger(__name__)
@dataclass
class MapContext:
"""Context passed to all agents and tools.
This context carries lightweight metadata about the map state.
The full style JSON is NOT included to avoid token bloat - tools access it via get_current_context_style().
IMPORTANT: This context is serialized by the OpenAI Agents SDK, so we exclude the large current_style
to prevent hitting token limits. The full style is accessible via the global _current_context variable.
"""
# Lightweight metadata only - NOT the full style JSON
base_theme: str = "light"
base_language: str = "en"
modification_count: int = 0 # Number of existing modifications
modifications: list = field(default_factory=list) # Capture all modifications created
# Store reference to full map_state separately (not serialized)
_map_state: Optional[MapState] = field(default=None, repr=False, compare=False)
# Global variable to store current context for tools to access
_current_context: Optional[MapContext] = None
def get_current_context_style() -> Optional[Dict[str, Any]]:
"""Get the current style from the context.
This allows tools to access the current map style without it being
serialized in every API call (saves thousands of tokens).
"""
global _current_context
if _current_context and _current_context._map_state:
return _current_context._map_state.current_style
return None
class SDKMapMorphOrchestrator:
"""SDK-based orchestrator using handoffs for agent delegation.
This orchestrator uses the OpenAI Agents SDK's handoff mechanism to route
user requests to specialized agents (ColorAgent, LanguageAgent) instead of
manual LLM-based routing.
"""
def __init__(self):
"""Initialize the SDK orchestrator with specialist agents and guardrails."""
# Initialize guardrails manager
self.guardrails = GuardrailsManager()
# Create specialist agents
self.color_agent = create_color_agent()
self.language_agent = create_language_agent()
logger.info("SDK orchestrator initialized with guardrails")
# Create triage agent with handoffs
self.triage_agent = Agent(
name="MapMorphOrchestrator",
model=LLM_MODEL, # Use configured model
instructions="""You are a routing coordinator. Your ONLY job is to hand off to specialist agents.
**CRITICAL: DO NOT respond to requests yourself. IMMEDIATELY hand off to the appropriate agent.**
**When to hand off:**
1. **Hand off to ColorAgent** if request mentions:
- Colors (blue, green, red, darker, lighter, etc.)
- Map layers (water, land, buildings, roads, earth, etc.)
- Brand names (Nike, Spotify, Google, etc.)
- Visual styling
2. **Hand off to LanguageAgent** if request mentions:
- Multiple languages (English and Spanish, bilingual, etc.)
- Language labels
- Showing two or more languages
**Your responses should ONLY be:**
- Immediate handoff to ColorAgent for color requests
- Immediate handoff to LanguageAgent for language requests
- If truly unclear, ask ONE clarifying question, then hand off
**DO NOT:**
- Explain what the agent will do
- Provide modification details
- Talk about how modifications work
- DO NOT respond with your own message - HAND OFF IMMEDIATELY!
**Remember:** The specialist agents will handle everything. Your only job is routing.
""",
handoffs=[
handoff(
agent=self.color_agent,
tool_description_override="Handle color modification requests for map layers (water, land, buildings, roads, etc.)"
),
handoff(
agent=self.language_agent,
tool_description_override="Handle dual-language label requests (bilingual maps)"
)
]
)
async def process_request_async(
self,
user_message: str,
map_state: MapState
) -> tuple[str, MapState]:
"""Process user request asynchronously through SDK agents with guardrails.
Args:
user_message: User's natural language request
map_state: Current map state with modifications
Returns:
tuple: (response_message, updated_map_state)
"""
# Pre-request guardrails check
error_response = self.guardrails.pre_request_check(user_message)
if error_response:
logger.warning(f"Pre-request check failed: {error_response.error}")
error_message = f"⚠️ {error_response.error}"
if error_response.suggestions:
error_message += "\n\nSuggestions:\n" + "\n".join(
f"• {s}" for s in error_response.suggestions
)
return error_message, map_state
# Create lightweight context (excludes large current_style to save tokens)
# The full map_state is stored globally for tools to access on-demand
context = MapContext(
base_theme=map_state.base_theme,
base_language=map_state.base_language,
modification_count=len(map_state.modifications),
modifications=[],
_map_state=map_state # Store reference but won't be serialized
)
# Store context globally so tools can access it
global _current_context
_current_context = context
# Define hooks to capture tool calls
from agents.run import RunHooks
captured_modifications = []
class ModificationCapturingHooks(RunHooks):
"""Hooks to capture modification tool calls."""
async def on_handoff(self, context, from_agent, to_agent):
"""Log when handoffs occur."""
print(f"🔀 HANDOFF: {from_agent.name}{to_agent.name}")
logger.info(f"Handoff from {from_agent.name} to {to_agent.name}")
async def on_agent_start(self, context, agent):
"""Log when agents start."""
print(f"🎭 AGENT START: {agent.name}")
logger.info(f"Agent {agent.name} started")
async def on_tool_start(self, context, agent, tool):
"""Log when tools are called."""
print(f"🔧 TOOL CALLED: {tool.name}")
logger.info(f"Tool called: {tool.name}")
async def on_tool_end(self, context, agent, tool, result: str):
"""Capture tool results that are modifications."""
print(f"📦 TOOL RESULT: {tool.name} returned result")
logger.info(f"Tool {tool.name} completed")
# Check if this is a modification-creating tool
modification_tools = [
'create_color_modification',
'create_dual_language_modification',
'create_multi_language_modification'
]
if tool.name in modification_tools:
# Parse the result string as it's returned as a string representation
import ast
try:
result_dict = ast.literal_eval(result)
print(f"✅ MODIFICATION CAPTURED from {tool.name}!")
print(f" Modification: {result_dict}")
logger.info(f"✓ Captured modification from {tool.name}")
captured_modifications.append(result_dict)
except:
# Try treating it as a dict if it's already parsed
if isinstance(result, dict):
print(f"✅ MODIFICATION CAPTURED from {tool.name}!")
logger.info(f"✓ Captured modification from {tool.name}")
captured_modifications.append(result)
else:
logger.error(f"Failed to parse result from {tool.name}: {result}")
else:
logger.debug(f"Tool {tool.name} result not captured (not a modification tool)")
hooks = ModificationCapturingHooks()
# Retry configuration for rate limits
max_retries = 3
base_delay = 2 # Start with 2 seconds
max_delay = 60 # Cap at 60 seconds
try:
print(f"\n{'='*70}")
print(f"🚀 OPENAI REQUEST STARTING")
print(f"{'='*70}")
print(f"User Message: {user_message}")
print(f"Starting Agent: {self.triage_agent.name}")
print(f"{'='*70}\n")
logger.info(f"Processing request: {user_message[:100]}...")
# Run agent with SDK with hooks to capture modifications
# Add retry logic for rate limits
result = None
last_error = None
for attempt in range(max_retries):
try:
result = await Runner.run(
starting_agent=self.triage_agent,
input=user_message,
context=context,
hooks=hooks, # Pass hooks to capture tool results
)
break # Success, exit retry loop
except Exception as e:
last_error = e
error_str = str(e).lower()
# Check if it's a rate limit error
is_rate_limit = (
"rate_limit" in error_str or
"429" in error_str or
"rate limit" in error_str
)
if is_rate_limit and attempt < max_retries - 1:
# Calculate exponential backoff delay
delay = min(base_delay * (2 ** attempt), max_delay)
print(f"\n⚠️ Rate limit hit (attempt {attempt + 1}/{max_retries})")
print(f" Waiting {delay} seconds before retry...")
logger.warning(f"Rate limit error, retrying in {delay}s (attempt {attempt + 1}/{max_retries})")
await asyncio.sleep(delay)
else:
# Not a rate limit error, or out of retries
raise
# Check if we exhausted retries without success
if result is None:
if last_error:
raise last_error
else:
raise RuntimeError("Failed to get result from OpenAI API after retries")
print(f"\n{'='*70}")
print(f"📥 OPENAI RESPONSE RECEIVED")
print(f"{'='*70}")
print(f"Result Type: {type(result)}")
print(f"Has Final Output: {result.final_output is not None}")
if result.final_output:
print(f"Final Output Length: {len(str(result.final_output))} characters")
print(f"Final Output Preview: {str(result.final_output)[:500]}")
else:
print(f"Final Output: None")
# Check new_items for tool calls
if hasattr(result, 'new_items') and result.new_items:
print(f"\n🔍 NEW ITEMS DEBUG (Conversation History):")
print(f"Number of new items: {len(result.new_items)}")
for idx, item in enumerate(result.new_items):
print(f"\n Item {idx + 1}: {type(item).__name__}")
if hasattr(item, 'type'):
print(f" Type: {item.type}")
# For ToolCallItem - print all attributes
if item.type == 'tool_call_item':
print(f" 🔧 TOOL CALL DETECTED!")
for attr_name in dir(item):
if not attr_name.startswith('_'):
try:
attr_value = getattr(item, attr_name)
if not callable(attr_value):
print(f" {attr_name}: {attr_value}")
except:
pass
# For ToolCallOutputItem - print all attributes
if item.type == 'tool_call_output_item':
print(f" 📦 TOOL OUTPUT DETECTED!")
for attr_name in dir(item):
if not attr_name.startswith('_'):
try:
attr_value = getattr(item, attr_name)
if not callable(attr_value) and attr_name not in ['model_config', 'model_fields', 'model_fields_set']:
print(f" {attr_name}: {str(attr_value)[:200]}")
except:
pass
# Try to access more result details if available
if hasattr(result, '__dict__'):
print(f"\nResult Attributes:")
for key, value in result.__dict__.items():
if key not in ['final_output', 'raw_responses']:
print(f" - {key}: {type(value).__name__}")
print(f"{'='*70}\n")
# Extract response
response_message = str(result.final_output) if result.final_output else "Request processed."
print(f"\n{'='*70}")
print(f"MODIFICATION SUMMARY")
print(f"{'='*70}")
print(f"Modifications captured: {len(captured_modifications)}")
# Apply captured modifications to map state
if captured_modifications:
print(f"✅ Applying {len(captured_modifications)} modification(s) to map...")
logger.info(f"Applying {len(captured_modifications)} modifications")
# Add modifications to map state
for i, mod_dict in enumerate(captured_modifications):
print(f" [{i+1}] Type: {mod_dict.get('type')}")
map_state.add_modification(mod_dict)
# Load existing style.json if it exists, otherwise load fresh
print(f"🎨 Re-rendering map with modifications...")
existing_style = load_style_from_file()
if existing_style:
# Apply modifications to existing style.json
modified_style = ModificationApplicator.apply_all(
existing_style,
map_state.modifications
)
else:
# No existing style, load fresh and apply modifications
fresh_style = load_style(map_state.base_theme, map_state.base_language)
modified_style = ModificationApplicator.apply_all(
fresh_style,
map_state.modifications
)
# Save modified style to style.json
save_style_to_file(modified_style)
map_state.current_style = modified_style
print(f"✅ Map updated! Total modifications: {len(map_state.modifications)}")
logger.info(f"Map state updated with {len(map_state.modifications)} total modifications")
else:
print(f"⚠️ NO MODIFICATIONS CAPTURED!")
print(f" This means the agent explained but didn't call modification tools.")
print(f" The map will NOT change.")
print(f"{'='*70}\n")
# Post-request guardrails check
post_error = self.guardrails.post_request_check(
response_message,
captured_modifications
)
if post_error:
logger.error(f"Post-request check failed: {post_error.error}")
error_message = f"⚠️ Modification validation failed: {post_error.error}"
if post_error.suggestions:
error_message += "\n\nSuggestions:\n" + "\n".join(
f"• {s}" for s in post_error.suggestions
)
return error_message, map_state
logger.info("Request processed successfully")
# Return updated map state
return response_message, map_state
except Exception as e:
print(f"\n{'='*70}")
print(f"❌ OPENAI ERROR")
print(f"{'='*70}")
print(f"Error Type: {type(e).__name__}")
print(f"Error Message: {str(e)}")
print(f"{'='*70}\n")
error_msg = f"Error processing request: {str(e)}"
logger.error(f"SDK Orchestrator error: {error_msg}", exc_info=True)
# Check if it's a rate limit error and provide helpful message
error_str = str(e).lower()
is_rate_limit = (
"rate_limit" in error_str or
"429" in error_str or
"rate limit" in error_str
)
if is_rate_limit:
print(f"⚠️ RATE LIMIT DETECTED")
print(f"This is not a code issue - your OpenAI account has hit its rate limit.")
print(f"The AI never processed your request.\n")
# Extract rate limit details if available
rate_limit_info = ""
if "tpm" in error_str or "tokens per min" in error_str:
rate_limit_info = "\n\n**Note:** You're hitting the TPM (tokens per minute) limit. "
rate_limit_info += "The code has been optimized to reduce token usage, but you may still need to:\n"
rate_limit_info += "- Wait longer between requests\n"
rate_limit_info += "- Add a payment method to increase limits significantly\n"
return (
"⚠️ **OpenAI API Rate Limit Exceeded**\n\n"
"Your OpenAI account has hit its rate limit. This means the AI agents "
"cannot process requests right now.\n\n"
"**Solutions:**\n"
"1. Wait 20-60 seconds and try again (the system will auto-retry)\n"
"2. Add a payment method to your OpenAI account at "
"https://platform.openai.com/account/billing to increase limits\n"
"3. Check your current limits at https://platform.openai.com/account/rate-limits\n"
f"{rate_limit_info}\n"
f"_Error details: {str(e)[:200]}_"
), map_state
return f"⚠️ {error_msg}", map_state
def process_request(
self,
user_message: str,
map_state: MapState
) -> tuple[str, MapState]:
"""Synchronous wrapper for async process_request.
Args:
user_message: User's natural language request
map_state: Current map state
Returns:
tuple: (response_message, updated_map_state)
"""
# Get or create event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run async function
return loop.run_until_complete(
self.process_request_async(user_message, map_state)
)
def get_guardrails_stats(self) -> Dict[str, int]:
"""Get guardrails usage statistics.
Returns:
Dictionary with request and error counts
"""
return self.guardrails.get_statistics()