GAIAgent / app.py
antmontieri's picture
Update app.py
407b028 verified
import os
import json
import requests
import pandas as pd
import gradio as gr
import time
from typing import Dict, Any, List
# --- Import Agent and Tools (These must be present in agent.py and tools.py) ---
try:
from agent import GAIAgent
except ImportError:
print("FATAL: Cannot import GAIAgent. Ensure agent.py exists and is in the PYTHONPATH.")
# --- PERSISTENCE LIBRARIES ---
# Required for persistent caching across container rebuilds
from huggingface_hub import HfApi, HfFileSystem
from huggingface_hub.utils import RepositoryNotFoundError
# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
CACHE_FILE = "answers_cache.json"
HF_DATASET_ID = os.environ.get("DATASET_ID", "antmontieri/gaia-agent-cache")
# --- RATE LIMIT & ROBUSTNESS CONFIGURATION ---
MAX_RETRIES = 2
REQUEST_DELAY = 10 # Grace Time (seconds) to wait between starting tasks
ERROR_DELAY = 60 # Pause after an API error (Backoff)
HF_REPO_ID = os.environ.get("SPACE_ID", "local-test")
# =============================================================================
# 0. CACHE & PERSISTENCE HELPERS
# =============================================================================
def load_cache_as_dict() -> Dict[str, str]:
"""Loads cache from local filesystem into a dictionary for quick lookup."""
if os.path.exists(CACHE_FILE):
try:
with open(CACHE_FILE, "r") as f:
data = json.load(f)
return {item["task_id"]: item["submitted_answer"] for item in data}
except Exception as e:
print(f"Error reading local cache: {e}")
return {}
def save_cache_local(data: List[Dict[str, str]]):
"""Saves data to local JSON file."""
try:
with open(CACHE_FILE, "w") as f:
json.dump(data, f, indent=2)
except Exception as e:
print(f"Error saving local cache: {e}")
def load_persistent_cache() -> Dict[str, str]:
"""
Tries to load cache from the Hugging Face Dataset or local copy.
"""
# 1. Load from local copy if running (fastest)
local_dict = load_cache_as_dict()
if local_dict:
return local_dict
# 2. Attempt to load from the Dataset repo
fs = HfFileSystem()
dataset_file_path = f"datasets/{HF_DATASET_ID}/{CACHE_FILE}"
try:
if fs.exists(dataset_file_path):
print(f"📥 Loading cache from Dataset: {HF_DATASET_ID}")
with fs.open(dataset_file_path, "r") as f:
data = json.load(f)
save_cache_local(data) # Update local file
return {item["task_id"]: item["submitted_answer"] for item in data}
except Exception as e:
print(f"⚠️ Error loading from Dataset: {e}")
return {}
def commit_cache_to_repo(data: List[Dict[str, str]]):
"""Saves locally and pushes update to the Hugging Face Dataset."""
# 1. Save locally first
save_cache_local(data)
# 2. Get Token
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
if not hf_token:
print("WARNING: HF_TOKEN secret missing. Cannot sync with Dataset.")
return
try:
api = HfApi(token=hf_token)
print(f"☁️ Syncing {len(data)} answers to Dataset: {HF_DATASET_ID}...")
api.upload_file(
path_or_fileobj=CACHE_FILE,
path_in_repo=CACHE_FILE,
repo_id=HF_DATASET_ID,
repo_type="dataset",
commit_message=f"Agent update: {len(data)} answers generated."
)
print("✅ Sync complete.")
except RepositoryNotFoundError:
print(f"❌ Error: Dataset {HF_DATASET_ID} not found. Did you create it?")
except Exception as e:
print(f"❌ Sync Error: {e}")
# ... (rest of the helpers, fetch_questions is external to the class logic) ...
def fetch_questions():
"""Fetches questions from the API."""
try:
response = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15)
response.raise_for_status()
return response.json()
except Exception as e:
print(f"Error fetching questions: {e}")
return []
def is_valid_answer(answer: str) -> bool:
"""Checks if the answer meets minimum criteria (not chatter/error/too long)."""
if not answer: return False
if "AGENT ERROR" in answer or "Error:" in answer: return False
if len(answer) > 150: return False
forbidden_phrases = ["I cannot", "I am sorry", "The answer is", "Based on"]
for phrase in forbidden_phrases:
if phrase.lower() in answer.lower():
if "not known" in answer.lower(): return True
return False
return True
# =============================================================================
# 1. GENERATOR FUNCTION (With Grace Time & Retry)
# =============================================================================
def generate_answers(profile: gr.OAuthProfile | None, *args):
"""
Runs the agent with strategic pauses and streams updates to the UI.
Passes file_name to the agent for multimodal tasks.
"""
# ... (omessa la prima parte di setup e inizializzazione) ...
if not profile:
yield "⚠️ Please login with Hugging Face first!", pd.DataFrame()
return
# Load state from persistence
existing_answers = load_persistent_cache()
questions = fetch_questions()
if not questions:
yield "❌ Failed to fetch questions.", pd.DataFrame()
return
# Initialize Agent
try:
agent = GAIAgent()
except Exception as e:
yield f"❌ Error initializing agent: {e}", pd.DataFrame()
return
results_log = []
payload_cache = []
total = len(questions)
yield f"🚀 Starting execution on {total} tasks (Grace Time: {REQUEST_DELAY}s)...", pd.DataFrame()
for i, item in enumerate(questions):
task_id = item.get("task_id")
question_text = item.get("question")
# --- AGGIUNTA FONDAMENTALE: RECUPERA file_name ---
file_name = item.get("file_name")
# ----------------------------------------------------
# --- CACHE CHECK ---
if task_id in existing_answers and is_valid_answer(existing_answers[task_id]):
answer = existing_answers[task_id]
# --- AGGIORNA VISUALIZZAZIONE con FILE ---
results_log.append({
"Task ID": task_id,
"Status": "⚡ Cached",
"Question": question_text[:50] + "...",
"File": file_name if file_name else "N/A", # Visualizza il nome del file
"Answer": answer
})
payload_cache.append({"task_id": task_id, "submitted_answer": answer})
yield f"⚡ Task {i+1} Cached.", pd.DataFrame(results_log)
continue
# --- REAL GENERATION (Slow) ---
final_answer = "NOT KNOWN"
success = False
for attempt in range(MAX_RETRIES):
try:
status_msg = f"⏳ Task {i+1}/{total} (Try {attempt+1}/{MAX_RETRIES})..."
yield status_msg, pd.DataFrame(results_log)
# CHIAMATA AGENTE: Passa il file_name al metodo __call__ dell'agente
answer = agent(question_text, file_name=file_name)
# VALIDATION
if is_valid_answer(answer):
final_answer = answer
success = True
# --- GRACE TIME AFTER SUCCESS ---
wait_time = REQUEST_DELAY
yield f"✅ Success. Cooling down for {wait_time}s...", pd.DataFrame(results_log)
time.sleep(wait_time)
break
else:
print(f"⚠️ Invalid answer on attempt {attempt+1}: {answer}")
# --- BACKOFF AFTER INVALID FORMAT ---
yield f"⚠️ Invalid format. Retrying in {ERROR_DELAY}s...", pd.DataFrame(results_log)
time.sleep(ERROR_DELAY)
except Exception as e:
print(f"❌ Error on attempt {attempt+1}: {e}")
# --- BACKOFF AFTER ERROR (e.g., Rate Limit) ---
yield f"❌ API Error. Pausing {ERROR_DELAY}s...", pd.DataFrame(results_log)
time.sleep(ERROR_DELAY)
# Save Result
status_label = "✅ Generated" if success else "⚠️ Failed"
# --- AGGIORNA VISUALIZZAZIONE con FILE ---
results_log.append({
"Task ID": task_id,
"Status": status_label,
"Question": question_text[:50] + "...",
"File": file_name if file_name else "N/A", # Aggiunto per il debug
"Answer": final_answer
})
payload_cache.append({"task_id": task_id, "submitted_answer": final_answer})
# Save progress to Git
commit_cache_to_repo(payload_cache)
yield f"🏁 Logic Complete Task {i+1}", pd.DataFrame(results_log)
yield "🎉 Generation Complete! Click 'Submit' to send results.", pd.DataFrame(results_log)
# =============================================================================
# 2. SUBMIT FUNCTION (Fast Process)
# =============================================================================
def submit_to_leaderboard(profile: gr.OAuthProfile | None, *args):
"""
Reads the persistent cache and submits the final payload to the leaderboard.
"""
if not profile: return "⚠️ Please login first!"
# Load the latest committed data
answers_dict = load_persistent_cache()
if not answers_dict: return "⚠️ No answers found. Please generate answers first."
# Prepare payload for submission
answers = [{"task_id": k, "submitted_answer": v} for k, v in answers_dict.items()]
username = profile.username
space_id = os.getenv("SPACE_ID", "local-test")
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
submission_data = {"username": username, "agent_code": agent_code, "answers": answers}
try:
response = requests.post(f"{DEFAULT_API_URL}/submit", json=submission_data, timeout=60)
response.raise_for_status()
res = response.json()
return (
f"✅ SUBMISSION SUCCESSFUL!\n\n"
f"User: {res.get('username')}\n"
f"Score: {res.get('score')}% | Correct: {res.get('correct_count')}/{res.get('total_attempted')}"
)
except Exception as e:
return f"❌ Submission Failed: {str(e)}"
# =============================================================================
# 3. USER INTERFACE (Gradio Blocks)
# =============================================================================
with gr.Blocks(title="GAIAgent Evaluation") as demo:
gr.Markdown("# 🤖 GAIAgent Evaluation Runner")
gr.Markdown(
"""
This tool runs your agent on the benchmark tasks, designed to be robust against network timeouts.
**Status:** Grace Time: 10s | Max Retries: 2 | **Persistence:** Upload on HF Datasets
"""
)
with gr.Row(): login_btn = gr.LoginButton()
with gr.Row():
btn_gen = gr.Button("1. Generate Answers (Safe Mode)", variant="primary")
btn_sub = gr.Button("2. Submit Results", variant="secondary")
status = gr.Textbox(label="Current Status", lines=10, interactive=False)
results_table = gr.DataFrame(label="Generated Answers", headers=["Task ID", "Status", "Question", "File", "Answer"], wrap=True)
# --- Event Handlers ---
btn_gen.click(fn=generate_answers, inputs=[login_btn], outputs=[status, results_table])
btn_sub.click(fn=submit_to_leaderboard, inputs=[login_btn], outputs=[status])
if __name__ == "__main__":
demo.launch()