Spaces:
Paused
Paused
| import os | |
| import json | |
| import requests | |
| import pandas as pd | |
| import gradio as gr | |
| import time | |
| from typing import Dict, Any, List | |
| # --- Import Agent and Tools (These must be present in agent.py and tools.py) --- | |
| try: | |
| from agent import GAIAgent | |
| except ImportError: | |
| print("FATAL: Cannot import GAIAgent. Ensure agent.py exists and is in the PYTHONPATH.") | |
| # --- PERSISTENCE LIBRARIES --- | |
| # Required for persistent caching across container rebuilds | |
| from huggingface_hub import HfApi, HfFileSystem | |
| from huggingface_hub.utils import RepositoryNotFoundError | |
| # --- Constants --- | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| CACHE_FILE = "answers_cache.json" | |
| HF_DATASET_ID = os.environ.get("DATASET_ID", "antmontieri/gaia-agent-cache") | |
| # --- RATE LIMIT & ROBUSTNESS CONFIGURATION --- | |
| MAX_RETRIES = 2 | |
| REQUEST_DELAY = 10 # Grace Time (seconds) to wait between starting tasks | |
| ERROR_DELAY = 60 # Pause after an API error (Backoff) | |
| HF_REPO_ID = os.environ.get("SPACE_ID", "local-test") | |
| # ============================================================================= | |
| # 0. CACHE & PERSISTENCE HELPERS | |
| # ============================================================================= | |
| def load_cache_as_dict() -> Dict[str, str]: | |
| """Loads cache from local filesystem into a dictionary for quick lookup.""" | |
| if os.path.exists(CACHE_FILE): | |
| try: | |
| with open(CACHE_FILE, "r") as f: | |
| data = json.load(f) | |
| return {item["task_id"]: item["submitted_answer"] for item in data} | |
| except Exception as e: | |
| print(f"Error reading local cache: {e}") | |
| return {} | |
| def save_cache_local(data: List[Dict[str, str]]): | |
| """Saves data to local JSON file.""" | |
| try: | |
| with open(CACHE_FILE, "w") as f: | |
| json.dump(data, f, indent=2) | |
| except Exception as e: | |
| print(f"Error saving local cache: {e}") | |
| def load_persistent_cache() -> Dict[str, str]: | |
| """ | |
| Tries to load cache from the Hugging Face Dataset or local copy. | |
| """ | |
| # 1. Load from local copy if running (fastest) | |
| local_dict = load_cache_as_dict() | |
| if local_dict: | |
| return local_dict | |
| # 2. Attempt to load from the Dataset repo | |
| fs = HfFileSystem() | |
| dataset_file_path = f"datasets/{HF_DATASET_ID}/{CACHE_FILE}" | |
| try: | |
| if fs.exists(dataset_file_path): | |
| print(f"📥 Loading cache from Dataset: {HF_DATASET_ID}") | |
| with fs.open(dataset_file_path, "r") as f: | |
| data = json.load(f) | |
| save_cache_local(data) # Update local file | |
| return {item["task_id"]: item["submitted_answer"] for item in data} | |
| except Exception as e: | |
| print(f"⚠️ Error loading from Dataset: {e}") | |
| return {} | |
| def commit_cache_to_repo(data: List[Dict[str, str]]): | |
| """Saves locally and pushes update to the Hugging Face Dataset.""" | |
| # 1. Save locally first | |
| save_cache_local(data) | |
| # 2. Get Token | |
| hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") | |
| if not hf_token: | |
| print("WARNING: HF_TOKEN secret missing. Cannot sync with Dataset.") | |
| return | |
| try: | |
| api = HfApi(token=hf_token) | |
| print(f"☁️ Syncing {len(data)} answers to Dataset: {HF_DATASET_ID}...") | |
| api.upload_file( | |
| path_or_fileobj=CACHE_FILE, | |
| path_in_repo=CACHE_FILE, | |
| repo_id=HF_DATASET_ID, | |
| repo_type="dataset", | |
| commit_message=f"Agent update: {len(data)} answers generated." | |
| ) | |
| print("✅ Sync complete.") | |
| except RepositoryNotFoundError: | |
| print(f"❌ Error: Dataset {HF_DATASET_ID} not found. Did you create it?") | |
| except Exception as e: | |
| print(f"❌ Sync Error: {e}") | |
| # ... (rest of the helpers, fetch_questions is external to the class logic) ... | |
| def fetch_questions(): | |
| """Fetches questions from the API.""" | |
| try: | |
| response = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15) | |
| response.raise_for_status() | |
| return response.json() | |
| except Exception as e: | |
| print(f"Error fetching questions: {e}") | |
| return [] | |
| def is_valid_answer(answer: str) -> bool: | |
| """Checks if the answer meets minimum criteria (not chatter/error/too long).""" | |
| if not answer: return False | |
| if "AGENT ERROR" in answer or "Error:" in answer: return False | |
| if len(answer) > 150: return False | |
| forbidden_phrases = ["I cannot", "I am sorry", "The answer is", "Based on"] | |
| for phrase in forbidden_phrases: | |
| if phrase.lower() in answer.lower(): | |
| if "not known" in answer.lower(): return True | |
| return False | |
| return True | |
| # ============================================================================= | |
| # 1. GENERATOR FUNCTION (With Grace Time & Retry) | |
| # ============================================================================= | |
| def generate_answers(profile: gr.OAuthProfile | None, *args): | |
| """ | |
| Runs the agent with strategic pauses and streams updates to the UI. | |
| Passes file_name to the agent for multimodal tasks. | |
| """ | |
| # ... (omessa la prima parte di setup e inizializzazione) ... | |
| if not profile: | |
| yield "⚠️ Please login with Hugging Face first!", pd.DataFrame() | |
| return | |
| # Load state from persistence | |
| existing_answers = load_persistent_cache() | |
| questions = fetch_questions() | |
| if not questions: | |
| yield "❌ Failed to fetch questions.", pd.DataFrame() | |
| return | |
| # Initialize Agent | |
| try: | |
| agent = GAIAgent() | |
| except Exception as e: | |
| yield f"❌ Error initializing agent: {e}", pd.DataFrame() | |
| return | |
| results_log = [] | |
| payload_cache = [] | |
| total = len(questions) | |
| yield f"🚀 Starting execution on {total} tasks (Grace Time: {REQUEST_DELAY}s)...", pd.DataFrame() | |
| for i, item in enumerate(questions): | |
| task_id = item.get("task_id") | |
| question_text = item.get("question") | |
| # --- AGGIUNTA FONDAMENTALE: RECUPERA file_name --- | |
| file_name = item.get("file_name") | |
| # ---------------------------------------------------- | |
| # --- CACHE CHECK --- | |
| if task_id in existing_answers and is_valid_answer(existing_answers[task_id]): | |
| answer = existing_answers[task_id] | |
| # --- AGGIORNA VISUALIZZAZIONE con FILE --- | |
| results_log.append({ | |
| "Task ID": task_id, | |
| "Status": "⚡ Cached", | |
| "Question": question_text[:50] + "...", | |
| "File": file_name if file_name else "N/A", # Visualizza il nome del file | |
| "Answer": answer | |
| }) | |
| payload_cache.append({"task_id": task_id, "submitted_answer": answer}) | |
| yield f"⚡ Task {i+1} Cached.", pd.DataFrame(results_log) | |
| continue | |
| # --- REAL GENERATION (Slow) --- | |
| final_answer = "NOT KNOWN" | |
| success = False | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| status_msg = f"⏳ Task {i+1}/{total} (Try {attempt+1}/{MAX_RETRIES})..." | |
| yield status_msg, pd.DataFrame(results_log) | |
| # CHIAMATA AGENTE: Passa il file_name al metodo __call__ dell'agente | |
| answer = agent(question_text, file_name=file_name) | |
| # VALIDATION | |
| if is_valid_answer(answer): | |
| final_answer = answer | |
| success = True | |
| # --- GRACE TIME AFTER SUCCESS --- | |
| wait_time = REQUEST_DELAY | |
| yield f"✅ Success. Cooling down for {wait_time}s...", pd.DataFrame(results_log) | |
| time.sleep(wait_time) | |
| break | |
| else: | |
| print(f"⚠️ Invalid answer on attempt {attempt+1}: {answer}") | |
| # --- BACKOFF AFTER INVALID FORMAT --- | |
| yield f"⚠️ Invalid format. Retrying in {ERROR_DELAY}s...", pd.DataFrame(results_log) | |
| time.sleep(ERROR_DELAY) | |
| except Exception as e: | |
| print(f"❌ Error on attempt {attempt+1}: {e}") | |
| # --- BACKOFF AFTER ERROR (e.g., Rate Limit) --- | |
| yield f"❌ API Error. Pausing {ERROR_DELAY}s...", pd.DataFrame(results_log) | |
| time.sleep(ERROR_DELAY) | |
| # Save Result | |
| status_label = "✅ Generated" if success else "⚠️ Failed" | |
| # --- AGGIORNA VISUALIZZAZIONE con FILE --- | |
| results_log.append({ | |
| "Task ID": task_id, | |
| "Status": status_label, | |
| "Question": question_text[:50] + "...", | |
| "File": file_name if file_name else "N/A", # Aggiunto per il debug | |
| "Answer": final_answer | |
| }) | |
| payload_cache.append({"task_id": task_id, "submitted_answer": final_answer}) | |
| # Save progress to Git | |
| commit_cache_to_repo(payload_cache) | |
| yield f"🏁 Logic Complete Task {i+1}", pd.DataFrame(results_log) | |
| yield "🎉 Generation Complete! Click 'Submit' to send results.", pd.DataFrame(results_log) | |
| # ============================================================================= | |
| # 2. SUBMIT FUNCTION (Fast Process) | |
| # ============================================================================= | |
| def submit_to_leaderboard(profile: gr.OAuthProfile | None, *args): | |
| """ | |
| Reads the persistent cache and submits the final payload to the leaderboard. | |
| """ | |
| if not profile: return "⚠️ Please login first!" | |
| # Load the latest committed data | |
| answers_dict = load_persistent_cache() | |
| if not answers_dict: return "⚠️ No answers found. Please generate answers first." | |
| # Prepare payload for submission | |
| answers = [{"task_id": k, "submitted_answer": v} for k, v in answers_dict.items()] | |
| username = profile.username | |
| space_id = os.getenv("SPACE_ID", "local-test") | |
| agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" | |
| submission_data = {"username": username, "agent_code": agent_code, "answers": answers} | |
| try: | |
| response = requests.post(f"{DEFAULT_API_URL}/submit", json=submission_data, timeout=60) | |
| response.raise_for_status() | |
| res = response.json() | |
| return ( | |
| f"✅ SUBMISSION SUCCESSFUL!\n\n" | |
| f"User: {res.get('username')}\n" | |
| f"Score: {res.get('score')}% | Correct: {res.get('correct_count')}/{res.get('total_attempted')}" | |
| ) | |
| except Exception as e: | |
| return f"❌ Submission Failed: {str(e)}" | |
| # ============================================================================= | |
| # 3. USER INTERFACE (Gradio Blocks) | |
| # ============================================================================= | |
| with gr.Blocks(title="GAIAgent Evaluation") as demo: | |
| gr.Markdown("# 🤖 GAIAgent Evaluation Runner") | |
| gr.Markdown( | |
| """ | |
| This tool runs your agent on the benchmark tasks, designed to be robust against network timeouts. | |
| **Status:** Grace Time: 10s | Max Retries: 2 | **Persistence:** Upload on HF Datasets | |
| """ | |
| ) | |
| with gr.Row(): login_btn = gr.LoginButton() | |
| with gr.Row(): | |
| btn_gen = gr.Button("1. Generate Answers (Safe Mode)", variant="primary") | |
| btn_sub = gr.Button("2. Submit Results", variant="secondary") | |
| status = gr.Textbox(label="Current Status", lines=10, interactive=False) | |
| results_table = gr.DataFrame(label="Generated Answers", headers=["Task ID", "Status", "Question", "File", "Answer"], wrap=True) | |
| # --- Event Handlers --- | |
| btn_gen.click(fn=generate_answers, inputs=[login_btn], outputs=[status, results_table]) | |
| btn_sub.click(fn=submit_to_leaderboard, inputs=[login_btn], outputs=[status]) | |
| if __name__ == "__main__": | |
| demo.launch() | |