bebechien's picture
Upload folder using huggingface_hub
3890a8c verified
raw
history blame
11.7 kB
import gradio as gr
import requests
import os
import pickle
import spaces
import torch
from bs4 import BeautifulSoup
from html_to_markdown import convert_to_markdown
from huggingface_hub import login
from sentence_transformers import SentenceTransformer
from transformers import pipeline, TextIteratorStreamer
from threading import Thread
from tqdm import tqdm
# --- 1. CONFIGURATION ---
# Centralized place for all settings and constants.
# Hugging Face & Model Configuration
HF_TOKEN = os.getenv('HF_TOKEN')
EMBEDDING_MODEL_ID = "google/embeddinggemma-300M"
LLM_MODEL_ID = "google/gemma-3-12B-it"
# Data Source Configuration
BASE_URL = "https://hollowknight.wiki"
GAME_KNOWLEDGE_DATA = [
{
"title": "Hollow Knight",
"category_list": [
{
"entry": "/w/Category:Bosses_(Hollow_Knight)",
"cache": "hollow_knight_boss.pkl",
"label": "Bosses",
},
],
},
{
"title": "Silksong",
"category_list": [
{
"entry": "/w/Category:Areas_(Silksong)",
"cache": "silksong_areas.pkl",
"label": "Areas",
},
{
"entry": "/w/Category:Bosses_(Silksong)",
"cache": "silksong_bosses.pkl",
"label": "Bosses",
},
{
"entry": "/w/Category:Tools_and_Skills_(Silksong)",
"cache": "silksong_tools_and_skills.pkl",
"label": "Tools and Skills",
},
{
"entry": "/w/Category:NPCs_(Silksong)",
"cache": "silksong_npcs.pkl",
"label": "NPCs",
}
],
},
]
# Gradio App Configuration
DEFAULT_SIMILARITY_THRESHOLD = 0.5
DEFAULT_MESSAGE_NO_MATCH = "I'm sorry, I can't find a relevant document to answer that question. Try asking about a specific boss in Hollow Knight."
# --- 2. HELPER FUNCTIONS ---
# Reusable functions for web scraping and data processing.
def _get_html(url: str) -> str:
"""Fetches HTML content from a URL."""
try:
response = requests.get(url)
response.raise_for_status() # Raises an HTTPError for bad responses (4xx or 5xx)
return response.text
except requests.exceptions.RequestException as e:
print(f"Error fetching {url}: {e}")
return ""
def _find_wiki_links(html_content: str) -> list[str]:
"""Parses HTML to find all boss links within the 'mw-pages' div."""
soup = BeautifulSoup(html_content, 'html.parser')
mw_pages_div = soup.find('div', id='mw-pages')
if not mw_pages_div:
return []
return [a['href'] for a in mw_pages_div.find_all('a', href=True)]
def _get_markdown_from_html(html: str) -> str:
if not html:
return ""
soup = BeautifulSoup(html, 'html.parser')
return convert_to_markdown(soup)
def _get_markdown_from_url(url: str) -> str:
return _get_markdown_from_html(_get_html(url))
# --- 3. DATA PROCESSING & CACHING ---
# Scrapes data and generates embeddings, using a cache to avoid re-running.
def _clean_text(text: str) -> str:
"""Removes the references section from the raw text."""
return text.split("References\n----------\n", 1)[0].strip()
def _create_data_entry(text: str, doc_path: str, label: str, embedding_model) -> dict | None:
"""Creates a single structured data entry with text, metadata, and embedding."""
cleaned_text = _clean_text(text)
if not cleaned_text:
return None
title = doc_path.split('/')[-1]
embedding = embedding_model.encode(cleaned_text, prompt=f"title: {title} | text: ")
return {
"text": cleaned_text,
"embedding": embedding,
"metadata": {
"category": label,
"source": BASE_URL + doc_path,
"title": title
}
}
def load_or_process_source(entry_point: str, cache_file: str, label: str, embedding_model):
"""
Loads processed data from a cache file if it exists. Otherwise, scrapes,
processes, generates embeddings, and saves to the cache.
"""
if os.path.exists(cache_file):
print(f"✅ Found cache for {label}. Loading data from '{cache_file}'...")
with open(cache_file, 'rb') as f:
return pickle.load(f)
print(f"ℹ️ No cache for {label}. Starting data scraping and processing...")
processed_data = []
main_page_html = _get_html(BASE_URL + entry_point)
data_entry = _create_data_entry(_get_markdown_from_html(main_page_html), entry_point, label, embedding_model)
if (data_entry):
processed_data.append(data_entry)
extracted_links = _find_wiki_links(main_page_html)
for doc_path in tqdm(extracted_links, desc=f"Processing {label} Pages"):
full_url = BASE_URL + doc_path
text = _get_markdown_from_url(full_url)
data_entry = _create_data_entry(text, doc_path, label, embedding_model)
if data_entry:
processed_data.append(data_entry)
print(f"✅ {label} processing complete. Saving {len(processed_data)} entries to '{cache_file}'...")
with open(cache_file, 'wb') as f:
pickle.dump(processed_data, f)
return processed_data
# --- 4. CORE AI LOGIC ---
# Functions for finding context and generating a response.
def find_best_context(model, query: str, contents: dict, similarity_threshold: float):
"""Finds the most relevant document text based on semantic similarity."""
if not query:
return None
query_embedding = model.encode(query, prompt_name="query")
contents_embeddings = torch.stack([torch.tensor(item["embedding"]) for item in contents])
similarities = model.similarity(query_embedding, contents_embeddings)
best_index = similarities.argmax().item()
best_score = similarities[0, best_index].item()
print(best_score)
if best_score >= similarity_threshold:
print(f"Using \"{contents[best_index]['metadata']['source']}\"...")
return contents[best_index]["text"]
return None
context = None
@spaces.GPU
def respond(message: str, history: list, game: str, similarity_threshold: float):
"""Generates a streaming response from the LLM based on the best context found."""
global context
contents = _select_content(game)
if not contents:
yield DEFAULT_MESSAGE_NO_MATCH
return
if (context := find_best_context(embedding_model, message, contents, similarity_threshold) or context):
# SUCCESS: A valid context was found and has been saved.
pass
else:
# FAILURE: No context is available.
yield DEFAULT_MESSAGE_NO_MATCH
return
system_prompt = f"Answer the following QUESTION based only on the CONTEXT provided. If the answer cannot be found in the CONTEXT, write \"I don't know.\"\n---\nCONTEXT:\n{context}\n"
user_prompt = f"QUESTION:\n{message}"
messages = [{"role": "system", "content": system_prompt}]
messages.extend(history)
messages.append({"role": "user", "content": user_prompt})
for item in messages[1:]:
print(item['role'])
print(item['content'])
streamer = TextIteratorStreamer(llm_pipeline.tokenizer, skip_prompt=True, skip_special_tokens=True)
thread = Thread(
target=llm_pipeline,
kwargs=dict(
text_inputs=messages,
streamer=streamer,
max_new_tokens=512,
do_sample=True,
top_p=0.95,
)
)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
# --- 5. INITIALIZATION ---
# Login, load models, and process data.
print("Logging into Hugging Face Hub...")
login(token=HF_TOKEN)
print("Initializing embedding model...")
embedding_model = SentenceTransformer(EMBEDDING_MODEL_ID)
print("Initializing language model...")
llm_pipeline = pipeline(
"text-generation",
model=LLM_MODEL_ID,
device_map="auto",
dtype="auto",
)
print("\n--- Processing Game Data ---")
knowledge_base = {}
for item in GAME_KNOWLEDGE_DATA:
knowledge_base[item['title']] = []
for category in item['category_list']:
knowledge_base[item['title']] += load_or_process_source(category['entry'], category['cache'], category['label'], embedding_model)
def _select_content(game: str):
return knowledge_base[game]
# --- 6. GRADIO UI ---
# Defines the web interface for the chatbot.
gr.set_static_paths(paths=["assets/"])
# Theme and CSS for the Silksong aesthetic
silksong_theme = gr.themes.Default(
primary_hue=gr.themes.colors.red,
secondary_hue=gr.themes.colors.zinc,
neutral_hue=gr.themes.colors.zinc,
font=[gr.themes.GoogleFont("IM Fell English"), "ui-sans-serif", "system-ui", "sans-serif"],
)
silksong_css="""
.gradio-container {
background-image: linear-gradient(rgba(255,255,255, 0.5), rgba(255, 255, 255, 1.0)), url("/gradio_api/file=assets/background.jpg");
background-size: cover;
background-repeat: no-repeat;
background-position: center;
}
body.dark .gradio-container {
background-image: linear-gradient(rgba(0, 0, 0, 0.5), rgba(0, 0, 0, 1.0)), url("/gradio_api/file=assets/background.jpg");
}
.header-text { text-align: center; text-shadow: 2px 2px 5px #000; }
.header-text h1 { font-size: 2.5em; color: #dc2626; }
.dark .header-text { text-shadow: 2px 2px 5px #FFF; }
.disclaimer { text-align: center; color: var(--body-text-color-subdued); font-size: 0.9em; padding: 20px; }
.disclaimer ul { list-style: none; padding: 0; }
.disclaimer a { color: #dc2626; }
"""
with gr.Blocks(theme=silksong_theme, css=silksong_css) as demo:
gr.HTML("""
<div class="header-text">
<h1>A Weaver's Counsel</h1>
<p>Speak, little traveler. What secrets of Pharloom do you seek?</p>
<p style="font-style: italic;">(Note: This bot has a limited knowledge.)</p>
</div>
""")
gr.ChatInterface(
respond,
type="messages",
chatbot=gr.Chatbot(type="messages", label=LLM_MODEL_ID),
textbox=gr.Textbox(placeholder="Ask about the haunted kingdom...", container=False, submit_btn=True, scale=7),
additional_inputs=[
gr.Dropdown(["Hollow Knight", "Silksong"], label="Game"),
gr.Slider(minimum=0.1, maximum=1.0, value=DEFAULT_SIMILARITY_THRESHOLD, step=0.1, label="Similarity Threshold"),
],
examples=[
["Where can I find the Moorwing?", "Silksong", DEFAULT_SIMILARITY_THRESHOLD],
["Who is the voice of Lace?", "Silksong", DEFAULT_SIMILARITY_THRESHOLD],
["How can I beat the False Knight?", "Hollow Knight", DEFAULT_SIMILARITY_THRESHOLD],
["Any achievement for Hornet Protector?", "Hollow Knight", DEFAULT_SIMILARITY_THRESHOLD],
],
)
gr.HTML("""
<div class="disclaimer">
<p><strong>Disclaimer:</strong></p>
<ul style="list-style: none; padding: 0;">
<li>This is a fan-made personal demonstration and not affiliated with any organization.<br>The bot is for entertainment purposes only.</li>
<li>Factual information is sourced from the <a href="https://hollowknight.wiki" target="_blank">Hollow Knight Wiki</a>.<br>Content is available under <a href="https://creativecommons.org/licenses/by-sa/3.0/" target="_blank">Commons Attribution-ShareAlike</a> unless otherwise noted.</li>
<li>Built by <a href="https://huggingface.co/bebechien" target="_blank">bebechien</a> with a 💖 for the world of Hollow Knight.</li>
</ul>
</div>
""")
if __name__ == "__main__":
demo.launch()