ishmeet-yo's picture
Update app/rag.py
f6c9e8d verified
import os
import pickle
import hashlib
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize
CACHE_DIR = "app/cache"
DATA_DIR = "app/data"
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
CHUNK_SIZE = 500
CHUNK_OVERLAP = 100
def compute_hash(files):
h = hashlib.md5()
for f in sorted(files):
with open(f, "rb") as fp:
h.update(fp.read())
return h.hexdigest()
def load_documents():
files = [
os.path.join(DATA_DIR, f)
for f in os.listdir(DATA_DIR)
if f.endswith(".txt")
]
if not files:
raise RuntimeError("No .txt files found in app/data")
texts = []
for f in files:
with open(f, encoding="utf-8", errors="ignore") as fp:
texts.append(fp.read())
return texts, files
def chunk_text(text, size=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
words = text.split()
chunks = []
i = 0
while i < len(words):
chunk = words[i:i + size]
chunks.append(" ".join(chunk))
i += size - overlap
return chunks
def chunk_documents(texts):
chunks = []
for t in texts:
chunks.extend(chunk_text(t))
return chunks
def build_embeddings(chunks):
model = SentenceTransformer(MODEL_NAME)
semantic = normalize(
model.encode(chunks, batch_size=32, show_progress_bar=True)
)
narrative = normalize(
model.encode(
["Story context: " + c for c in chunks],
batch_size=32,
show_progress_bar=True
)
)
entity = normalize(
model.encode(
["Entities mentioned: " + c for c in chunks],
batch_size=32,
show_progress_bar=True
)
)
tfidf = TfidfVectorizer(
ngram_range=(1, 2),
stop_words="english"
)
tfidf_matrix = tfidf.fit_transform(chunks)
return {
"semantic": semantic,
"narrative": narrative,
"entity": entity,
"tfidf": tfidf,
"tfidf_matrix": tfidf_matrix,
"model": model
}
def save_cache(chunks, heads, dataset_hash):
os.makedirs(CACHE_DIR, exist_ok=True)
np.save(f"{CACHE_DIR}/semantic.npy", heads["semantic"])
np.save(f"{CACHE_DIR}/narrative.npy", heads["narrative"])
np.save(f"{CACHE_DIR}/entity.npy", heads["entity"])
with open(f"{CACHE_DIR}/chunks.pkl", "wb") as f:
pickle.dump(chunks, f)
with open(f"{CACHE_DIR}/tfidf.pkl", "wb") as f:
pickle.dump(heads["tfidf"], f)
with open(f"{CACHE_DIR}/tfidf_matrix.pkl", "wb") as f:
pickle.dump(heads["tfidf_matrix"], f)
with open(f"{CACHE_DIR}/hash.txt", "w") as f:
f.write(dataset_hash)
def load_cache():
with open(f"{CACHE_DIR}/chunks.pkl", "rb") as f:
chunks = pickle.load(f)
heads = {
"semantic": np.load(f"{CACHE_DIR}/semantic.npy"),
"narrative": np.load(f"{CACHE_DIR}/narrative.npy"),
"entity": np.load(f"{CACHE_DIR}/entity.npy"),
}
with open(f"{CACHE_DIR}/tfidf.pkl", "rb") as f:
heads["tfidf"] = pickle.load(f)
with open(f"{CACHE_DIR}/tfidf_matrix.pkl", "rb") as f:
heads["tfidf_matrix"] = pickle.load(f)
# model is loaded once here
heads["model"] = SentenceTransformer(MODEL_NAME)
return chunks, heads
def load_data():
texts, files = load_documents()
chunks = chunk_documents(texts)
dataset_hash = compute_hash(files)
hash_path = f"{CACHE_DIR}/hash.txt"
cached_hash = None
if os.path.exists(hash_path):
with open(hash_path) as f:
cached_hash = f.read().strip()
if cached_hash == dataset_hash:
print("Loading embeddings from cache")
return load_cache()
print("Building embeddings")
heads = build_embeddings(chunks)
save_cache(chunks, heads, dataset_hash)
return chunks, heads
def retrieve_chunks(query, chunks, heads, k=5):
model = heads["model"]
q_sem = normalize(model.encode([query]))
q_nav = normalize(model.encode(["Story question: " + query]))
q_ent = normalize(model.encode(["Entities in question: " + query]))
sem_score = heads["semantic"] @ q_sem.T
nav_score = heads["narrative"] @ q_nav.T
ent_score = heads["entity"] @ q_ent.T
q_tfidf = heads["tfidf"].transform([query])
key_score = heads["tfidf_matrix"] @ q_tfidf.T
final_score = (
0.40 * sem_score +
0.30 * nav_score +
0.15 * ent_score +
0.15 * key_score.toarray()
)
top_idx = np.argsort(final_score.flatten())[::-1][:k]
return [chunks[i] for i in top_idx]