Spaces:
Running
Running
| import os | |
| import pickle | |
| import hashlib | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.preprocessing import normalize | |
| CACHE_DIR = "app/cache" | |
| DATA_DIR = "app/data" | |
| MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| CHUNK_SIZE = 500 | |
| CHUNK_OVERLAP = 100 | |
| def compute_hash(files): | |
| h = hashlib.md5() | |
| for f in sorted(files): | |
| with open(f, "rb") as fp: | |
| h.update(fp.read()) | |
| return h.hexdigest() | |
| def load_documents(): | |
| files = [ | |
| os.path.join(DATA_DIR, f) | |
| for f in os.listdir(DATA_DIR) | |
| if f.endswith(".txt") | |
| ] | |
| if not files: | |
| raise RuntimeError("No .txt files found in app/data") | |
| texts = [] | |
| for f in files: | |
| with open(f, encoding="utf-8", errors="ignore") as fp: | |
| texts.append(fp.read()) | |
| return texts, files | |
| def chunk_text(text, size=CHUNK_SIZE, overlap=CHUNK_OVERLAP): | |
| words = text.split() | |
| chunks = [] | |
| i = 0 | |
| while i < len(words): | |
| chunk = words[i:i + size] | |
| chunks.append(" ".join(chunk)) | |
| i += size - overlap | |
| return chunks | |
| def chunk_documents(texts): | |
| chunks = [] | |
| for t in texts: | |
| chunks.extend(chunk_text(t)) | |
| return chunks | |
| def build_embeddings(chunks): | |
| model = SentenceTransformer(MODEL_NAME) | |
| semantic = normalize( | |
| model.encode(chunks, batch_size=32, show_progress_bar=True) | |
| ) | |
| narrative = normalize( | |
| model.encode( | |
| ["Story context: " + c for c in chunks], | |
| batch_size=32, | |
| show_progress_bar=True | |
| ) | |
| ) | |
| entity = normalize( | |
| model.encode( | |
| ["Entities mentioned: " + c for c in chunks], | |
| batch_size=32, | |
| show_progress_bar=True | |
| ) | |
| ) | |
| tfidf = TfidfVectorizer( | |
| ngram_range=(1, 2), | |
| stop_words="english" | |
| ) | |
| tfidf_matrix = tfidf.fit_transform(chunks) | |
| return { | |
| "semantic": semantic, | |
| "narrative": narrative, | |
| "entity": entity, | |
| "tfidf": tfidf, | |
| "tfidf_matrix": tfidf_matrix, | |
| "model": model | |
| } | |
| def save_cache(chunks, heads, dataset_hash): | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| np.save(f"{CACHE_DIR}/semantic.npy", heads["semantic"]) | |
| np.save(f"{CACHE_DIR}/narrative.npy", heads["narrative"]) | |
| np.save(f"{CACHE_DIR}/entity.npy", heads["entity"]) | |
| with open(f"{CACHE_DIR}/chunks.pkl", "wb") as f: | |
| pickle.dump(chunks, f) | |
| with open(f"{CACHE_DIR}/tfidf.pkl", "wb") as f: | |
| pickle.dump(heads["tfidf"], f) | |
| with open(f"{CACHE_DIR}/tfidf_matrix.pkl", "wb") as f: | |
| pickle.dump(heads["tfidf_matrix"], f) | |
| with open(f"{CACHE_DIR}/hash.txt", "w") as f: | |
| f.write(dataset_hash) | |
| def load_cache(): | |
| with open(f"{CACHE_DIR}/chunks.pkl", "rb") as f: | |
| chunks = pickle.load(f) | |
| heads = { | |
| "semantic": np.load(f"{CACHE_DIR}/semantic.npy"), | |
| "narrative": np.load(f"{CACHE_DIR}/narrative.npy"), | |
| "entity": np.load(f"{CACHE_DIR}/entity.npy"), | |
| } | |
| with open(f"{CACHE_DIR}/tfidf.pkl", "rb") as f: | |
| heads["tfidf"] = pickle.load(f) | |
| with open(f"{CACHE_DIR}/tfidf_matrix.pkl", "rb") as f: | |
| heads["tfidf_matrix"] = pickle.load(f) | |
| # model is loaded once here | |
| heads["model"] = SentenceTransformer(MODEL_NAME) | |
| return chunks, heads | |
| def load_data(): | |
| texts, files = load_documents() | |
| chunks = chunk_documents(texts) | |
| dataset_hash = compute_hash(files) | |
| hash_path = f"{CACHE_DIR}/hash.txt" | |
| cached_hash = None | |
| if os.path.exists(hash_path): | |
| with open(hash_path) as f: | |
| cached_hash = f.read().strip() | |
| if cached_hash == dataset_hash: | |
| print("Loading embeddings from cache") | |
| return load_cache() | |
| print("Building embeddings") | |
| heads = build_embeddings(chunks) | |
| save_cache(chunks, heads, dataset_hash) | |
| return chunks, heads | |
| def retrieve_chunks(query, chunks, heads, k=5): | |
| model = heads["model"] | |
| q_sem = normalize(model.encode([query])) | |
| q_nav = normalize(model.encode(["Story question: " + query])) | |
| q_ent = normalize(model.encode(["Entities in question: " + query])) | |
| sem_score = heads["semantic"] @ q_sem.T | |
| nav_score = heads["narrative"] @ q_nav.T | |
| ent_score = heads["entity"] @ q_ent.T | |
| q_tfidf = heads["tfidf"].transform([query]) | |
| key_score = heads["tfidf_matrix"] @ q_tfidf.T | |
| final_score = ( | |
| 0.40 * sem_score + | |
| 0.30 * nav_score + | |
| 0.15 * ent_score + | |
| 0.15 * key_score.toarray() | |
| ) | |
| top_idx = np.argsort(final_score.flatten())[::-1][:k] | |
| return [chunks[i] for i in top_idx] | |