Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import io | |
| import re | |
| import json | |
| import base64 | |
| import traceback | |
| from typing import Optional, List, Dict, Any | |
| import cmath | |
| import requests | |
| import pandas as pd | |
| from PIL import Image | |
| # Optional deps | |
| try: | |
| import pdfplumber | |
| except Exception: | |
| pdfplumber = None | |
| try: | |
| import pytesseract | |
| except Exception: | |
| pytesseract = None | |
| try: | |
| import sympy as sp | |
| except Exception: | |
| sp = None | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| except Exception: | |
| plt = None | |
| try: | |
| from pint import UnitRegistry | |
| _ureg = UnitRegistry() | |
| except Exception: | |
| _ureg = None | |
| try: | |
| from dateutil import parser as dtparser | |
| from dateutil.relativedelta import relativedelta | |
| except Exception: | |
| dtparser = None | |
| relativedelta = None | |
| # LangChain bits | |
| from langchain_core.tools import tool | |
| from langchain_tavily.tavily_search import TavilySearch | |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
| from langchain_experimental.utilities import PythonREPL | |
| # ------------ helpers (formatting, env, truncation, errors) ------------ | |
| def _env(name: str, default: Optional[str] = None) -> Optional[str]: | |
| return os.getenv(name, default) | |
| def _truncate(txt: str, max_len: int = 4000) -> str: | |
| if txt is None: | |
| return "" | |
| if len(txt) <= max_len: | |
| return txt | |
| head = max_len - 200 | |
| return txt[:head] + f"\n... [truncated {len(txt) - head} chars]" | |
| def _fmt_block(tag: str, attrs: Dict[str, Any] | None, body: str) -> str: | |
| attrs = attrs or {} | |
| attr_str = " ".join(f'{k}="{attrs[k] if attrs[k] is not None else ""}"' for k in attrs) | |
| if attr_str: | |
| return f"<{tag} {attr_str}>\n{body}\n</{tag}>" | |
| return f"<{tag}>\n{body}\n</{tag}>" | |
| def _fmt_error(tool_name: str, err: Exception) -> str: | |
| return _fmt_block( | |
| "ToolError", | |
| {"tool": tool_name, "type": err.__class__.__name__}, | |
| _truncate(f"{err}\n{traceback.format_exc()}", 1600), | |
| ) | |
| def web_search(query: str) -> str: | |
| """Search the web (Tavily). Returns up to 3 results in <Document> blocks + optional <Answer>.""" | |
| try: | |
| api_key = _env("TAVILY_API_KEY") | |
| tavily = TavilySearch( | |
| tavily_api_key=api_key, | |
| max_results=3, | |
| include_answer=True, | |
| search_depth="advanced", | |
| topic="general", | |
| include_raw_content=False, | |
| ) | |
| res = tavily._run(query=query) | |
| docs = res.get("results", []) or [] | |
| blocks: List[str] = [] | |
| for d in docs: | |
| blocks.append( | |
| _fmt_block( | |
| "Document", | |
| {"source": d.get("url", ""), "title": d.get("title", "")}, | |
| _truncate(d.get("content", "") or ""), | |
| ) | |
| ) | |
| parts = [] | |
| ans = res.get("answer") | |
| if ans: | |
| parts.append(_fmt_block("Answer", {}, _truncate(str(ans), 1000))) | |
| if blocks: | |
| parts.append("\n\n---\n\n".join(blocks)) | |
| if not parts: | |
| return _fmt_block("WebResults", {"query": query}, "No results.") | |
| return _fmt_block("WebResults", {"query": query}, "\n\n".join(parts)) | |
| except Exception as e: | |
| return _fmt_error("web_search", e) | |
| def wiki_search(query: str) -> str: | |
| """Search Wikipedia and return up to 2 articles as <Document> blocks.""" | |
| try: | |
| loader = WikipediaLoader(query=query, load_max_docs=2) | |
| docs = loader.load() | |
| blocks = [] | |
| for doc in docs: | |
| meta = getattr(doc, "metadata", {}) or {} | |
| blocks.append( | |
| _fmt_block( | |
| "Document", | |
| {"source": meta.get("source", ""), "page": str(meta.get("page", ""))}, | |
| _truncate(doc.page_content or "", 4000), | |
| ) | |
| ) | |
| if not blocks: | |
| return _fmt_block("WikiResults", {"query": query}, "No results.") | |
| return _fmt_block("WikiResults", {"query": query}, "\n\n---\n\n".join(blocks)) | |
| except Exception as e: | |
| return _fmt_error("wiki_search", e) | |
| def academic_search(query: str) -> str: | |
| """Search arXiv and return up to 3 papers as <Document> blocks.""" | |
| try: | |
| loader = ArxivLoader(query=query, load_max_docs=3) | |
| docs = loader.load() | |
| blocks = [] | |
| for doc in docs: | |
| meta = getattr(doc, "metadata", {}) or {} | |
| title = meta.get("Title") or meta.get("title") or "" | |
| published = meta.get("Published") or meta.get("published") or "" | |
| blocks.append( | |
| _fmt_block( | |
| "Document", | |
| {"title": str(title), "date": str(published)}, | |
| _truncate(doc.page_content or "", 3000), | |
| ) | |
| ) | |
| if not blocks: | |
| return _fmt_block("ArxivResults", {"query": query}, "No results.") | |
| return _fmt_block("ArxivResults", {"query": query}, "\n\n---\n\n".join(blocks)) | |
| except Exception as e: | |
| return _fmt_error("academic_search", e) | |
| def python_code(code: str) -> str: | |
| """ | |
| Execute Python code in a sandboxed REPL. | |
| Input should be valid Python code. | |
| """ | |
| try: | |
| if code is None: | |
| return "<ToolError>No code provided to python_code tool.</ToolError>" | |
| python_repl = PythonREPL() | |
| code_str = str(code) | |
| output = python_repl.run(code_str) | |
| if output is None: | |
| return "(no output)" | |
| return str(output).strip("\n") | |
| except Exception as e: | |
| return f"<ToolError>{type(e).__name__}: {e}</ToolError>" | |
| def image_info(path: str) -> str: | |
| """Return basic info about an image (width x height, format).""" | |
| try: | |
| with Image.open(path) as img: | |
| w, h, fmt = img.width, img.height, img.format | |
| return _fmt_block("ImageInfo", {"path": path}, f"{w}x{h} ({fmt})") | |
| except Exception as e: | |
| return _fmt_error("image_info", e) | |
| def read_mp3_transcript(path: str) -> str: | |
| """Transcribe an MP3 file (placeholder). Replace with actual ASR.""" | |
| try: | |
| return _fmt_block("AudioTranscript", {"path": path}, "Transcription not implemented.") | |
| except Exception as e: | |
| return _fmt_error("read_mp3_transcript", e) | |
| def ocr_image(path: str) -> str: | |
| """Run OCR on an image and return extracted text (requires pytesseract + Tesseract installed).""" | |
| try: | |
| if pytesseract is None: | |
| raise RuntimeError("pytesseract not installed or tesseract binary missing") | |
| with Image.open(path) as img: | |
| txt = pytesseract.image_to_string(img) | |
| return _fmt_block("OCRText", {"path": path}, _truncate(txt.strip() or "(no text)", 4000)) | |
| except Exception as e: | |
| return _fmt_error("ocr_image", e) | |
| def math_solver(expr: str) -> str: | |
| """Solve/compute a math expression with SymPy. Examples: | |
| - 'integrate(sin(x)/x, (x, 0, 1))' | |
| - 'solve(x**2 - 5*x + 6, x)' | |
| - 'simplify((x**2 - 1)/(x - 1))'""" | |
| try: | |
| if sp is None: | |
| raise RuntimeError("sympy not installed") | |
| # Safe-ish sympify (no symbols defined → define common ones) | |
| symbols = {s: sp.symbols(s) for s in list("abcdefghijklmnopqrstuvwxyz")} | |
| res = sp.sympify(expr, locals=symbols) | |
| # Try evalf if numeric | |
| out = res.evalf() if hasattr(res, "evalf") else res | |
| return _fmt_block("MathResult", {}, _truncate(str(out), 4000)) | |
| except Exception as e: | |
| return _fmt_error("math_solver", e) | |
| def plot_data_tool(args: str) -> str: | |
| """Create a simple plot from CSV. | |
| Usage (JSON string): | |
| { | |
| "path": "data.csv", | |
| "x": "col_x", # optional if data has index | |
| "y": "col_y", # required | |
| "kind": "line" # 'line' or 'scatter' | |
| } | |
| Returns <ImageBase64> PNG.""" | |
| try: | |
| if plt is None: | |
| raise RuntimeError("matplotlib not available") | |
| cfg = json.loads(args) | |
| path = cfg.get("path") | |
| xcol = cfg.get("x") | |
| ycol = cfg["y"] | |
| kind = (cfg.get("kind") or "line").lower() | |
| df = pd.read_csv(path) | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| if kind == "scatter": | |
| ax.scatter(df[xcol] if xcol else range(len(df[ycol])), df[ycol]) | |
| else: | |
| ax.plot(df[xcol] if xcol else range(len(df[ycol])), df[ycol]) | |
| ax.set_xlabel(xcol or "index") | |
| ax.set_ylabel(ycol) | |
| ax.set_title(f"{os.path.basename(path)}: {ycol} vs {xcol or 'index'}") | |
| buf = io.BytesIO() | |
| fig.tight_layout() | |
| fig.savefig(buf, format="png") | |
| plt.close(fig) | |
| b64 = base64.b64encode(buf.getvalue()).decode("utf-8") | |
| return _fmt_block("ImageBase64", {"format": "png"}, b64) | |
| except Exception as e: | |
| return _fmt_error("plot_data_tool", e) | |
| def unit_converter(query: str) -> str: | |
| """Convert units. Examples: | |
| - '12 inch to cm' | |
| - '5 miles to kilometer' | |
| - '32 degF to degC'""" | |
| try: | |
| if _ureg is None: | |
| raise RuntimeError("pint not installed") | |
| m = re.match(r"\s*([\-0-9\.]+)\s+([A-Za-z\/\^\*\s]+)\s+to\s+([A-Za-z\/\^\*\s]+)\s*$", query) | |
| if not m: | |
| raise ValueError("Format: '<value> <from_units> to <to_units>'") | |
| val = float(m.group(1)) | |
| from_u = m.group(2).strip() | |
| to_u = m.group(3).strip() | |
| qty = val * _ureg(from_u) | |
| conv = qty.to(to_u) | |
| return _fmt_block("UnitConversion", {"from": from_u, "to": to_u}, f"{conv.magnitude} {conv.units}") | |
| except Exception as e: | |
| return _fmt_error("unit_converter", e) | |
| def date_time_calculator(query: str) -> str: | |
| """Date/time math. Examples: | |
| - 'diff 2024-01-01 2025-08-14' → difference (y,m,d) | |
| - 'add 2025-08-14 + 3 days' → add delta | |
| - 'add 2025-08-14 - 2 weeks' → subtract delta | |
| Accepts ISO dates; units: years, months, weeks, days, hours, minutes.""" | |
| try: | |
| if dtparser is None or relativedelta is None: | |
| raise RuntimeError("python-dateutil not installed") | |
| s = query.strip() | |
| if s.lower().startswith("diff"): | |
| parts = s.split() | |
| if len(parts) != 3: | |
| raise ValueError("Use: diff YYYY-MM-DD YYYY-MM-DD") | |
| d1 = dtparser.parse(parts[1]) | |
| d2 = dtparser.parse(parts[2]) | |
| rd = relativedelta(d2, d1) | |
| return _fmt_block( | |
| "DateDiff", | |
| {"from": parts[1], "to": parts[2]}, | |
| f"{rd.years} years, {rd.months} months, {rd.days} days, " | |
| f"{rd.hours} hours, {rd.minutes} minutes", | |
| ) | |
| elif s.lower().startswith("add"): | |
| # e.g., "add 2025-08-14 + 3 days" or "add 2025-08-14 - 2 weeks" | |
| m = re.match(r"add\s+(\S+)\s*([+-])\s*(\d+)\s+(years?|months?|weeks?|days?|hours?|minutes?)", s, re.I) | |
| if not m: | |
| raise ValueError("Use: add <date> +/- <n> <unit>") | |
| base = dtparser.parse(m.group(1)) | |
| sign = 1 if m.group(2) == "+" else -1 | |
| n = int(m.group(3)) * sign | |
| unit = m.group(4).lower() | |
| kwargs = {} | |
| if "year" in unit: kwargs["years"] = n | |
| elif "month" in unit: kwargs["months"] = n | |
| elif "week" in unit: kwargs["weeks"] = n | |
| elif "day" in unit: kwargs["days"] = n | |
| elif "hour" in unit: kwargs["hours"] = n | |
| elif "minute" in unit: kwargs["minutes"] = n | |
| res = base + relativedelta(**kwargs) | |
| return _fmt_block("DateAdd", {"base": base.isoformat(), "delta": f"{n} {unit}"}, res.isoformat()) | |
| else: | |
| raise ValueError("Start with 'diff' or 'add'.") | |
| except Exception as e: | |
| return _fmt_error("date_time_calculator", e) | |
| def api_request_tool(args: str) -> str: | |
| """Call a JSON REST API. | |
| Usage (JSON string): | |
| { | |
| "method": "GET", | |
| "url": "https://api.example.com/items", | |
| "headers": {"Authorization": "Bearer ..."}, | |
| "params": {"q": "search"}, | |
| "json": {"k": "v"}, | |
| "timeout": 20 | |
| }""" | |
| try: | |
| cfg = json.loads(args) | |
| method = (cfg.get("method") or "GET").upper() | |
| url = cfg["url"] | |
| headers = cfg.get("headers") or {} | |
| params = cfg.get("params") or {} | |
| json_body = cfg.get("json") | |
| timeout = cfg.get("timeout", 20) | |
| resp = requests.request(method, url, headers=headers, params=params, json=json_body, timeout=timeout) | |
| meta = {"status": resp.status_code, "url": url} | |
| text = resp.text | |
| # Try to pretty JSON if possible | |
| try: | |
| text = json.dumps(resp.json(), indent=2)[:4000] | |
| except Exception: | |
| text = _truncate(text, 4000) | |
| return _fmt_block("APIResponse", meta, text) | |
| except Exception as e: | |
| return _fmt_error("api_request_tool", e) | |
| def html_table_extractor(url: str) -> str: | |
| """Extract the first HTML table from a webpage and return CSV preview.""" | |
| try: | |
| tables = pd.read_html(url) | |
| if not tables: | |
| return _fmt_block("HTMLTable", {"url": url}, "No tables found.") | |
| df = tables[0] | |
| buf = io.StringIO() | |
| df.head(15).to_csv(buf, index=False) | |
| summary = f"Shape: {df.shape[0]} rows x {df.shape[1]} cols\nColumns: {list(df.columns)}\n\nHead(15):\n{buf.getvalue()}" | |
| return _fmt_block("HTMLTable", {"url": url}, _truncate(summary, 4000)) | |
| except Exception as e: | |
| return _fmt_error("html_table_extractor", e) | |
| def multiply(a: float, b: float) -> float: | |
| """ | |
| Multiplies two numbers. | |
| Args: | |
| a (float): the first number | |
| b (float): the second number | |
| """ | |
| return a * b | |
| def add(a: float, b: float) -> float: | |
| """ | |
| Adds two numbers. | |
| Args: | |
| a (float): the first number | |
| b (float): the second number | |
| """ | |
| return a + b | |
| def subtract(a: float, b: float) -> int: | |
| """ | |
| Subtracts two numbers. | |
| Args: | |
| a (float): the first number | |
| b (float): the second number | |
| """ | |
| return a - b | |
| def divide(a: float, b: float) -> float: | |
| """ | |
| Divides two numbers. | |
| Args: | |
| a (float): the first float number | |
| b (float): the second float number | |
| """ | |
| if b == 0: | |
| raise ValueError("Cannot divided by zero.") | |
| return a / b | |
| def modulus(a: int, b: int) -> int: | |
| """ | |
| Get the modulus of two numbers. | |
| Args: | |
| a (int): the first number | |
| b (int): the second number | |
| """ | |
| return a % b | |
| def power(a: float, b: float) -> float: | |
| """ | |
| Get the power of two numbers. | |
| Args: | |
| a (float): the first number | |
| b (float): the second number | |
| """ | |
| return a**b | |
| def square_root(a: float) -> float | complex: | |
| """ | |
| Get the square root of a number. | |
| Args: | |
| a (float): the number to get the square root of | |
| """ | |
| if a >= 0: | |
| return a**0.5 | |
| return cmath.sqrt(a) | |