Spaces:
Sleeping
Sleeping
| """SDK-based orchestrator using OpenAI Agents SDK handoffs.""" | |
| import asyncio | |
| import logging | |
| import time | |
| from typing import Optional, Dict, Any | |
| from dataclasses import dataclass, field | |
| from agents import Agent, Runner, handoff | |
| from models.modification import MapState, Modification | |
| from core.modification_applicator import ModificationApplicator | |
| from core.guardrails import GuardrailsManager, ErrorResponse | |
| from styles import load_style, load_style_from_file, save_style_to_file | |
| from constants.llm_config import LLM_MODEL | |
| from mapmorph_agents.sdk_agents import create_color_agent, create_language_agent | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| class MapContext: | |
| """Context passed to all agents and tools. | |
| This context carries lightweight metadata about the map state. | |
| The full style JSON is NOT included to avoid token bloat - tools access it via get_current_context_style(). | |
| IMPORTANT: This context is serialized by the OpenAI Agents SDK, so we exclude the large current_style | |
| to prevent hitting token limits. The full style is accessible via the global _current_context variable. | |
| """ | |
| # Lightweight metadata only - NOT the full style JSON | |
| base_theme: str = "light" | |
| base_language: str = "en" | |
| modification_count: int = 0 # Number of existing modifications | |
| modifications: list = field(default_factory=list) # Capture all modifications created | |
| # Store reference to full map_state separately (not serialized) | |
| _map_state: Optional[MapState] = field(default=None, repr=False, compare=False) | |
| # Global variable to store current context for tools to access | |
| _current_context: Optional[MapContext] = None | |
| def get_current_context_style() -> Optional[Dict[str, Any]]: | |
| """Get the current style from the context. | |
| This allows tools to access the current map style without it being | |
| serialized in every API call (saves thousands of tokens). | |
| """ | |
| global _current_context | |
| if _current_context and _current_context._map_state: | |
| return _current_context._map_state.current_style | |
| return None | |
| class SDKMapMorphOrchestrator: | |
| """SDK-based orchestrator using handoffs for agent delegation. | |
| This orchestrator uses the OpenAI Agents SDK's handoff mechanism to route | |
| user requests to specialized agents (ColorAgent, LanguageAgent) instead of | |
| manual LLM-based routing. | |
| """ | |
| def __init__(self): | |
| """Initialize the SDK orchestrator with specialist agents and guardrails.""" | |
| # Initialize guardrails manager | |
| self.guardrails = GuardrailsManager() | |
| # Create specialist agents | |
| self.color_agent = create_color_agent() | |
| self.language_agent = create_language_agent() | |
| logger.info("SDK orchestrator initialized with guardrails") | |
| # Create triage agent with handoffs | |
| self.triage_agent = Agent( | |
| name="MapMorphOrchestrator", | |
| model=LLM_MODEL, # Use configured model | |
| instructions="""You are a routing coordinator. Your ONLY job is to hand off to specialist agents. | |
| **CRITICAL: DO NOT respond to requests yourself. IMMEDIATELY hand off to the appropriate agent.** | |
| **When to hand off:** | |
| 1. **Hand off to ColorAgent** if request mentions: | |
| - Colors (blue, green, red, darker, lighter, etc.) | |
| - Map layers (water, land, buildings, roads, earth, etc.) | |
| - Brand names (Nike, Spotify, Google, etc.) | |
| - Visual styling | |
| 2. **Hand off to LanguageAgent** if request mentions: | |
| - Multiple languages (English and Spanish, bilingual, etc.) | |
| - Language labels | |
| - Showing two or more languages | |
| **Your responses should ONLY be:** | |
| - Immediate handoff to ColorAgent for color requests | |
| - Immediate handoff to LanguageAgent for language requests | |
| - If truly unclear, ask ONE clarifying question, then hand off | |
| **DO NOT:** | |
| - Explain what the agent will do | |
| - Provide modification details | |
| - Talk about how modifications work | |
| - DO NOT respond with your own message - HAND OFF IMMEDIATELY! | |
| **Remember:** The specialist agents will handle everything. Your only job is routing. | |
| """, | |
| handoffs=[ | |
| handoff( | |
| agent=self.color_agent, | |
| tool_description_override="Handle color modification requests for map layers (water, land, buildings, roads, etc.)" | |
| ), | |
| handoff( | |
| agent=self.language_agent, | |
| tool_description_override="Handle dual-language label requests (bilingual maps)" | |
| ) | |
| ] | |
| ) | |
| async def process_request_async( | |
| self, | |
| user_message: str, | |
| map_state: MapState | |
| ) -> tuple[str, MapState]: | |
| """Process user request asynchronously through SDK agents with guardrails. | |
| Args: | |
| user_message: User's natural language request | |
| map_state: Current map state with modifications | |
| Returns: | |
| tuple: (response_message, updated_map_state) | |
| """ | |
| # Pre-request guardrails check | |
| error_response = self.guardrails.pre_request_check(user_message) | |
| if error_response: | |
| logger.warning(f"Pre-request check failed: {error_response.error}") | |
| error_message = f"⚠️ {error_response.error}" | |
| if error_response.suggestions: | |
| error_message += "\n\nSuggestions:\n" + "\n".join( | |
| f"• {s}" for s in error_response.suggestions | |
| ) | |
| return error_message, map_state | |
| # Create lightweight context (excludes large current_style to save tokens) | |
| # The full map_state is stored globally for tools to access on-demand | |
| context = MapContext( | |
| base_theme=map_state.base_theme, | |
| base_language=map_state.base_language, | |
| modification_count=len(map_state.modifications), | |
| modifications=[], | |
| _map_state=map_state # Store reference but won't be serialized | |
| ) | |
| # Store context globally so tools can access it | |
| global _current_context | |
| _current_context = context | |
| # Define hooks to capture tool calls | |
| from agents.run import RunHooks | |
| captured_modifications = [] | |
| class ModificationCapturingHooks(RunHooks): | |
| """Hooks to capture modification tool calls.""" | |
| async def on_handoff(self, context, from_agent, to_agent): | |
| """Log when handoffs occur.""" | |
| print(f"🔀 HANDOFF: {from_agent.name} → {to_agent.name}") | |
| logger.info(f"Handoff from {from_agent.name} to {to_agent.name}") | |
| async def on_agent_start(self, context, agent): | |
| """Log when agents start.""" | |
| print(f"🎭 AGENT START: {agent.name}") | |
| logger.info(f"Agent {agent.name} started") | |
| async def on_tool_start(self, context, agent, tool): | |
| """Log when tools are called.""" | |
| print(f"🔧 TOOL CALLED: {tool.name}") | |
| logger.info(f"Tool called: {tool.name}") | |
| async def on_tool_end(self, context, agent, tool, result: str): | |
| """Capture tool results that are modifications.""" | |
| print(f"📦 TOOL RESULT: {tool.name} returned result") | |
| logger.info(f"Tool {tool.name} completed") | |
| # Check if this is a modification-creating tool | |
| modification_tools = [ | |
| 'create_color_modification', | |
| 'create_dual_language_modification', | |
| 'create_multi_language_modification' | |
| ] | |
| if tool.name in modification_tools: | |
| # Parse the result string as it's returned as a string representation | |
| import ast | |
| try: | |
| result_dict = ast.literal_eval(result) | |
| print(f"✅ MODIFICATION CAPTURED from {tool.name}!") | |
| print(f" Modification: {result_dict}") | |
| logger.info(f"✓ Captured modification from {tool.name}") | |
| captured_modifications.append(result_dict) | |
| except: | |
| # Try treating it as a dict if it's already parsed | |
| if isinstance(result, dict): | |
| print(f"✅ MODIFICATION CAPTURED from {tool.name}!") | |
| logger.info(f"✓ Captured modification from {tool.name}") | |
| captured_modifications.append(result) | |
| else: | |
| logger.error(f"Failed to parse result from {tool.name}: {result}") | |
| else: | |
| logger.debug(f"Tool {tool.name} result not captured (not a modification tool)") | |
| hooks = ModificationCapturingHooks() | |
| # Retry configuration for rate limits | |
| max_retries = 3 | |
| base_delay = 2 # Start with 2 seconds | |
| max_delay = 60 # Cap at 60 seconds | |
| try: | |
| print(f"\n{'='*70}") | |
| print(f"🚀 OPENAI REQUEST STARTING") | |
| print(f"{'='*70}") | |
| print(f"User Message: {user_message}") | |
| print(f"Starting Agent: {self.triage_agent.name}") | |
| print(f"{'='*70}\n") | |
| logger.info(f"Processing request: {user_message[:100]}...") | |
| # Run agent with SDK with hooks to capture modifications | |
| # Add retry logic for rate limits | |
| result = None | |
| last_error = None | |
| for attempt in range(max_retries): | |
| try: | |
| result = await Runner.run( | |
| starting_agent=self.triage_agent, | |
| input=user_message, | |
| context=context, | |
| hooks=hooks, # Pass hooks to capture tool results | |
| ) | |
| break # Success, exit retry loop | |
| except Exception as e: | |
| last_error = e | |
| error_str = str(e).lower() | |
| # Check if it's a rate limit error | |
| is_rate_limit = ( | |
| "rate_limit" in error_str or | |
| "429" in error_str or | |
| "rate limit" in error_str | |
| ) | |
| if is_rate_limit and attempt < max_retries - 1: | |
| # Calculate exponential backoff delay | |
| delay = min(base_delay * (2 ** attempt), max_delay) | |
| print(f"\n⚠️ Rate limit hit (attempt {attempt + 1}/{max_retries})") | |
| print(f" Waiting {delay} seconds before retry...") | |
| logger.warning(f"Rate limit error, retrying in {delay}s (attempt {attempt + 1}/{max_retries})") | |
| await asyncio.sleep(delay) | |
| else: | |
| # Not a rate limit error, or out of retries | |
| raise | |
| # Check if we exhausted retries without success | |
| if result is None: | |
| if last_error: | |
| raise last_error | |
| else: | |
| raise RuntimeError("Failed to get result from OpenAI API after retries") | |
| print(f"\n{'='*70}") | |
| print(f"📥 OPENAI RESPONSE RECEIVED") | |
| print(f"{'='*70}") | |
| print(f"Result Type: {type(result)}") | |
| print(f"Has Final Output: {result.final_output is not None}") | |
| if result.final_output: | |
| print(f"Final Output Length: {len(str(result.final_output))} characters") | |
| print(f"Final Output Preview: {str(result.final_output)[:500]}") | |
| else: | |
| print(f"Final Output: None") | |
| # Check new_items for tool calls | |
| if hasattr(result, 'new_items') and result.new_items: | |
| print(f"\n🔍 NEW ITEMS DEBUG (Conversation History):") | |
| print(f"Number of new items: {len(result.new_items)}") | |
| for idx, item in enumerate(result.new_items): | |
| print(f"\n Item {idx + 1}: {type(item).__name__}") | |
| if hasattr(item, 'type'): | |
| print(f" Type: {item.type}") | |
| # For ToolCallItem - print all attributes | |
| if item.type == 'tool_call_item': | |
| print(f" 🔧 TOOL CALL DETECTED!") | |
| for attr_name in dir(item): | |
| if not attr_name.startswith('_'): | |
| try: | |
| attr_value = getattr(item, attr_name) | |
| if not callable(attr_value): | |
| print(f" {attr_name}: {attr_value}") | |
| except: | |
| pass | |
| # For ToolCallOutputItem - print all attributes | |
| if item.type == 'tool_call_output_item': | |
| print(f" 📦 TOOL OUTPUT DETECTED!") | |
| for attr_name in dir(item): | |
| if not attr_name.startswith('_'): | |
| try: | |
| attr_value = getattr(item, attr_name) | |
| if not callable(attr_value) and attr_name not in ['model_config', 'model_fields', 'model_fields_set']: | |
| print(f" {attr_name}: {str(attr_value)[:200]}") | |
| except: | |
| pass | |
| # Try to access more result details if available | |
| if hasattr(result, '__dict__'): | |
| print(f"\nResult Attributes:") | |
| for key, value in result.__dict__.items(): | |
| if key not in ['final_output', 'raw_responses']: | |
| print(f" - {key}: {type(value).__name__}") | |
| print(f"{'='*70}\n") | |
| # Extract response | |
| response_message = str(result.final_output) if result.final_output else "Request processed." | |
| print(f"\n{'='*70}") | |
| print(f"MODIFICATION SUMMARY") | |
| print(f"{'='*70}") | |
| print(f"Modifications captured: {len(captured_modifications)}") | |
| # Apply captured modifications to map state | |
| if captured_modifications: | |
| print(f"✅ Applying {len(captured_modifications)} modification(s) to map...") | |
| logger.info(f"Applying {len(captured_modifications)} modifications") | |
| # Add modifications to map state | |
| for i, mod_dict in enumerate(captured_modifications): | |
| print(f" [{i+1}] Type: {mod_dict.get('type')}") | |
| map_state.add_modification(mod_dict) | |
| # Load existing style.json if it exists, otherwise load fresh | |
| print(f"🎨 Re-rendering map with modifications...") | |
| existing_style = load_style_from_file() | |
| if existing_style: | |
| # Apply modifications to existing style.json | |
| modified_style = ModificationApplicator.apply_all( | |
| existing_style, | |
| map_state.modifications | |
| ) | |
| else: | |
| # No existing style, load fresh and apply modifications | |
| fresh_style = load_style(map_state.base_theme, map_state.base_language) | |
| modified_style = ModificationApplicator.apply_all( | |
| fresh_style, | |
| map_state.modifications | |
| ) | |
| # Save modified style to style.json | |
| save_style_to_file(modified_style) | |
| map_state.current_style = modified_style | |
| print(f"✅ Map updated! Total modifications: {len(map_state.modifications)}") | |
| logger.info(f"Map state updated with {len(map_state.modifications)} total modifications") | |
| else: | |
| print(f"⚠️ NO MODIFICATIONS CAPTURED!") | |
| print(f" This means the agent explained but didn't call modification tools.") | |
| print(f" The map will NOT change.") | |
| print(f"{'='*70}\n") | |
| # Post-request guardrails check | |
| post_error = self.guardrails.post_request_check( | |
| response_message, | |
| captured_modifications | |
| ) | |
| if post_error: | |
| logger.error(f"Post-request check failed: {post_error.error}") | |
| error_message = f"⚠️ Modification validation failed: {post_error.error}" | |
| if post_error.suggestions: | |
| error_message += "\n\nSuggestions:\n" + "\n".join( | |
| f"• {s}" for s in post_error.suggestions | |
| ) | |
| return error_message, map_state | |
| logger.info("Request processed successfully") | |
| # Return updated map state | |
| return response_message, map_state | |
| except Exception as e: | |
| print(f"\n{'='*70}") | |
| print(f"❌ OPENAI ERROR") | |
| print(f"{'='*70}") | |
| print(f"Error Type: {type(e).__name__}") | |
| print(f"Error Message: {str(e)}") | |
| print(f"{'='*70}\n") | |
| error_msg = f"Error processing request: {str(e)}" | |
| logger.error(f"SDK Orchestrator error: {error_msg}", exc_info=True) | |
| # Check if it's a rate limit error and provide helpful message | |
| error_str = str(e).lower() | |
| is_rate_limit = ( | |
| "rate_limit" in error_str or | |
| "429" in error_str or | |
| "rate limit" in error_str | |
| ) | |
| if is_rate_limit: | |
| print(f"⚠️ RATE LIMIT DETECTED") | |
| print(f"This is not a code issue - your OpenAI account has hit its rate limit.") | |
| print(f"The AI never processed your request.\n") | |
| # Extract rate limit details if available | |
| rate_limit_info = "" | |
| if "tpm" in error_str or "tokens per min" in error_str: | |
| rate_limit_info = "\n\n**Note:** You're hitting the TPM (tokens per minute) limit. " | |
| rate_limit_info += "The code has been optimized to reduce token usage, but you may still need to:\n" | |
| rate_limit_info += "- Wait longer between requests\n" | |
| rate_limit_info += "- Add a payment method to increase limits significantly\n" | |
| return ( | |
| "⚠️ **OpenAI API Rate Limit Exceeded**\n\n" | |
| "Your OpenAI account has hit its rate limit. This means the AI agents " | |
| "cannot process requests right now.\n\n" | |
| "**Solutions:**\n" | |
| "1. Wait 20-60 seconds and try again (the system will auto-retry)\n" | |
| "2. Add a payment method to your OpenAI account at " | |
| "https://platform.openai.com/account/billing to increase limits\n" | |
| "3. Check your current limits at https://platform.openai.com/account/rate-limits\n" | |
| f"{rate_limit_info}\n" | |
| f"_Error details: {str(e)[:200]}_" | |
| ), map_state | |
| return f"⚠️ {error_msg}", map_state | |
| def process_request( | |
| self, | |
| user_message: str, | |
| map_state: MapState | |
| ) -> tuple[str, MapState]: | |
| """Synchronous wrapper for async process_request. | |
| Args: | |
| user_message: User's natural language request | |
| map_state: Current map state | |
| Returns: | |
| tuple: (response_message, updated_map_state) | |
| """ | |
| # Get or create event loop | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| # Run async function | |
| return loop.run_until_complete( | |
| self.process_request_async(user_message, map_state) | |
| ) | |
| def get_guardrails_stats(self) -> Dict[str, int]: | |
| """Get guardrails usage statistics. | |
| Returns: | |
| Dictionary with request and error counts | |
| """ | |
| return self.guardrails.get_statistics() | |