def demo():
custom_css = """
.horizontal-container {
display: flex !important;
flex-direction: row !important;
flex-wrap: nowrap !important;
width: 100% !important;
}
.column-1 {
min-width: 300px !important;
max-width: 35% !important;
flex: 1 !important;
}
.column-2 {
min-width: 500px !important;
flex: 2 !important;
}
@media (max-width: 900px) {
.column-1 { max-width: 40% !important; }
.column-2 { min-width: 400px !important; }
}
"""
with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky"), css=custom_css) as demo:
vector_db = gr.State()
qa_chain = gr.State()
gr.HTML("
RAG PDF chatbot
")
gr.Markdown("""Query your PDF documents! This AI agent is designed to perform retrieval augmented generation (RAG) on PDF documents. The app is hosted on Hugging Face Hub for the sole purpose of demonstration. \
Please do not upload confidential documents.
""")
with gr.Row(elem_classes="horizontal-container"):
with gr.Column(elem_classes="column-1"):
gr.Markdown("Step 1 - Upload PDF documents and Initialize RAG pipeline")
with gr.Row():
document = gr.Files(height=300, file_count="multiple", file_types=[".pdf"], interactive=True, label="Upload PDF documents")
with gr.Row():
db_btn = gr.Button("Create vector database")
with gr.Row():
db_progress = gr.Textbox(value="Not initialized", show_label=False)
gr.Markdown("Select Large Language Model (LLM) and input parameters")
with gr.Row():
llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple[0], type="index")
with gr.Row():
with gr.Accordion("LLM input parameters", open=False):
with gr.Row():
slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", info="Controls randomness in token generation", interactive=True)
with gr.Row():
slider_maxtokens = gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max New Tokens", info="Maximum number of tokens to be generated", interactive=True)
with gr.Row():
slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k", info="Number of tokens to select the next token from", interactive=True)
with gr.Row():
qachain_btn = gr.Button("Initialize Question Answering Chatbot")
with gr.Row():
llm_progress = gr.Textbox(value="Not initialized", show_label=False)
with gr.Column(elem_classes="column-2"):
gr.Markdown("Step 2 - Chat with your Document")
chatbot = gr.Chatbot(height=505)
with gr.Accordion("Relevant context from the source document", open=False):
with gr.Row():
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
source1_page = gr.Number(label="Page", scale=1)
with gr.Row():
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
source2_page = gr.Number(label="Page", scale=1)
with gr.Row():
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
source3_page = gr.Number(label="Page", scale=1)
with gr.Row():
msg = gr.Textbox(placeholder="Ask a question", container=True)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
# Rest of your event handlers remain the same...
db_btn.click(initialize_database,
inputs=[document],
outputs=[vector_db, db_progress])
qachain_btn.click(initialize_LLM,
inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
inputs=None,
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False)
msg.submit(conversation,
inputs=[qa_chain, msg, chatbot],
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False)
submit_btn.click(conversation,
inputs=[qa_chain, msg, chatbot],
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False)
clear_btn.click(lambda:[None,"",0,"",0,"",0],
inputs=None,
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False)
demo.queue().launch(debug=True)