| | """Supervisor node implementation for the agent supervisor system.""" |
| |
|
| | from typing import Dict, List, Literal, Optional, Union, Type, cast |
| |
|
| | from langchain_core.messages import HumanMessage, SystemMessage, AIMessage |
| | from langchain_core.prompts import ChatPromptTemplate |
| | from langgraph.graph import StateGraph, START, END |
| | from langgraph.types import Command |
| |
|
| | from react_agent.configuration import Configuration |
| | from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router |
| | from react_agent.utils import load_chat_model, format_system_prompt, get_message_text |
| | from react_agent import prompts |
| |
|
| |
|
| | |
| | SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"] |
| |
|
| |
|
| | def supervisor_node(state: State) -> Command[SupervisorDestinations]: |
| | """Supervising LLM that decides which specialized agent should act next. |
| | |
| | Args: |
| | state: The current state with messages |
| | |
| | Returns: |
| | Command with routing information |
| | """ |
| | |
| | configuration = Configuration.from_context() |
| | |
| | |
| | steps_taken = state.get("steps_taken", 0) |
| | steps_taken += 1 |
| | state_updates = {"steps_taken": steps_taken} |
| | |
| | |
| | if steps_taken >= configuration.recursion_limit - 5: |
| | |
| | context = state.get("context", {}) |
| | answer = extract_best_answer_from_context(context) |
| | |
| | return Command( |
| | goto="final_answer", |
| | update={ |
| | "messages": [ |
| | HumanMessage( |
| | content=f"Maximum steps ({steps_taken}) reached. Extracting best answer from available information.", |
| | name="supervisor" |
| | ) |
| | ], |
| | "draft_answer": f"FINAL ANSWER: {answer}", |
| | "retry_exhausted": True, |
| | "steps_taken": steps_taken |
| | } |
| | ) |
| | |
| | |
| | retry_count = state.get("retry_count", 0) |
| | max_retries = 2 |
| | |
| | if retry_count > max_retries: |
| | |
| | context = state.get("context", {}) |
| | answer = extract_best_answer_from_context(context) |
| | |
| | return Command( |
| | goto="final_answer", |
| | update={ |
| | "messages": [ |
| | HumanMessage( |
| | content=f"Maximum retries ({max_retries}) reached. Extracting best answer from available information.", |
| | name="supervisor" |
| | ) |
| | ], |
| | "draft_answer": f"FINAL ANSWER: {answer}", |
| | "retry_exhausted": True, |
| | "steps_taken": steps_taken |
| | } |
| | ) |
| | |
| | |
| | if not state.get("plan"): |
| | return Command( |
| | goto="planner", |
| | update={ |
| | **state_updates |
| | } |
| | ) |
| | |
| | |
| | plan = state.get("plan") |
| | if not plan.get("steps") or len(plan.get("steps", [])) == 0: |
| | |
| | return Command( |
| | goto="planner", |
| | update={ |
| | "messages": [ |
| | HumanMessage( |
| | content="Previous plan had 0 steps. Please create a plan with at least 1 step to solve the user's question.", |
| | name="supervisor" |
| | ) |
| | ], |
| | "plan": None, |
| | **state_updates |
| | } |
| | ) |
| | |
| | |
| | critic_verdict = state.get("critic_verdict") |
| | if critic_verdict: |
| | if critic_verdict.get("verdict") == VERDICTS[0]: |
| | |
| | |
| | return Command( |
| | goto="final_answer", |
| | update={ |
| | "messages": [ |
| | HumanMessage( |
| | content="Answer approved by critic. Generating final response.", |
| | name="supervisor" |
| | ) |
| | ] |
| | } |
| | ) |
| | elif critic_verdict.get("verdict") == VERDICTS[1]: |
| | |
| | current_retry_count = state.get("retry_count", 0) |
| | |
| | |
| | if current_retry_count >= max_retries: |
| | |
| | context = state.get("context", {}) |
| | answer = extract_best_answer_from_context(context) |
| | |
| | return Command( |
| | goto="final_answer", |
| | update={ |
| | "messages": [ |
| | HumanMessage( |
| | content=f"Maximum retries ({max_retries}) reached. Proceeding with best available answer.", |
| | name="supervisor" |
| | ) |
| | ], |
| | "draft_answer": f"FINAL ANSWER: {answer}", |
| | "retry_exhausted": True |
| | } |
| | ) |
| | |
| | |
| | context = state.get("context", {}) |
| | worker_results = state.get("worker_results", {}) |
| | |
| | |
| | reason = critic_verdict.get("reason", "") |
| | if not reason or reason.strip() == "\"": |
| | reason = "Answer did not meet format requirements" |
| | |
| | |
| | format_issues = [ |
| | "format", "concise", "explanation", "not formatted", |
| | "instead of just", "contains explanations", "FINAL ANSWER" |
| | ] |
| | is_format_issue = any(issue in reason.lower() for issue in format_issues) |
| | |
| | |
| | has_sufficient_info = has_sufficient_information(state) |
| | |
| | if is_format_issue and has_sufficient_info and current_retry_count >= 0: |
| | |
| | return Command( |
| | goto="final_answer", |
| | update={ |
| | "messages": [ |
| | HumanMessage( |
| | content="We have sufficient information but formatting issues. Generating properly formatted answer.", |
| | name="supervisor" |
| | ) |
| | ], |
| | "retry_count": current_retry_count + 1 |
| | } |
| | ) |
| | |
| | |
| | next_retry_count = current_retry_count + 1 |
| | |
| | return Command( |
| | goto="planner", |
| | update={ |
| | "plan": None, |
| | "current_step_index": None, |
| | "draft_answer": None, |
| | "critic_verdict": None, |
| | |
| | "context": context, |
| | "worker_results": worker_results, |
| | |
| | "retry_count": next_retry_count, |
| | |
| | "messages": [ |
| | HumanMessage( |
| | content=f"Retrying with new plan (retry #{next_retry_count}). Reason: {reason}", |
| | name="supervisor" |
| | ) |
| | ] |
| | } |
| | ) |
| | |
| | |
| | plan = state["plan"] |
| | current_step_index = state.get("current_step_index", 0) |
| | |
| | |
| | if current_step_index >= len(plan["steps"]): |
| | |
| | context = state.get("context", {}) |
| | |
| | |
| | worker_results = [] |
| | for worker in WORKERS: |
| | if worker in context: |
| | worker_results.append(f"**{worker.title()}**: {context[worker]}") |
| | |
| | |
| | draft_content = "\n\n".join(worker_results) |
| | |
| | |
| | return Command( |
| | goto="critic", |
| | update={ |
| | "draft_answer": draft_content, |
| | |
| | "messages": [ |
| | HumanMessage( |
| | content="All steps completed. Evaluating the answer.", |
| | name="supervisor" |
| | ) |
| | ] |
| | } |
| | ) |
| | |
| | |
| | current_step = plan["steps"][current_step_index] |
| | worker = current_step["worker"] |
| | instruction = current_step["instruction"] |
| | |
| | |
| | context_info = "" |
| | if state.get("context"): |
| | |
| | relevant_context = {} |
| | |
| | |
| | if worker == "coder" and "researcher" in state["context"]: |
| | relevant_context["researcher"] = state["context"]["researcher"] |
| | |
| | |
| | if worker == "researcher" and "coder" in state["context"]: |
| | |
| | coder_content = state["context"]["coder"] |
| | if len(coder_content) < 100: |
| | relevant_context["coder"] = coder_content |
| | |
| | |
| | context_items = [] |
| | for key, value in relevant_context.items(): |
| | |
| | if len(value) > 200: |
| | |
| | summary = value[:200] |
| | if '.' in summary: |
| | summary = summary.split('.')[0] + '.' |
| | context_items.append(f"Previous {key} found: {summary}...") |
| | else: |
| | context_items.append(f"Previous {key} found: {value}") |
| | |
| | if context_items: |
| | context_info = "\n\nRelevant context: " + "\n".join(context_items) |
| | |
| | |
| | enhanced_instruction = f"{instruction}{context_info}" |
| | |
| | |
| | if worker == "coder": |
| | enhanced_instruction += "\nProvide both your calculation method AND the final result value." |
| | elif worker == "researcher": |
| | enhanced_instruction += "\nFocus on gathering factual information related to the task." |
| | |
| | |
| | messages_update = [ |
| | HumanMessage( |
| | content=f"Step {current_step_index + 1}: {enhanced_instruction}", |
| | name="supervisor" |
| | ) |
| | ] |
| | |
| | |
| | worker_destination = cast(SupervisorDestinations, worker) |
| | |
| | |
| | return Command( |
| | goto=worker_destination, |
| | update={ |
| | "messages": messages_update, |
| | "next": worker, |
| | **state_updates |
| | } |
| | ) |
| |
|
| | def extract_best_answer_from_context(context): |
| | """Extract the best available answer from context. |
| | |
| | This is a generic function to extract answers from any type of question context. |
| | It progressively tries different strategies to find a suitable answer. |
| | |
| | Args: |
| | context: The state context containing worker outputs |
| | |
| | Returns: |
| | Best answer found or "unknown" if nothing suitable is found |
| | """ |
| | answer = "unknown" |
| | |
| | |
| | if "coder" in context: |
| | coder_content = context["coder"] |
| | |
| | |
| | import re |
| | answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", coder_content, re.IGNORECASE) |
| | if answer_match: |
| | return answer_match.group(1).strip() |
| | |
| | |
| | if "researcher" in context: |
| | researcher_content = context["researcher"] |
| | |
| | |
| | import re |
| | |
| | |
| | list_items = re.findall(r"[-•*]\s+([^:\n]+)", researcher_content) |
| | if list_items: |
| | |
| | answer = ",".join(item.strip() for item in list_items) |
| | return answer |
| | |
| | |
| | bold_items = re.findall(r"\*\*([^*]+)\*\*", researcher_content) |
| | if bold_items: |
| | |
| | processed_items = [] |
| | for item in bold_items: |
| | |
| | clean_item = re.sub(r'(^|\s)(a|an|the|is|are|was|were|be|been)(\s|$)', ' ', item) |
| | clean_item = clean_item.strip() |
| | if clean_item and len(clean_item) < 30: |
| | processed_items.append(clean_item) |
| | |
| | if processed_items: |
| | answer = ",".join(processed_items) |
| | return answer |
| | |
| | |
| | combined_content = "" |
| | for worker_type, content in context.items(): |
| | combined_content += " " + content |
| | |
| | |
| | import re |
| | numbers = re.findall(r'\b(\d+(?:\.\d+)?)\b', combined_content) |
| | if numbers: |
| | answer = numbers[0] |
| | |
| | return answer |
| |
|
| | def has_sufficient_information(state): |
| | """Determine if we have enough information to generate a final answer. |
| | |
| | Args: |
| | state: The current conversation state |
| | |
| | Returns: |
| | Boolean indicating if we have sufficient information |
| | """ |
| | context = state.get("context", {}) |
| | |
| | |
| | if "researcher" in context and "coder" in context: |
| | return True |
| | |
| | |
| | if "researcher" in context and len(context["researcher"]) > 150: |
| | return True |
| | |
| | |
| | for worker, content in context.items(): |
| | if content and ( |
| | "- " in content or |
| | "•" in content or |
| | "*" in content or |
| | ":" in content |
| | ): |
| | return True |
| | |
| | return False |