| | """Define an Agent Supervisor graph with specialized worker agents. |
| | |
| | The supervisor routes tasks to specialized agents based on the query type. |
| | """ |
| |
|
| | from typing import Dict, List, Literal, Optional, Union, Type, cast |
| |
|
| | from langchain_core.messages import HumanMessage, SystemMessage, AIMessage |
| | from langchain_core.prompts import ChatPromptTemplate |
| | from langgraph.graph import StateGraph, START, END |
| | |
| | from langgraph.prebuilt import create_react_agent |
| | from langgraph.types import Command |
| |
|
| | from react_agent.configuration import Configuration |
| | from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router, Plan, PlanStep, CriticVerdict |
| | from react_agent.tools import TOOLS, tavily_tool, python_repl_tool |
| | from react_agent.utils import load_chat_model, format_system_prompt, get_message_text |
| | from react_agent import prompts |
| | from react_agent.supervisor_node import supervisor_node |
| |
|
| |
|
| | |
| | SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"] |
| | WorkerDestination = Literal["supervisor"] |
| |
|
| |
|
| | |
| | def is_user_message(message): |
| | """Check if a message is from a user regardless of message format.""" |
| | if isinstance(message, dict): |
| | return message.get("role") == "user" |
| | elif isinstance(message, HumanMessage): |
| | return True |
| | return False |
| |
|
| |
|
| | |
| | def get_message_content(message): |
| | """Extract content from a message regardless of format.""" |
| | if isinstance(message, dict): |
| | return message.get("content", "") |
| | elif hasattr(message, "content"): |
| | return message.content |
| | return "" |
| |
|
| |
|
| | |
| |
|
| | def planner_node(state: State) -> Command[WorkerDestination]: |
| | """Planning LLM that creates a step-by-step execution plan. |
| | |
| | Args: |
| | state: The current state with messages |
| | |
| | Returns: |
| | Command to update the state with a plan |
| | """ |
| | configuration = Configuration.from_context() |
| | |
| | planner_llm = load_chat_model(configuration.planner_model) |
| | |
| | |
| | steps_taken = state.get("steps_taken", 0) |
| | steps_taken += 1 |
| | |
| | |
| | user_messages = [m for m in state["messages"] if is_user_message(m)] |
| | original_question = get_message_content(user_messages[-1]) if user_messages else "Help me" |
| | |
| | |
| | planner_prompt_template = ChatPromptTemplate.from_messages([ |
| | ("system", prompts.PLANNER_PROMPT), |
| | ("user", "{question}") |
| | ]) |
| | |
| | |
| | formatted_messages = planner_prompt_template.format_messages( |
| | question=original_question, |
| | system_time=format_system_prompt("{system_time}"), |
| | workers=", ".join(WORKERS), |
| | worker_options=", ".join([f'"{w}"' for w in WORKERS]), |
| | example_worker_1=WORKERS[0] if WORKERS else "researcher", |
| | example_worker_2=WORKERS[1] if len(WORKERS) > 1 else "coder" |
| | ) |
| |
|
| | |
| | plan = planner_llm.with_structured_output(Plan).invoke(formatted_messages) |
| | |
| | |
| | return Command( |
| | goto="supervisor", |
| | update={ |
| | "plan": plan, |
| | "current_step_index": 0, |
| | |
| | "messages": [ |
| | HumanMessage( |
| | content=f"Created plan with {len(plan['steps'])} steps", |
| | name="planner" |
| | ) |
| | ], |
| | "steps_taken": steps_taken |
| | } |
| | ) |
| |
|
| |
|
| | |
| |
|
| | def final_answer_node(state: State) -> Command[Literal["__end__"]]: |
| | """Generate a final answer based on gathered information. |
| | |
| | Args: |
| | state: The current state with messages and context |
| | |
| | Returns: |
| | Command with final answer |
| | """ |
| | configuration = Configuration.from_context() |
| |
|
| | |
| | steps_taken = state.get("steps_taken", 0) |
| | steps_taken += 1 |
| | |
| | |
| | retry_exhausted = state.get("retry_exhausted", False) |
| | draft_answer = state.get("draft_answer") |
| | |
| | |
| | gaia_answer = "" |
| | |
| | if retry_exhausted and draft_answer and draft_answer.startswith("FINAL ANSWER:"): |
| | |
| | |
| | import re |
| | final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", draft_answer, re.IGNORECASE) |
| | if final_answer_match: |
| | gaia_answer = final_answer_match.group(1).strip() |
| | else: |
| | gaia_answer = "unknown" |
| | else: |
| | |
| | final_llm = load_chat_model(configuration.final_answer_model) |
| | |
| | |
| | user_messages = [m for m in state["messages"] if is_user_message(m)] |
| | original_question = get_message_content(user_messages[-1]) if user_messages else "Help me" |
| | |
| | |
| | if draft_answer and draft_answer.startswith("FINAL ANSWER:"): |
| | |
| | raw_answer = draft_answer |
| | else: |
| | |
| | context = state.get("context", {}) |
| | worker_results = state.get("worker_results", {}) |
| |
|
| | |
| | final_prompt = ChatPromptTemplate.from_messages([ |
| | ("system", prompts.FINAL_ANSWER_PROMPT), |
| | ("user", prompts.FINAL_ANSWER_USER_PROMPT) |
| | ]) |
| | |
| | |
| | context_list = [] |
| | |
| | if "researcher" in context: |
| | context_list.append(f"Research information: {context['researcher']}") |
| | |
| | |
| | if "coder" in context: |
| | context_list.append(f"Calculation results: {context['coder']}") |
| | |
| | |
| | for worker, content in context.items(): |
| | if worker not in ["researcher", "coder"]: |
| | context_list.append(f"{worker.capitalize()}: {content}") |
| | |
| | |
| | formatted_messages = final_prompt.format_messages( |
| | question=original_question, |
| | context="\n\n".join(context_list) |
| | ) |
| | |
| | raw_answer = final_llm.invoke(formatted_messages).content |
| | |
| | |
| | import re |
| | gaia_answer = raw_answer |
| | final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", raw_answer, re.IGNORECASE) |
| | if final_answer_match: |
| | gaia_answer = final_answer_match.group(1).strip() |
| | |
| | |
| | |
| | if configuration.allow_agent_to_extract_answers and (not gaia_answer or gaia_answer.lower() in ["unknown", "insufficient information"]): |
| | context = state.get("context", {}) |
| | from react_agent.supervisor_node import extract_best_answer_from_context |
| | extracted_answer = extract_best_answer_from_context(context) |
| | if extracted_answer != "unknown": |
| | gaia_answer = extracted_answer |
| | |
| | |
| | return Command( |
| | goto=END, |
| | update={ |
| | "messages": [ |
| | AIMessage( |
| | content=f"FINAL ANSWER: {gaia_answer}", |
| | name="supervisor" |
| | ) |
| | ], |
| | "next": "FINISH", |
| | "gaia_answer": gaia_answer, |
| | "submitted_answer": gaia_answer, |
| | "status": "final_answer_generated", |
| | "steps_taken": steps_taken |
| | } |
| | ) |
| |
|
| |
|
| | |
| |
|
| | def critic_node(state: State) -> Command[Union[WorkerDestination, SupervisorDestinations]]: |
| | """Critic that evaluates if the answer fully satisfies the request. |
| | |
| | Args: |
| | state: The current state with messages and draft answer |
| | |
| | Returns: |
| | Command with evaluation verdict |
| | """ |
| | configuration = Configuration.from_context() |
| | |
| | critic_llm = load_chat_model(configuration.critic_model) |
| | |
| | |
| | steps_taken = state.get("steps_taken", 0) |
| | steps_taken += 1 |
| | |
| | |
| | user_messages = [m for m in state["messages"] if is_user_message(m)] |
| | original_question = get_message_content(user_messages[-1]) if user_messages else "Help me" |
| | |
| | |
| | draft_answer = state.get("draft_answer", "No answer provided.") |
| | |
| | |
| | critic_prompt_template = ChatPromptTemplate.from_messages([ |
| | ("system", prompts.CRITIC_PROMPT), |
| | ("user", prompts.CRITIC_USER_PROMPT) |
| | ]) |
| | |
| | |
| | formatted_messages = critic_prompt_template.format_messages( |
| | question=original_question, |
| | answer=draft_answer, |
| | system_time=format_system_prompt("{system_time}"), |
| | correct_verdict=VERDICTS[0] if VERDICTS else "CORRECT", |
| | retry_verdict=VERDICTS[1] if len(VERDICTS) > 1 else "RETRY" |
| | ) |
| |
|
| | |
| | verdict = critic_llm.with_structured_output(CriticVerdict).invoke(formatted_messages) |
| | |
| | |
| | if verdict["verdict"] == VERDICTS[0]: |
| | verdict_message = "Answer is complete, accurate, and properly formatted for GAIA." |
| | goto = "final_answer" |
| | else: |
| | verdict_message = f"Answer needs improvement. Reason: {verdict.get('reason', 'Unknown')}" |
| | goto = "supervisor" |
| | |
| | |
| | return Command( |
| | goto=goto, |
| | update={ |
| | "critic_verdict": verdict, |
| | "messages": [ |
| | HumanMessage( |
| | content=verdict_message, |
| | name="critic" |
| | ) |
| | ], |
| | "steps_taken": steps_taken |
| | } |
| | ) |
| |
|
| |
|
| | |
| |
|
| | def create_worker_node(worker_type: str): |
| | """Factory function to create a worker node of the specified type. |
| | |
| | Args: |
| | worker_type: The type of worker to create (must be in WORKERS) |
| | |
| | Returns: |
| | A function that processes requests for the specified worker type |
| | """ |
| | if worker_type not in WORKERS: |
| | raise ValueError(f"Unknown worker type: {worker_type}") |
| | |
| | configuration = Configuration.from_context() |
| | |
| | |
| | if worker_type == "researcher": |
| | llm = load_chat_model(configuration.researcher_model) |
| | worker_prompt = prompts.RESEARCHER_PROMPT |
| | worker_tools = [tavily_tool] |
| | elif worker_type == "coder": |
| | llm = load_chat_model(configuration.coder_model) |
| | worker_prompt = prompts.CODER_PROMPT |
| | worker_tools = [python_repl_tool] |
| | else: |
| | |
| | llm = load_chat_model(configuration.model) |
| | worker_prompt = getattr(prompts, f"{worker_type.upper()}_PROMPT", prompts.SYSTEM_PROMPT) |
| | worker_tools = TOOLS |
| | |
| | |
| | worker_agent = create_react_agent( |
| | llm, |
| | tools=worker_tools, |
| | prompt=format_system_prompt(worker_prompt) |
| | ) |
| | |
| | |
| | def worker_node(state: State) -> Command[WorkerDestination]: |
| | """Process requests using the specified worker. |
| | |
| | Args: |
| | state: The current conversation state |
| | |
| | Returns: |
| | Command to return to supervisor with results |
| | """ |
| | |
| | steps_taken = state.get("steps_taken", 0) |
| | steps_taken += 1 |
| | |
| | |
| | task_message = None |
| | if state.get("messages"): |
| | for msg in reversed(state["messages"]): |
| | if hasattr(msg, "name") and msg.name == "supervisor": |
| | task_message = msg |
| | break |
| | |
| | if not task_message: |
| | return Command( |
| | goto="supervisor", |
| | update={ |
| | "messages": [ |
| | HumanMessage( |
| | content=f"Error: No task message found for {worker_type}", |
| | name=worker_type |
| | ) |
| | ], |
| | "steps_taken": steps_taken |
| | } |
| | ) |
| | |
| | |
| | |
| | agent_input = { |
| | "messages": [ |
| | |
| | state["messages"][0] if state["messages"] else HumanMessage(content="Help me"), |
| | |
| | task_message |
| | ] |
| | } |
| | |
| | |
| | result = worker_agent.invoke(agent_input) |
| | |
| | |
| | result_content = extract_worker_result(worker_type, result, state) |
| | |
| | |
| | context_update = state.get("context", {}).copy() |
| | context_update[worker_type] = result_content |
| | |
| | |
| | worker_results = state.get("worker_results", {}).copy() |
| | if worker_type not in worker_results: |
| | worker_results[worker_type] = [] |
| | worker_results[worker_type].append(result_content) |
| | |
| | |
| | current_step_index = state.get("current_step_index", 0) |
| | |
| | return Command( |
| | update={ |
| | "messages": [ |
| | HumanMessage(content=result_content, name=worker_type) |
| | ], |
| | "current_step_index": current_step_index + 1, |
| | "context": context_update, |
| | "worker_results": worker_results, |
| | "steps_taken": steps_taken |
| | }, |
| | goto="supervisor", |
| | ) |
| | |
| | return worker_node |
| |
|
| |
|
| | def extract_worker_result(worker_type: str, result: dict, state: State) -> str: |
| | """Extract a clean, useful result from the worker's output. |
| | |
| | This handles different response formats from different worker types. |
| | |
| | Args: |
| | worker_type: The type of worker (researcher or coder) |
| | result: The raw result from the worker agent |
| | state: The current state for context |
| | |
| | Returns: |
| | A cleaned string with the relevant result information |
| | """ |
| | |
| | if not result or "messages" not in result or not result["messages"]: |
| | return f"No output from {worker_type}" |
| | |
| | |
| | last_message = result["messages"][-1] |
| | |
| | |
| | if hasattr(last_message, "content") and last_message.content: |
| | result_content = last_message.content |
| | else: |
| | result_content = f"No content from {worker_type}" |
| | |
| | |
| | if worker_type == "coder": |
| | |
| | if "```" in result_content: |
| | |
| | import re |
| | stdout_match = re.search(r"Stdout:\s*(.*?)(?:\n\n|$)", result_content, re.DOTALL) |
| | if stdout_match: |
| | |
| | execution_result = stdout_match.group(1).strip() |
| | if execution_result: |
| | |
| | if re.match(r"^\d+(\.\d+)?$", execution_result): |
| | return execution_result |
| | else: |
| | return f"Code executed with result: {execution_result}" |
| | |
| | |
| | |
| | result_match = re.search(r"(?:Result|Output|Answer):\s*(.*?)(?:\n\n|$)", result_content, re.DOTALL) |
| | if result_match: |
| | return result_match.group(1).strip() |
| | |
| | elif worker_type == "researcher": |
| | |
| | |
| | if len(result_content) > 800: |
| | |
| | |
| | import re |
| | summary_match = re.search(r"(?:Summary|Conclusion|To summarize|In summary):(.*?)(?:\n\n|$)", |
| | result_content, re.IGNORECASE | re.DOTALL) |
| | if summary_match: |
| | return summary_match.group(1).strip() |
| | |
| | |
| | return result_content |
| |
|
| |
|
| | |
| |
|
| | def create_agent_supervisor_graph() -> StateGraph: |
| | """Create the agent supervisor graph with all nodes and edges. |
| | |
| | Returns: |
| | Compiled StateGraph ready for execution |
| | """ |
| | |
| | builder = StateGraph(State) |
| | |
| | |
| | builder.add_node("planner", planner_node) |
| | builder.add_node("supervisor", supervisor_node) |
| | builder.add_node("critic", critic_node) |
| | builder.add_node("final_answer", final_answer_node) |
| |
|
| | |
| | for worker_type in WORKERS: |
| | builder.add_node(worker_type, create_worker_node(worker_type)) |
| | |
| | |
| | builder.add_edge(START, "supervisor") |
| | builder.add_edge("planner", "supervisor") |
| | builder.add_edge("critic", "supervisor") |
| | builder.add_edge("critic", "final_answer") |
| | builder.add_edge("final_answer", END) |
| | builder.add_edge("supervisor", END) |
| | |
| | |
| | for worker_type in WORKERS: |
| | builder.add_edge(worker_type, "supervisor") |
| | |
| | |
| | |
| | return builder |
| |
|
| |
|
| | |
| |
|
| | def get_compiled_graph(checkpointer=None): |
| | """Get a compiled graph with optional checkpointer. |
| | |
| | Args: |
| | checkpointer: Optional checkpointer for persistence |
| | |
| | Returns: |
| | Compiled StateGraph ready for execution |
| | """ |
| | |
| | configuration = Configuration.from_context() |
| | |
| | builder = create_agent_supervisor_graph() |
| | |
| | |
| | def should_end(state): |
| | """Determine if the graph should terminate.""" |
| | |
| | if state.get("status") == "final_answer_generated": |
| | return True |
| | |
| | |
| | if state.get("retry_exhausted") and state.get("gaia_answer"): |
| | return True |
| | |
| | |
| | steps_taken = state.get("steps_taken", 0) |
| | if steps_taken >= configuration.recursion_limit - 5: |
| | return True |
| | |
| | return False |
| | |
| | |
| | def count_steps(state): |
| | """Count steps to prevent infinite loops.""" |
| | steps_taken = state.get("steps_taken", 0) |
| | return {"steps_taken": steps_taken + 1} |
| | |
| | |
| | if checkpointer: |
| | graph = builder.compile( |
| | checkpointer=checkpointer, |
| | name="Structured Reasoning Loop" |
| | ) |
| | else: |
| | graph = builder.compile( |
| | name="Structured Reasoning Loop" |
| | ) |
| | |
| | |
| | graph = graph.with_config({ |
| | "recursion_limit": configuration.recursion_limit, |
| | "max_iterations": configuration.max_iterations |
| | }) |
| | |
| | return graph |
| |
|
| |
|
| | |
| | graph = get_compiled_graph() |
| |
|