| | |
| | import gradio as gr |
| | import os |
| | import json |
| | from glob import glob |
| | import requests |
| | from langchain import FAISS |
| | from langchain.embeddings import CohereEmbeddings, OpenAIEmbeddings |
| | from langchain import VectorDBQA |
| | from langchain.chat_models import ChatOpenAI |
| | from prompts import MyTemplate |
| | from build_index.run import process_files |
| | from langchain.prompts.chat import ( |
| | ChatPromptTemplate, |
| | SystemMessagePromptTemplate, |
| | HumanMessagePromptTemplate, |
| | ) |
| | from langchain.prompts import PromptTemplate |
| | from langchain.chains.llm import LLMChain |
| | from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain |
| | from langchain.chains import QAGenerationChain |
| | from langchain.chains.combine_documents.stuff import StuffDocumentsChain |
| |
|
| | |
| | API_URL = "https://api.openai.com/v1/chat/completions" |
| | cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj' |
| | faiss_store = './output/' |
| | docsearch = None |
| |
|
| |
|
| | def process(files, openai_api_key, max_tokens, n_sample): |
| | """ |
| | 对文档处理进行摘要,构建问题,构建文档索引 |
| | """ |
| | os.environ['OPENAI_API_KEY'] = openai_api_key |
| | print('Displaying uploading files ') |
| | print(glob('/tmp/*')) |
| | docs = process_files([i.name for i in files], 'openai', max_tokens) |
| | print('Display Faiss index') |
| | print(glob('./output/*')) |
| | question = get_question(docs, openai_api_key, max_tokens, n_sample) |
| | summary = get_summary(docs, openai_api_key, max_tokens, n_sample) |
| | return question, summary |
| |
|
| |
|
| | def get_question(docs, openai_api_key, max_tokens, n_sample=5): |
| | q_list = [] |
| | llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens, temperature=0) |
| | |
| | prompt = ChatPromptTemplate.from_messages( |
| | [ |
| | SystemMessagePromptTemplate.from_template(MyTemplate['qa_sys_template']), |
| | HumanMessagePromptTemplate.from_template(MyTemplate['qa_user_template']), |
| | ] |
| | ) |
| | chain = QAGenerationChain.from_llm(llm, prompt=prompt) |
| | print('Generating Question from template') |
| | for i in range(n_sample): |
| | qa = chain.run(docs[i].page_content)[0] |
| | print(qa) |
| | q_list.append(f"问题{i + 1}: {qa['question']}") |
| | return '\n'.join(q_list) |
| |
|
| |
|
| | def get_summary(docs, openai_api_key, max_tokens, n_sample=5, verbose=None): |
| | llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens) |
| | print('Generating Summary from template') |
| | map_prompt = PromptTemplate(template=MyTemplate['summary_template'], input_variables=["text"]) |
| | combine_prompt = PromptTemplate(template=MyTemplate['summary_template'], input_variables=["text"]) |
| | map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose) |
| | reduce_chain = LLMChain(llm=llm, prompt=combine_prompt, verbose=verbose) |
| | combine_document_chain = StuffDocumentsChain( |
| | llm_chain=reduce_chain, |
| | document_variable_name='text', |
| | verbose=verbose, |
| | ) |
| | chain = MapReduceDocumentsChain( |
| | llm_chain=map_chain, |
| | combine_document_chain=combine_document_chain, |
| | document_variable_name='text', |
| | collapse_document_chain=None, |
| | verbose=verbose |
| | ) |
| | summary = chain.run(docs[:n_sample]) |
| | print(summary) |
| | return summary |
| |
|
| |
|
| | def predict(inputs, openai_api_key, max_tokens, chat_counter, chatbot=[], history=[]): |
| | global docsearch |
| | print(f"chat_counter - {chat_counter}") |
| | print(f'Histroy - {history}') |
| | print(f'chatbot - {chatbot}') |
| |
|
| | history.append(inputs) |
| | if docsearch is None: |
| | print(f'loading faiss store from {faiss_store}') |
| | docsearch = FAISS.load_local(faiss_store, OpenAIEmbeddings(openai_api_key=openai_api_key)) |
| | else: |
| | print('faiss already loaded') |
| | |
| | llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens) |
| | messages_combine = [ |
| | SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']), |
| | HumanMessagePromptTemplate.from_template("{question}") |
| | ] |
| | p_chat_combine = ChatPromptTemplate.from_messages(messages_combine) |
| | messages_reduce = [ |
| | SystemMessagePromptTemplate.from_template(MyTemplate['chat_reduce_template']), |
| | HumanMessagePromptTemplate.from_template("{question}") |
| | ] |
| | p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce) |
| | chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch, |
| | k=4, |
| | chain_type_kwargs={"question_prompt": p_chat_reduce, |
| | "combine_prompt": p_chat_combine} |
| | ) |
| | result = chain({"query": inputs}) |
| | print(result) |
| | result = result['result'] |
| | |
| | history.append(result) |
| | chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] |
| | chat_counter += 1 |
| | yield chat, history, chat_counter |
| |
|
| |
|
| | def reset_textbox(): |
| | return gr.update(value='') |
| |
|
| |
|
| | with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;} |
| | #chatbot {height: 520px; overflow: auto;}""") as demo: |
| | gr.HTML("""<h1 align="center">🚀Smart Doc Reader🚀</h1>""") |
| | with gr.Column(elem_id="col_container"): |
| | openai_api_key = gr.Textbox(type='password', label="输入 API Key") |
| |
|
| | with gr.Accordion("Parameters", open=True): |
| | with gr.Row(): |
| | max_tokens = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, interactive=True, |
| | label="字数") |
| | chat_counter = gr.Number(value=0, precision=0, label='对话轮数') |
| | n_sample = gr.Slider(minimum=3, maximum=5, value=3, step=1, interactive=True, |
| | label="问题数") |
| |
|
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | files = gr.File(file_count="multiple", file_types=[".pdf"], label='上传pdf文件') |
| | run = gr.Button('文档内容解读') |
| |
|
| | with gr.Column(): |
| | summary = gr.Textbox(type='text', label="一眼看尽 - 文档概览") |
| | question = gr.Textbox(type='text', label='推荐问题 - 问别的也行哟') |
| |
|
| | chatbot = gr.Chatbot(elem_id='chatbot') |
| | inputs = gr.Textbox(placeholder="这篇文档是关于什么的", label="针对文档你有哪些问题?") |
| | state = gr.State([]) |
| |
|
| | with gr.Row(): |
| | clear = gr.Button("清空") |
| | start = gr.Button("提问") |
| |
|
| | run.click(process, [files, openai_api_key, max_tokens, n_sample], [question, summary]) |
| | inputs.submit(predict, |
| | [inputs, openai_api_key, max_tokens, chat_counter, chatbot, state], |
| | [chatbot, state, chat_counter], ) |
| | start.click(predict, |
| | [inputs, openai_api_key, max_tokens, chat_counter, chatbot, state], |
| | [chatbot, state, chat_counter], ) |
| |
|
| | |
| | clear.click(reset_textbox, [], [inputs], queue=False) |
| | inputs.submit(reset_textbox, [], [inputs]) |
| | demo.queue().launch(debug=True) |
| |
|