Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import faiss | |
| import pickle | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| import openai | |
| from data_setup import setup_knowledge_base | |
| # 檔案路徑 | |
| PDF_FILE_PART_1 = '噶哈巫語參考語法(上).pdf' | |
| PDF_FILE_PART_2 = '噶哈巫語參考語法(下).pdf' | |
| CHUNKS_PKL = 'chunks.pkl' | |
| FAISS_INDEX_BIN = 'faiss_index.bin' | |
| # 模型設定 | |
| RETRIEVER_MODEL_NAME = 'paraphrase-multilingual-mpnet-base-v2' | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # 確保API金鑰是從環境變數中安全讀取 | |
| # 這個環境變數需要在 Hugging Face Space 的 Settings -> Repository secrets 中設定為 OPENAI_API_KEY | |
| openai.api_key = os.environ.get("OPENAI_API_KEY") | |
| # 全局變數,用於儲存載入的檢索模型和知識庫資料 | |
| retriever_model_global = None | |
| index_global = None | |
| chunks_global = None | |
| def load_retriever_components(): | |
| """在應用程式啟動時,一次性載入檢索器和知識庫。""" | |
| global retriever_model_global, index_global, chunks_global | |
| print("--- 檢查知識庫檔案 ---") | |
| if not os.path.exists(CHUNKS_PKL) or not os.path.exists(FAISS_INDEX_BIN): | |
| print("知識庫檔案不存在,正在重新建立...") | |
| setup_knowledge_base(PDF_FILE_PART_1, PDF_FILE_PART_2) | |
| print("--- 載入 Sentence Transformer 模型 ---") | |
| retriever_model_global = SentenceTransformer(RETRIEVER_MODEL_NAME, device=DEVICE) | |
| print("--- 載入 FAISS 索引和語塊 ---") | |
| index_global = faiss.read_index(FAISS_INDEX_BIN) | |
| with open(CHUNKS_PKL, 'rb') as f: | |
| chunks_global = pickle.load(f) | |
| print("知識庫載入成功!") | |
| # 這個函數現在只會更新全局變數,並返回一個狀態字串 | |
| return "檢索模型及知識庫載入完畢,請開始提問。" | |
| def ask_question_gradio(question): | |
| """Gradio介面呼叫的問答函數。""" | |
| if not openai.api_key: | |
| return "API 金鑰未設定。請聯繫開發者確保金鑰已在 Hugging Face Secrets 中設定。", "錯誤" | |
| # 確保全局變數已載入 | |
| global retriever_model_global, index_global, chunks_global | |
| if retriever_model_global is None: | |
| # 如果模型未載入 (不應該發生,因為 demo.load 會先執行), 則嘗試載入 | |
| load_retriever_components() | |
| if retriever_model_global is None: # 如果仍然為空,表示載入失敗 | |
| return "檢索模型或知識庫載入失敗。請檢查日誌。", "錯誤" | |
| print(f"\n使用者提問: {question}") | |
| # 1. 檢索相關語塊 | |
| question_embedding = retriever_model_global.encode([question], convert_to_tensor=True).cpu().numpy() | |
| distances, indices = index_global.search(question_embedding, 5) | |
| retrieved_chunks = [chunks_global[i] for i in indices[0]] | |
| # 2. 格式化提示詞給ChatGPT API | |
| context = "" | |
| for i, chunk in enumerate(retrieved_chunks): | |
| context += f"參考資料 {i+1}: {chunk['content']}\n" | |
| prompt = f""" | |
| 你是一個專業的噶哈巫語法老師,請根據以下提供的參考資料,精確地、簡潔地回答問題。 | |
| 你的回答必須: | |
| 1. 嚴格根據提供的參考資料,如果資料中沒有足夠的資訊,請禮貌地表示不知道,不要捏造任何內容。 | |
| 2. 以條列式或段落方式,清晰地呈現答案。 | |
| 3. 在回答中,用 [頁碼 x] 標籤標註資訊來自哪一份參考資料的頁碼。 | |
| 以下是參考資料: | |
| {context} | |
| 問題: {question} | |
| 答案: | |
| """ | |
| # 3. 呼叫ChatGPT API | |
| try: | |
| # 從 OpenAI 1.0.0 版本開始,API 呼叫方式有所改變 | |
| client = openai.OpenAI(api_key=openai.api_key) # 使用新的客戶端模式 | |
| response = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": "你是一個專業的噶哈巫語法老師,請根據提供的參考資料回答問題。"}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| ) | |
| answer = response.choices[0].message.content | |
| # 4. 格式化最終輸出 | |
| formatted_output = f"**問題**: {question}\n\n" | |
| formatted_output += f"**答案**: {answer}\n\n" | |
| formatted_output += "**引用來源**:\n" | |
| # 確保引用來源的頁碼是清晰的 | |
| for i, chunk in enumerate(retrieved_chunks): | |
| formatted_output += f"- [頁碼 {', '.join(map(str, chunk['page_numbers']))}]\n" # 顯示所有頁碼 | |
| return formatted_output, "提問完成。" | |
| except openai.AuthenticationError: | |
| return "API 金鑰無效,請檢查 Hugging Face Secrets 中的設定。", "錯誤" | |
| except Exception as e: | |
| return f"發生錯誤: {e}\n請檢查日誌獲取更多資訊。", "錯誤" | |
| # Gradio介面定義 | |
| with gr.Blocks(theme=gr.themes.Base()) as demo: | |
| gr.Markdown("# 噶哈巫語法小老師 (ChatGPT API)") | |
| gr.Markdown("歡迎提問,我會根據《噶哈巫語參考語法》的內容來回答你。") | |
| # 狀態文本框用於顯示載入進度 | |
| status_text = gr.Textbox(label="應用程式狀態", value="正在載入模型,請稍候...", interactive=False) | |
| with gr.Row(): | |
| question_input = gr.Textbox( | |
| label="請輸入你的問題", | |
| lines=2, | |
| placeholder="例如:請問什麼是去詞綴化?", | |
| interactive=False # 初始為不可互動 | |
| ) | |
| submit_button = gr.Button("提問", interactive=False) # 初始為不可互動 | |
| output_textbox = gr.Markdown(label="回答") | |
| # 在介面載入時執行後台載入,並在完成後啟用介面元素 | |
| demo.load( | |
| fn=load_retriever_components, # 呼叫這個函數來載入模型和知識庫 | |
| inputs=None, | |
| outputs=status_text # 載入完成後,更新狀態文本框 | |
| ).then( | |
| fn=lambda: (gr.update(interactive=True), gr.update(interactive=True)), # 啟用輸入框和按鈕 | |
| inputs=None, | |
| outputs=[question_input, submit_button] | |
| ) | |
| # 點擊按鈕時呼叫問答函數 | |
| submit_button.click( | |
| fn=ask_question_gradio, | |
| inputs=[question_input], # 不再需要直接傳遞 state,因為模型是全局的 | |
| outputs=[output_textbox, status_text] # 同時更新回答和狀態 | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |