Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| from langgraph.graph import START, StateGraph, MessagesState, END | |
| from langgraph.prebuilt import tools_condition, ToolNode | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
| from langchain_core.tools import tool | |
| from langchain_groq import ChatGroq | |
| from supabase import create_client | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import SupabaseVectorStore | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.documents import Document | |
| import json | |
| import pdfplumber | |
| import pandas as pd | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| from PIL import Image | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import cmath | |
| # from code_interpreter import CodeInterpreter | |
| import uuid | |
| import tempfile | |
| import requests | |
| from urllib.parse import urlparse | |
| from typing import Optional | |
| import io | |
| import contextlib | |
| import base64 | |
| import subprocess | |
| import sqlite3 | |
| import traceback | |
| load_dotenv() | |
| # ------------------- TOOL DEFINITIONS ------------------- | |
| def multiply(a: int, b: int) -> int: | |
| """Multiply two numbers.""" | |
| return a * b | |
| def add(a: int, b: int) -> int: | |
| """Add two numbers.""" | |
| return a + b | |
| def subtract(a: int, b: int) -> int: | |
| """Subtract b from a.""" | |
| return a - b | |
| def divide(a: int, b: int) -> float: | |
| """Divide a by b. Raise error if b is zero.""" | |
| if b == 0: | |
| raise ValueError("Cannot divide by zero.") | |
| return a / b | |
| def modulus(a: int, b: int) -> int: | |
| """Get remainder of a divided by b.""" | |
| return a % b | |
| def square_root(a: float) -> float | complex: | |
| """ | |
| Get the square root of a number. | |
| Args: | |
| a (float): the number to get the square root of | |
| """ | |
| if a >= 0: | |
| return a**0.5 | |
| return cmath.sqrt(a) | |
| def power(a: float, b: float) -> float: | |
| """ | |
| Get the power of two numbers. | |
| Args: | |
| a (float): the first number | |
| b (float): the second number | |
| """ | |
| return a**b | |
| def wiki_search(query: str) -> str: | |
| """Search Wikipedia for a query (max 2 results).""" | |
| docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
| return "\n\n".join([doc.page_content for doc in docs]) | |
| def web_search(query: str) -> str: | |
| """Search the web using Tavily (max 3 results).""" | |
| results = TavilySearchResults(max_results=3).invoke(query) | |
| texts = [doc.get("content", "") or doc.get("text", "") for doc in results if isinstance(doc, dict)] | |
| return "\n\n".join(texts) | |
| def arvix_search(query: str) -> str: | |
| """Search Arxiv for academic papers (max 3 results, truncated to 1000 characters each).""" | |
| docs = ArxivLoader(query=query, load_max_docs=3).load() | |
| return "\n\n".join([doc.page_content[:1000] for doc in docs]) | |
| def read_excel_file(path: str) -> str: | |
| """Read an Excel file and return the first few rows of each sheet as text.""" | |
| import pandas as pd | |
| try: | |
| xls = pd.ExcelFile(path) | |
| content = "" | |
| for sheet in xls.sheet_names: | |
| df = xls.parse(sheet) | |
| content += f"Sheet: {sheet}\n" | |
| content += df.head(5).to_string(index=False) + "\n\n" | |
| return content.strip() | |
| except Exception as e: | |
| return f"Error reading Excel file: {str(e)}" | |
| def extract_text_from_pdf(path: str) -> str: | |
| """Extract text from a PDF file given its local path.""" | |
| try: | |
| text = "" | |
| with pdfplumber.open(path) as pdf: | |
| for page in pdf.pages[:5]: # 限前5页,避免过大 | |
| page_text = page.extract_text() | |
| if page_text: | |
| text += page_text + "\n\n" | |
| return text.strip() if text else "No text extracted from PDF." | |
| except Exception as e: | |
| return f"Error reading PDF: {str(e)}" | |
| # 初始化模型(首次加载可能稍慢) | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| def blip_image_caption(image_path: str) -> str: | |
| """Generate a description for an image using BLIP.""" | |
| try: | |
| image = Image.open(image_path).convert("RGB") | |
| inputs = processor(image, return_tensors="pt") | |
| with torch.no_grad(): | |
| out = model.generate(**inputs) | |
| caption = processor.decode(out[0], skip_special_tokens=True) | |
| return caption | |
| except Exception as e: | |
| return f"Failed to process image with BLIP: {str(e)}" | |
| # @tool | |
| # def execute_code_multilang(code: str, language: str = "python") -> str: | |
| # """Execute code in multiple languages (Python, Bash, SQL, C, Java) and return results. | |
| # Args: | |
| # code (str): The source code to execute. | |
| # language (str): The language of the code. Supported: "python", "bash", "sql", "c", "java". | |
| # Returns: | |
| # A string summarizing the execution results (stdout, stderr, errors, plots, dataframes if any). | |
| # """ | |
| # supported_languages = ["python", "bash", "sql", "c", "java"] | |
| # language = language.lower() | |
| # interpreter_instance = CodeInterpreter() | |
| # if language not in supported_languages: | |
| # return f"❌ Unsupported language: {language}. Supported languages are: {', '.join(supported_languages)}" | |
| # result = interpreter_instance.execute_code(code, language=language) | |
| # response = [] | |
| # if result["status"] == "success": | |
| # response.append(f"✅ Code executed successfully in **{language.upper()}**") | |
| # if result.get("stdout"): | |
| # response.append( | |
| # "\n**Standard Output:**\n```\n" + result["stdout"].strip() + "\n```" | |
| # ) | |
| # if result.get("stderr"): | |
| # response.append( | |
| # "\n**Standard Error (if any):**\n```\n" | |
| # + result["stderr"].strip() | |
| # + "\n```" | |
| # ) | |
| # if result.get("result") is not None: | |
| # response.append( | |
| # "\n**Execution Result:**\n```\n" | |
| # + str(result["result"]).strip() | |
| # + "\n```" | |
| # ) | |
| # if result.get("dataframes"): | |
| # for df_info in result["dataframes"]: | |
| # response.append( | |
| # f"\n**DataFrame `{df_info['name']}` (Shape: {df_info['shape']})**" | |
| # ) | |
| # df_preview = pd.DataFrame(df_info["head"]) | |
| # response.append("First 5 rows:\n```\n" + str(df_preview) + "\n```") | |
| # if result.get("plots"): | |
| # response.append( | |
| # f"\n**Generated {len(result['plots'])} plot(s)** (Image data returned separately)" | |
| # ) | |
| # else: | |
| # response.append(f"❌ Code execution failed in **{language.upper()}**") | |
| # if result.get("stderr"): | |
| # response.append( | |
| # "\n**Error Log:**\n```\n" + result["stderr"].strip() + "\n```" | |
| # ) | |
| # return "\n".join(response) | |
| def save_and_read_file(content: str, filename: Optional[str] = None) -> str: | |
| """ | |
| Save content to a file and return the path. | |
| Args: | |
| content (str): the content to save to the file | |
| filename (str, optional): the name of the file. If not provided, a random name file will be created. | |
| """ | |
| temp_dir = tempfile.gettempdir() | |
| if filename is None: | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) | |
| filepath = temp_file.name | |
| else: | |
| filepath = os.path.join(temp_dir, filename) | |
| with open(filepath, "w") as f: | |
| f.write(content) | |
| return f"File saved to {filepath}. You can read this file to process its contents." | |
| def download_file_from_url(url: str, filename: Optional[str] = None) -> str: | |
| """ | |
| Download a file from a URL and save it to a temporary location. | |
| Args: | |
| url (str): the URL of the file to download. | |
| filename (str, optional): the name of the file. If not provided, a random name file will be created. | |
| """ | |
| try: | |
| # Parse URL to get filename if not provided | |
| if not filename: | |
| path = urlparse(url).path | |
| filename = os.path.basename(path) | |
| if not filename: | |
| filename = f"downloaded_{uuid.uuid4().hex[:8]}" | |
| # Create temporary file | |
| temp_dir = tempfile.gettempdir() | |
| filepath = os.path.join(temp_dir, filename) | |
| # Download the file | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| # Save the file | |
| with open(filepath, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| return f"File downloaded to {filepath}. You can read this file to process its contents." | |
| except Exception as e: | |
| return f"Error downloading file: {str(e)}" | |
| def analyze_csv_file(file_path: str, query: str) -> str: | |
| """ | |
| Analyze a CSV file using pandas and answer a question about it. | |
| Args: | |
| file_path (str): the path to the CSV file. | |
| query (str): Question about the data | |
| """ | |
| try: | |
| # Read the CSV file | |
| df = pd.read_csv(file_path) | |
| # Run various analyses based on the query | |
| result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n" | |
| result += f"Columns: {', '.join(df.columns)}\n\n" | |
| # Add summary statistics | |
| result += "Summary statistics:\n" | |
| result += str(df.describe()) | |
| return result | |
| except Exception as e: | |
| return f"Error analyzing CSV file: {str(e)}" | |
| def execute_code_multilang(code: str, language: str = "python") -> str: | |
| """ | |
| Execute code in Python, Bash, SQL, C, or Java and return formatted results. | |
| Args: | |
| code (str): Source code. | |
| language (str): Language of the code. One of: 'python', 'bash', 'sql', 'c', 'java'. | |
| Returns: | |
| str: Human-readable execution result. | |
| """ | |
| language = language.lower() | |
| exec_id = str(uuid.uuid4()) | |
| result = { | |
| "stdout": "", | |
| "stderr": "", | |
| "status": "error", | |
| "plots": [], | |
| "dataframes": [], | |
| } | |
| try: | |
| if language == "python": | |
| plt.switch_backend("Agg") | |
| stdout_buffer = io.StringIO() | |
| stderr_buffer = io.StringIO() | |
| globals_dict = {"pd": pd, "plt": plt, "Image": Image} | |
| with contextlib.redirect_stdout(stdout_buffer), contextlib.redirect_stderr(stderr_buffer): | |
| exec(code, globals_dict) | |
| # Save plots | |
| if plt.get_fignums(): | |
| for i, fig_num in enumerate(plt.get_fignums()): | |
| fig = plt.figure(fig_num) | |
| img_path = os.path.join(tempfile.gettempdir(), f"{exec_id}_plot_{i}.png") | |
| fig.savefig(img_path) | |
| with open(img_path, "rb") as f: | |
| img_data = base64.b64encode(f.read()).decode() | |
| result["plots"].append(img_data) | |
| # Check for dataframes | |
| for var_name, var_val in globals_dict.items(): | |
| if isinstance(var_val, pd.DataFrame): | |
| result["dataframes"].append((var_name, var_val.head().to_string())) | |
| result["stdout"] = stdout_buffer.getvalue() | |
| result["stderr"] = stderr_buffer.getvalue() | |
| result["status"] = "success" | |
| elif language == "bash": | |
| completed = subprocess.run(code, shell=True, capture_output=True, text=True, timeout=30) | |
| result["stdout"] = completed.stdout | |
| result["stderr"] = completed.stderr | |
| result["status"] = "success" if completed.returncode == 0 else "error" | |
| elif language == "sql": | |
| conn = sqlite3.connect(":memory:") | |
| cur = conn.cursor() | |
| cur.execute(code) | |
| if code.strip().lower().startswith("select"): | |
| cols = [desc[0] for desc in cur.description] | |
| rows = cur.fetchall() | |
| df = pd.DataFrame(rows, columns=cols) | |
| result["dataframes"].append(("query_result", df.head().to_string())) | |
| conn.commit() | |
| conn.close() | |
| result["status"] = "success" | |
| result["stdout"] = "SQL executed successfully." | |
| elif language == "c": | |
| with tempfile.TemporaryDirectory() as tmp: | |
| src = os.path.join(tmp, "main.c") | |
| bin_path = os.path.join(tmp, "main") | |
| with open(src, "w") as f: | |
| f.write(code) | |
| comp = subprocess.run(["gcc", src, "-o", bin_path], capture_output=True, text=True) | |
| if comp.returncode != 0: | |
| result["stderr"] = comp.stderr | |
| else: | |
| run = subprocess.run([bin_path], capture_output=True, text=True, timeout=30) | |
| result["stdout"] = run.stdout | |
| result["stderr"] = run.stderr | |
| result["status"] = "success" if run.returncode == 0 else "error" | |
| elif language == "java": | |
| with tempfile.TemporaryDirectory() as tmp: | |
| src = os.path.join(tmp, "Main.java") | |
| with open(src, "w") as f: | |
| f.write(code) | |
| comp = subprocess.run(["javac", src], capture_output=True, text=True) | |
| if comp.returncode != 0: | |
| result["stderr"] = comp.stderr | |
| else: | |
| run = subprocess.run(["java", "-cp", tmp, "Main"], capture_output=True, text=True, timeout=30) | |
| result["stdout"] = run.stdout | |
| result["stderr"] = run.stderr | |
| result["status"] = "success" if run.returncode == 0 else "error" | |
| else: | |
| return f"❌ Unsupported language: {language}." | |
| except Exception as e: | |
| result["stderr"] = traceback.format_exc() | |
| # Format response | |
| summary = [] | |
| if result["status"] == "success": | |
| summary.append(f"✅ Code executed successfully in **{language.upper()}**") | |
| if result["stdout"]: | |
| summary.append(f"\n**Output:**\n```\n{result['stdout'].strip()}\n```") | |
| if result["stderr"]: | |
| summary.append(f"\n**Warnings/Errors:**\n```\n{result['stderr'].strip()}\n```") | |
| for name, df in result["dataframes"]: | |
| summary.append(f"\n**DataFrame `{name}` Preview:**\n```\n{df}\n```") | |
| if result["plots"]: | |
| summary.append(f"\n📊 {len(result['plots'])} plot(s) generated (base64-encoded).") | |
| else: | |
| summary.append(f"❌ Execution failed for **{language.upper()}**") | |
| if result["stderr"]: | |
| summary.append(f"\n**Error:**\n```\n{result['stderr'].strip()}\n```") | |
| return "\n".join(summary) | |
| tools = [multiply, add, subtract, divide, modulus, | |
| wiki_search, web_search, arvix_search, read_excel_file, extract_text_from_pdf, | |
| blip_image_caption, save_and_read_file, download_file_from_url, analyze_csv_file, | |
| execute_code_multilang] | |
| # ------------------- SYSTEM PROMPT ------------------- | |
| system_prompt_path = "system_prompt.txt" | |
| if os.path.exists(system_prompt_path): | |
| with open(system_prompt_path, "r", encoding="utf-8") as f: | |
| system_prompt = f.read() | |
| else: | |
| system_prompt = ( | |
| "You are an intelligent AI agent who can solve math, science, factual, and research-based problems. " | |
| "You can use tools like Wikipedia, Web search, or Arxiv when needed. Always give precise and helpful answers." | |
| ) | |
| sys_msg = SystemMessage(content=system_prompt) | |
| # ------------------- GRAPH CONSTRUCTION ------------------- | |
| def build_graph(provider: str = "groq"): | |
| if provider == "google": | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0) | |
| elif provider == "groq": | |
| groq_key = os.getenv("GROQ_API_KEY") | |
| if not groq_key: | |
| raise ValueError("GROQ_API_KEY is not set.") | |
| llm = ChatGroq(model="qwen-qwq-32b", temperature=0, api_key=groq_key) | |
| elif provider == "huggingface": | |
| llm = ChatHuggingFace( | |
| llm=HuggingFaceEndpoint( | |
| url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf", | |
| temperature=0 | |
| ) | |
| ) | |
| elif provider == "openai": | |
| openai_key = os.getenv("OPENAI_API_KEY") | |
| if not openai_key: | |
| raise ValueError("OPENAI_API_KEY is not set.") | |
| llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=openai_key) | |
| else: | |
| raise ValueError("Invalid provider") | |
| llm_with_tools = llm.bind_tools(tools) | |
| def assistant(state: MessagesState): | |
| return {"messages": [sys_msg] + [llm_with_tools.invoke(state["messages"])]} | |
| SUPABASE_URL = os.getenv("SUPABASE_URL") | |
| SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY") | |
| supabase = create_client(SUPABASE_URL, SUPABASE_KEY) | |
| embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
| vectorstore = SupabaseVectorStore( | |
| client=supabase, | |
| embedding=embedding_model, | |
| table_name="QA_db" | |
| ) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 1}) | |
| # ✅ 替换 similarity_search_by_vector_with_relevance_scores 方法,直接调用 supabase.rpc | |
| original_fn = vectorstore.similarity_search_by_vector_with_relevance_scores | |
| # ✅ 覆盖 vectorstore 的方法 | |
| def patched_fn(embedding, k=4, filter=None, **kwargs): | |
| response = supabase.rpc( | |
| "match_documents", | |
| { | |
| "query_embedding": embedding, | |
| "match_count": k | |
| } | |
| ).execute() | |
| documents = [] | |
| for r in response.data: | |
| metadata = r["metadata"] | |
| if isinstance(metadata, str): | |
| try: | |
| metadata = json.loads(metadata) | |
| except Exception: | |
| metadata = {} | |
| doc = Document( | |
| page_content=r["content"], | |
| metadata=metadata | |
| ) | |
| documents.append((doc, r["similarity"])) | |
| return documents | |
| # ✅ 覆盖 vectorstore 的方法 | |
| vectorstore.similarity_search_by_vector_with_relevance_scores = patched_fn | |
| def qa_retriever_node(state: MessagesState): | |
| user_question = state["messages"][-1].content | |
| docs = retriever.invoke(user_question) | |
| if docs: | |
| return { | |
| "messages": state["messages"] + [AIMessage(content=docs[0].page_content)], | |
| "__condition__": "complete" | |
| } | |
| return {"messages": state["messages"], "__condition__": "default"} | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("retriever", qa_retriever_node) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(tools)) | |
| builder.add_edge(START, "retriever") | |
| builder.add_conditional_edges("retriever", { | |
| "default": lambda x: "assistant", | |
| "complete": lambda x: END, | |
| }) | |
| builder.add_conditional_edges("assistant", tools_condition) | |
| builder.add_edge("tools", "assistant") | |
| return builder.compile() | |
| # ------------------- LOCAL TEST ------------------- | |
| if __name__ == "__main__": | |
| question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" | |
| graph = build_graph(provider="openai") | |
| messages = graph.invoke({"messages": [HumanMessage(content=question)]}) | |
| print("=== AI Agent Response ===") | |
| for m in messages["messages"]: | |
| m.pretty_print() | |