| | import os |
| | import io |
| | import glob |
| | import tempfile |
| | import time |
| | import numpy as np |
| | import pandas as pd |
| | import requests |
| | import gradio as gr |
| | from bs4 import BeautifulSoup |
| | from PyPDF2 import PdfReader |
| | from docx import Document |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter |
| | from sentence_transformers import SentenceTransformer |
| | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline |
| | import faiss |
| |
|
| | |
| | HF_HOME = os.environ.get("HF_HOME", "/tmp/hf_cache") |
| | os.makedirs(HF_HOME, exist_ok=True) |
| |
|
| | os.environ["HF_HOME"] = HF_HOME |
| | os.environ["TRANSFORMERS_CACHE"] = HF_HOME |
| | os.environ["SENTENCE_TRANSFORMERS_HOME"] = HF_HOME |
| | os.environ["HF_DATASETS_CACHE"] = HF_HOME |
| | os.environ["XDG_CACHE_HOME"] = HF_HOME |
| | os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" |
| |
|
| | locks_dir = os.path.join(HF_HOME, "hub", ".locks") |
| | if os.path.isdir(locks_dir): |
| | for p in glob.glob(os.path.join(locks_dir, "*.lock")): |
| | try: os.remove(p) |
| | except: pass |
| |
|
| | MODEL_ID = "MehdiHosseiniMoghadam/AVA-Mistral-7B-V2" |
| |
|
| | embedder = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=HF_HOME) |
| | config = AutoConfig.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True) |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True) |
| | model = AutoModelForCausalLM.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True) |
| | llm = pipeline("text-generation", model=model, tokenizer=tokenizer, |
| | max_length=1024, do_sample=True, temperature=0.2, |
| | trust_remote_code=True, device_map="auto") |
| |
|
| | def load_file_text(file): |
| | name = file.name.lower() |
| | if name.endswith(".pdf"): |
| | reader = PdfReader(file) |
| | text = "".join(page.extract_text() or "" for page in reader.pages) |
| | return text |
| | elif name.endswith(".docx"): |
| | data = file.read() |
| | doc = Document(io.BytesIO(data)) |
| | return " ".join(p.text for p in doc.paragraphs) |
| | elif name.endswith(".csv"): |
| | data = file.read() |
| | for enc in ("utf-8", "latin-1"): |
| | try: |
| | df = pd.read_csv(io.BytesIO(data), encoding=enc) |
| | return " ".join(df.astype(str).values.flatten().tolist()) |
| | except: pass |
| | return "" |
| | elif name.endswith(".txt"): |
| | raw = file.read() |
| | for enc in ("utf-8", "latin-1"): |
| | try: return raw.decode(enc, errors="ignore") |
| | except: continue |
| | return raw.decode("utf-8", errors="ignore") |
| | else: |
| | return "" |
| |
|
| | def fetch_web_text(url): |
| | try: |
| | headers = {'User-Agent': 'Mozilla/5.0'} |
| | resp = requests.get(url, headers=headers, timeout=10) |
| | resp.raise_for_status() |
| | soup = BeautifulSoup(resp.text, "html.parser") |
| | for tag in soup(["script", "style", "noscript"]): |
| | tag.decompose() |
| | return " ".join(soup.get_text(separator=" ").split()) |
| | except Exception: |
| | return "" |
| |
|
| | def chunk_docs(docs, chunk_size=1000, chunk_overlap=120): |
| | splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) |
| | chunks = [] |
| | for doc in docs: |
| | splits = splitter.split_text(doc["text"]) |
| | for idx, chunk in enumerate(splits): |
| | chunks.append({"source": doc["source"], "chunk_id": f"{doc['source']}_chunk{idx}", "content": chunk}) |
| | return chunks |
| |
|
| | def build_index_and_chunks(docs): |
| | chunks = chunk_docs(docs) |
| | texts = [chunk["content"] for chunk in chunks] |
| | if len(texts) == 0: return None, [] |
| | embeddings = embedder.encode(texts, show_progress_bar=True, convert_to_numpy=True) |
| | embeddings = np.asarray(embeddings).astype("float32") |
| | dim = embeddings.shape[1] |
| | index = faiss.IndexFlatL2(dim) |
| | index.add(embeddings) |
| | return index, chunks |
| |
|
| | def retrieve(query, index, chunks, top_k=3): |
| | if index is None or len(chunks) == 0: |
| | return [] |
| | q_emb = embedder.encode([query], convert_to_numpy=True) |
| | q_emb = np.asarray(q_emb).astype("float32") |
| | distances, indices = index.search(q_emb, top_k) |
| | results = [] |
| | for dist, idx in zip(distances[0], indices[0]): |
| | if idx >= 0 and idx < len(chunks): |
| | results.append({"chunk": chunks[idx], "score": float(dist)}) |
| | return results |
| |
|
| | def answer_question(query, index, chunks): |
| | results = retrieve(query, index, chunks) |
| | context_chunks = [r["chunk"] for r in results] |
| | context_text = "\n".join(f"[{c['chunk_id']}] {c['content']}" for c in context_chunks) |
| | prompt = ( |
| | "Answer the following question using ONLY the provided context and cite the chunk ids used.\n" |
| | f"Question: {query}\nContext:\n{context_text}\nAnswer with citations:" |
| | ) |
| | generated = llm(prompt, max_length=512, num_return_sequences=1) |
| | return generated[0]["generated_text"], "\n".join(f"[{c['chunk_id']} from {c['source']}]" for c in context_chunks) |
| |
|
| | state = {"index": None, "chunks": []} |
| |
|
| | def process(files, urls): |
| | docs = [] |
| | if files: |
| | for f in files: |
| | text = load_file_text(f) |
| | if text: |
| | docs.append({"source": f.name, "text": text}) |
| | if urls: |
| | for url in urls.strip().splitlines(): |
| | text = fetch_web_text(url.strip()) |
| | if text: |
| | docs.append({"source": url.strip(), "text": text}) |
| | if len(docs) == 0: |
| | return "No documents or URLs loaded." |
| | index, chunks = build_index_and_chunks(docs) |
| | state["index"], state["chunks"] = index, chunks |
| | return f"Loaded {len(docs)} docs, created {len(chunks)} chunks." |
| |
|
| | def chat_response(user_message, history): |
| | if state["index"] is None or len(state["chunks"]) == 0: |
| | bot_message = "Please upload documents or enter URLs, then press 'Load & Process' first." |
| | else: |
| | answer, sources = answer_question(user_message, state["index"], state["chunks"]) |
| | bot_message = answer + "\n\nSources:\n" + sources |
| | history = history or [] |
| | history.append(("User: " + user_message, "Assistant: " + bot_message)) |
| | return "", history |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown("# 📚 RAG Chatbot with Mistral-7B and FAISS") |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | file_input = gr.File(label="Upload Files (PDF, DOCX, TXT, CSV)", file_types=[".pdf", ".docx", ".txt", ".csv"], file_count="multiple") |
| | url_input = gr.Textbox(label="Enter URLs (one per line)", lines=4) |
| | process_button = gr.Button("Load & Process Documents and URLs") |
| | output_log = gr.Textbox(label="Status") |
| |
|
| | with gr.Column(scale=2): |
| | chatbot = gr.Chatbot() |
| | user_input = gr.Textbox(placeholder="Ask a question about the loaded documents...", show_label=False) |
| | submit_btn = gr.Button("Send") |
| |
|
| | process_button.click(process, inputs=[file_input, url_input], outputs=output_log) |
| | submit_btn.click(chat_response, inputs=[user_input, chatbot], outputs=[user_input, chatbot]) |
| |
|
| | demo.launch() |
| |
|