| | |
| | |
| |
|
| | import os |
| | import warnings |
| | import gradio as gr |
| | import pandas as pd |
| | from typing import Tuple, List, Optional |
| | import tempfile |
| | warnings.filterwarnings('ignore') |
| |
|
| | |
| | qa_system = None |
| | retriever = None |
| | setup_complete = False |
| | current_source = "default" |
| |
|
| | def install_requirements(): |
| | """Install all required packages""" |
| | packages = [ |
| | "gradio", |
| | "langchain", |
| | "torch", |
| | "transformers", |
| | "sentence-transformers", |
| | "datasets", |
| | "faiss-cpu", |
| | "langchain-community", |
| | "pandas", |
| | "PyPDF2", |
| | "pdfplumber", |
| | "python-docx", |
| | "openpyxl" |
| | ] |
| | |
| | import subprocess |
| | import sys |
| | |
| | for package in packages: |
| | try: |
| | subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package]) |
| | except: |
| | pass |
| |
|
| | def extract_text_from_pdf(pdf_path: str) -> str: |
| | """Extract text from PDF file""" |
| | try: |
| | import PyPDF2 |
| | text = "" |
| | with open(pdf_path, 'rb') as file: |
| | pdf_reader = PyPDF2.PdfReader(file) |
| | for page in pdf_reader.pages: |
| | text += page.extract_text() + "\n" |
| | return text |
| | except Exception as e: |
| | |
| | try: |
| | import pdfplumber |
| | text = "" |
| | with pdfplumber.open(pdf_path) as pdf: |
| | for page in pdf.pages: |
| | page_text = page.extract_text() |
| | if page_text: |
| | text += page_text + "\n" |
| | return text |
| | except Exception as e2: |
| | raise Exception(f"Failed to extract PDF text: {e2}") |
| |
|
| | def process_uploaded_file(file_path: str) -> Tuple[str, List]: |
| | """Process uploaded file and return text content and documents""" |
| | if not file_path: |
| | return "No file uploaded", [] |
| | |
| | try: |
| | file_extension = os.path.splitext(file_path)[1].lower() |
| | |
| | if file_extension == '.pdf': |
| | text = extract_text_from_pdf(file_path) |
| | elif file_extension == '.txt': |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | text = f.read() |
| | elif file_extension == '.docx': |
| | from docx import Document |
| | doc = Document(file_path) |
| | text = "\n".join([paragraph.text for paragraph in doc.paragraphs]) |
| | else: |
| | return f"Unsupported file type: {file_extension}", [] |
| | |
| | if not text.strip(): |
| | return "No text content found in the file", [] |
| | |
| | |
| | from langchain_core.documents import Document |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter |
| | |
| | |
| | text_splitter = RecursiveCharacterTextSplitter( |
| | chunk_size=1000, |
| | chunk_overlap=200, |
| | length_function=len |
| | ) |
| | |
| | |
| | doc = Document(page_content=text, metadata={"source": file_path}) |
| | chunks = text_splitter.split_documents([doc]) |
| | |
| | return f"Successfully processed {len(chunks)} chunks from uploaded file", chunks |
| | |
| | except Exception as e: |
| | return f"Error processing file: {str(e)}", [] |
| |
|
| | def setup_rag_system(use_default_data: bool = True, uploaded_file=None) -> Tuple[str, bool]: |
| | """Setup the RAG system and return status message""" |
| | global qa_system, retriever, setup_complete, current_source |
| | |
| | try: |
| | |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter |
| | from langchain.embeddings import HuggingFaceEmbeddings |
| | from langchain.vectorstores import FAISS |
| | from transformers import pipeline |
| | from langchain import HuggingFacePipeline |
| | from langchain.chains import RetrievalQA |
| | from langchain_core.documents import Document |
| | |
| | status = "π Setting up RAG System...\n" |
| | |
| | |
| | if uploaded_file and not use_default_data: |
| | status += "π Processing uploaded file...\n" |
| | file_status, docs = process_uploaded_file(uploaded_file) |
| | status += f"{file_status}\n" |
| | current_source = "uploaded" |
| | |
| | if not docs: |
| | status += "β No valid documents found. Using default data instead.\n" |
| | use_default_data = True |
| | |
| | if use_default_data or not docs: |
| | status += "π Loading default dataset...\n" |
| | current_source = "default" |
| | try: |
| | |
| | dataset_url = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.parquet" |
| | df = pd.read_parquet(dataset_url) |
| | status += f"β
Dataset loaded! {len(df)} entries found.\n" |
| | |
| | |
| | docs = [] |
| | for i, row in df.iterrows(): |
| | if i < 200: |
| | content = f"Context: {row.get('context', '')}\nQuestion: {row.get('instruction', '')}\nResponse: {row.get('response', '')}" |
| | docs.append(Document(page_content=content, metadata=row.to_dict())) |
| | |
| | except Exception as e: |
| | status += f"β οΈ Dataset download failed: {e}\n" |
| | status += "π Using sample data...\n" |
| | |
| | sample_data = [ |
| | {'instruction': 'What is machine learning?', 'context': '', 'response': 'Machine learning is a method of data analysis that automates analytical model building using algorithms that iteratively learn from data.'}, |
| | {'instruction': 'What is artificial intelligence?', 'context': '', 'response': 'Artificial intelligence (AI) is the simulation of human intelligence processes by machines, especially computer systems.'}, |
| | {'instruction': 'What is deep learning?', 'context': '', 'response': 'Deep learning is a subset of machine learning that uses neural networks with multiple layers to analyze data.'}, |
| | {'instruction': 'What is data science?', 'context': '', 'response': 'Data science is an interdisciplinary field that uses scientific methods, processes, algorithms and systems to extract knowledge and insights from data.'}, |
| | {'instruction': 'What is Python programming?', 'context': '', 'response': 'Python is a high-level, interpreted programming language known for its simplicity and readability, widely used in data science and web development.'}, |
| | {'instruction': 'What is natural language processing?', 'context': '', 'response': 'Natural Language Processing (NLP) is a branch of artificial intelligence that helps computers understand and interpret human language.'}, |
| | ] |
| | |
| | docs = [] |
| | for item in sample_data: |
| | content = f"Context: {item['context']}\nQuestion: {item['instruction']}\nResponse: {item['response']}" |
| | docs.append(Document(page_content=content, metadata=item)) |
| | |
| | |
| | if current_source == "default": |
| | status += "βοΈ Splitting documents...\n" |
| | text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) |
| | docs = text_splitter.split_documents(docs) |
| | |
| | status += f"β
Prepared {len(docs)} document chunks!\n" |
| | |
| | |
| | status += "π€ Loading embedding model...\n" |
| | try: |
| | embeddings = HuggingFaceEmbeddings( |
| | model_name="sentence-transformers/all-MiniLM-L6-v2", |
| | model_kwargs={'device': 'cpu'}, |
| | encode_kwargs={'normalize_embeddings': False} |
| | ) |
| | status += "β
Embedding model loaded (all-MiniLM-L6-v2)!\n" |
| | except Exception as e: |
| | |
| | embeddings = HuggingFaceEmbeddings( |
| | model_name="sentence-transformers/paraphrase-MiniLM-L6-v2", |
| | model_kwargs={'device': 'cpu'}, |
| | encode_kwargs={'normalize_embeddings': False} |
| | ) |
| | status += "β
Embedding model loaded (paraphrase-MiniLM-L6-v2)!\n" |
| | |
| | |
| | status += "ποΈ Creating vector database...\n" |
| | db = FAISS.from_documents(docs, embeddings) |
| | status += "β
Vector database created!\n" |
| | |
| | |
| | status += "π€ Loading QA model...\n" |
| | try: |
| | |
| | question_answerer = pipeline( |
| | "question-answering", |
| | model="distilbert-base-cased-distilled-squad", |
| | tokenizer="distilbert-base-cased-distilled-squad", |
| | return_tensors='pt', |
| | max_answer_len=512, |
| | max_seq_len=512 |
| | ) |
| | status += "β
QA model loaded (DistilBERT)!\n" |
| | except Exception as e: |
| | try: |
| | |
| | question_answerer = pipeline( |
| | "question-answering", |
| | model="bert-large-uncased-whole-word-masking-finetuned-squad", |
| | return_tensors='pt', |
| | max_answer_len=512 |
| | ) |
| | status += "β
QA model loaded (BERT-large)!\n" |
| | except Exception as e2: |
| | |
| | question_answerer = pipeline( |
| | "question-answering", |
| | model="deepset/roberta-base-squad2", |
| | return_tensors='pt' |
| | ) |
| | status += "β
QA model loaded (RoBERTa)!\n" |
| | |
| | llm = HuggingFacePipeline( |
| | pipeline=question_answerer, |
| | model_kwargs={"temperature": 0.1, "max_length": 512} |
| | ) |
| | |
| | |
| | status += "βοΈ Setting up retrieval system...\n" |
| | retriever = db.as_retriever(search_kwargs={"k": 4}) |
| | qa_system = RetrievalQA.from_chain_type( |
| | llm=llm, |
| | chain_type="stuff", |
| | retriever=retriever, |
| | return_source_documents=True |
| | ) |
| | status += "β
Retrieval system ready!\n" |
| | |
| | setup_complete = True |
| | source_info = "uploaded PDF" if current_source == "uploaded" else "default dataset" |
| | status += f"\nπ RAG System Setup Complete!\n" |
| | status += f"π Data source: {source_info}\n" |
| | status += f"π Vector database: {len(docs)} chunks\n" |
| | status += "You can now ask questions in the chat interface." |
| | |
| | return status, True |
| | |
| | except Exception as e: |
| | error_msg = f"β Error setting up RAG system: {str(e)}\n" |
| | error_msg += "Please check your internet connection and try again." |
| | return error_msg, False |
| |
|
| | def ask_question(question: str, history: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[str, str]]]: |
| | """Process a question through the RAG system""" |
| | global qa_system, setup_complete, current_source |
| | |
| | if not setup_complete or qa_system is None: |
| | response = "β RAG system is not set up yet. Please click 'Setup RAG System' first." |
| | history.append((question, response)) |
| | return "", history |
| | |
| | if not question.strip(): |
| | response = "Please ask a question!" |
| | history.append((question, response)) |
| | return "", history |
| | |
| | try: |
| | |
| | result = qa_system({"query": question}) |
| | |
| | |
| | answer = result.get('result', 'No answer found.') |
| | sources = result.get('source_documents', []) |
| | |
| | |
| | source_type = "uploaded document" if current_source == "uploaded" else "dataset" |
| | response = f"π€ **Answer:** {answer}\n\n" |
| | |
| | if sources: |
| | response += f"π **Sources from {source_type}:**\n" |
| | for i, doc in enumerate(sources[:3], 1): |
| | preview = doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content |
| | response += f"{i}. {preview}\n\n" |
| | |
| | history.append((question, response)) |
| | return "", history |
| | |
| | except Exception as e: |
| | error_response = f"β Error processing question: {str(e)}" |
| | history.append((question, error_response)) |
| | return "", history |
| |
|
| | def clear_chat(): |
| | """Clear the chat history""" |
| | return [] |
| |
|
| | def get_sample_questions(): |
| | """Return sample questions for testing""" |
| | return [ |
| | "What is machine learning?", |
| | "What is artificial intelligence?", |
| | "What is deep learning?", |
| | "What is data science?", |
| | "What is Python programming?", |
| | "What is natural language processing?" |
| | ] |
| |
|
| | |
| | def create_interface(): |
| | """Create the Gradio interface""" |
| | |
| | with gr.Blocks(title="RAG System with PDF Upload", theme=gr.themes.Soft()) as demo: |
| | gr.HTML(""" |
| | <div style="text-align: center; margin-bottom: 20px;"> |
| | <h1>π€ RAG System with PDF Upload</h1> |
| | <p>Upload your PDF documents or use the default dataset to ask questions</p> |
| | </div> |
| | """) |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | gr.HTML("<h3>π Upload Your Document</h3>") |
| | file_upload = gr.File( |
| | label="Upload PDF, TXT, or DOCX file", |
| | file_types=[".pdf", ".txt", ".docx"], |
| | type="filepath" |
| | ) |
| | |
| | with gr.Row(): |
| | use_default = gr.Checkbox( |
| | label="Use default dataset instead", |
| | value=True, |
| | info="Check this to use the default Databricks dataset" |
| | ) |
| | |
| | setup_btn = gr.Button("π Setup RAG System", variant="primary", size="lg") |
| | |
| | setup_status = gr.Textbox( |
| | label="Setup Status", |
| | lines=12, |
| | max_lines=20, |
| | value="1. Upload a document OR check 'Use default dataset'\n2. Click 'Setup RAG System' to initialize\n3. Start asking questions!", |
| | interactive=False |
| | ) |
| | |
| | gr.HTML("<hr>") |
| | |
| | gr.HTML("<h3>π‘ Tips:</h3>") |
| | gr.HTML(""" |
| | <ul> |
| | <li><strong>PDF:</strong> Upload research papers, reports, manuals</li> |
| | <li><strong>TXT:</strong> Plain text documents</li> |
| | <li><strong>DOCX:</strong> Word documents</li> |
| | <li><strong>Default:</strong> Use Databricks Dolly dataset for general questions</li> |
| | </ul> |
| | """) |
| | |
| | gr.HTML("<h3>π Sample Questions:</h3>") |
| | sample_questions = get_sample_questions() |
| | for i, question in enumerate(sample_questions, 1): |
| | gr.HTML(f"<p><strong>{i}.</strong> {question}</p>") |
| | |
| | with gr.Column(scale=2): |
| | chatbot = gr.Chatbot( |
| | label="π¬ Chat with Your Documents", |
| | height=500, |
| | bubble_full_width=False |
| | ) |
| | |
| | with gr.Row(): |
| | msg = gr.Textbox( |
| | label="Your Question", |
| | placeholder="Ask a question about your document or the dataset...", |
| | scale=4 |
| | ) |
| | send_btn = gr.Button("Send", variant="primary", scale=1) |
| | |
| | with gr.Row(): |
| | clear_btn = gr.Button("Clear Chat", variant="secondary") |
| | |
| | |
| | gr.HTML(""" |
| | <div style="margin-top: 10px; padding: 10px; background-color: #f0f0f0; border-radius: 5px;"> |
| | <strong>π How it works:</strong> |
| | <br>1. Upload your document or use default dataset |
| | <br>2. The system creates embeddings and a searchable index |
| | <br>3. Your questions are matched with relevant content |
| | <br>4. An AI model generates answers based on the context |
| | </div> |
| | """) |
| | |
| | |
| | setup_btn.click( |
| | setup_rag_system, |
| | inputs=[use_default, file_upload], |
| | outputs=[setup_status, gr.State()] |
| | ) |
| | |
| | msg.submit( |
| | ask_question, |
| | inputs=[msg, chatbot], |
| | outputs=[msg, chatbot] |
| | ) |
| | |
| | send_btn.click( |
| | ask_question, |
| | inputs=[msg, chatbot], |
| | outputs=[msg, chatbot] |
| | ) |
| | |
| | clear_btn.click( |
| | clear_chat, |
| | outputs=[chatbot] |
| | ) |
| | |
| | |
| | with gr.Row(): |
| | gr.HTML("<h3>π Quick Start Questions:</h3>") |
| | |
| | with gr.Row(): |
| | for question in sample_questions[:3]: |
| | btn = gr.Button(question, size="sm") |
| | btn.click( |
| | lambda q=question: (q, []), |
| | outputs=[msg, chatbot] |
| | ) |
| | |
| | return demo |
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | print("Installing requirements...") |
| | install_requirements() |
| | |
| | |
| | demo = create_interface() |
| | |
| | |
| | demo.launch( |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | share=True, |
| | debug=True |
| | ) |