atuanlausu's picture
Update app.py
957dbc1 verified
import os
import gradio as gr
import faiss
import pickle
from sentence_transformers import SentenceTransformer
import torch
import openai
from data_setup import setup_knowledge_base
# 檔案路徑
PDF_FILE_PART_1 = '噶哈巫語參考語法(上).pdf'
PDF_FILE_PART_2 = '噶哈巫語參考語法(下).pdf'
CHUNKS_PKL = 'chunks.pkl'
FAISS_INDEX_BIN = 'faiss_index.bin'
# 模型設定
RETRIEVER_MODEL_NAME = 'paraphrase-multilingual-mpnet-base-v2'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 確保API金鑰是從環境變數中安全讀取
# 這個環境變數需要在 Hugging Face Space 的 Settings -> Repository secrets 中設定為 OPENAI_API_KEY
openai.api_key = os.environ.get("OPENAI_API_KEY")
# 全局變數,用於儲存載入的檢索模型和知識庫資料
retriever_model_global = None
index_global = None
chunks_global = None
def load_retriever_components():
"""在應用程式啟動時,一次性載入檢索器和知識庫。"""
global retriever_model_global, index_global, chunks_global
print("--- 檢查知識庫檔案 ---")
if not os.path.exists(CHUNKS_PKL) or not os.path.exists(FAISS_INDEX_BIN):
print("知識庫檔案不存在,正在重新建立...")
setup_knowledge_base(PDF_FILE_PART_1, PDF_FILE_PART_2)
print("--- 載入 Sentence Transformer 模型 ---")
retriever_model_global = SentenceTransformer(RETRIEVER_MODEL_NAME, device=DEVICE)
print("--- 載入 FAISS 索引和語塊 ---")
index_global = faiss.read_index(FAISS_INDEX_BIN)
with open(CHUNKS_PKL, 'rb') as f:
chunks_global = pickle.load(f)
print("知識庫載入成功!")
# 這個函數現在只會更新全局變數,並返回一個狀態字串
return "檢索模型及知識庫載入完畢,請開始提問。"
def ask_question_gradio(question):
"""Gradio介面呼叫的問答函數。"""
if not openai.api_key:
return "API 金鑰未設定。請聯繫開發者確保金鑰已在 Hugging Face Secrets 中設定。", "錯誤"
# 確保全局變數已載入
global retriever_model_global, index_global, chunks_global
if retriever_model_global is None:
# 如果模型未載入 (不應該發生,因為 demo.load 會先執行), 則嘗試載入
load_retriever_components()
if retriever_model_global is None: # 如果仍然為空,表示載入失敗
return "檢索模型或知識庫載入失敗。請檢查日誌。", "錯誤"
print(f"\n使用者提問: {question}")
# 1. 檢索相關語塊
question_embedding = retriever_model_global.encode([question], convert_to_tensor=True).cpu().numpy()
distances, indices = index_global.search(question_embedding, 5)
retrieved_chunks = [chunks_global[i] for i in indices[0]]
# 2. 格式化提示詞給ChatGPT API
context = ""
for i, chunk in enumerate(retrieved_chunks):
context += f"參考資料 {i+1}: {chunk['content']}\n"
prompt = f"""
你是一個專業的噶哈巫語法老師,請根據以下提供的參考資料,精確地、簡潔地回答問題。
你的回答必須:
1. 嚴格根據提供的參考資料,如果資料中沒有足夠的資訊,請禮貌地表示不知道,不要捏造任何內容。
2. 以條列式或段落方式,清晰地呈現答案。
3. 在回答中,用 [頁碼 x] 標籤標註資訊來自哪一份參考資料的頁碼。
以下是參考資料:
{context}
問題: {question}
答案:
"""
# 3. 呼叫ChatGPT API
try:
# 從 OpenAI 1.0.0 版本開始,API 呼叫方式有所改變
client = openai.OpenAI(api_key=openai.api_key) # 使用新的客戶端模式
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "你是一個專業的噶哈巫語法老師,請根據提供的參考資料回答問題。"},
{"role": "user", "content": prompt}
]
)
answer = response.choices[0].message.content
# 4. 格式化最終輸出
formatted_output = f"**問題**: {question}\n\n"
formatted_output += f"**答案**: {answer}\n\n"
formatted_output += "**引用來源**:\n"
# 確保引用來源的頁碼是清晰的
for i, chunk in enumerate(retrieved_chunks):
formatted_output += f"- [頁碼 {', '.join(map(str, chunk['page_numbers']))}]\n" # 顯示所有頁碼
return formatted_output, "提問完成。"
except openai.AuthenticationError:
return "API 金鑰無效,請檢查 Hugging Face Secrets 中的設定。", "錯誤"
except Exception as e:
return f"發生錯誤: {e}\n請檢查日誌獲取更多資訊。", "錯誤"
# Gradio介面定義
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown("# 噶哈巫語法小老師 (ChatGPT API)")
gr.Markdown("歡迎提問,我會根據《噶哈巫語參考語法》的內容來回答你。")
# 狀態文本框用於顯示載入進度
status_text = gr.Textbox(label="應用程式狀態", value="正在載入模型,請稍候...", interactive=False)
with gr.Row():
question_input = gr.Textbox(
label="請輸入你的問題",
lines=2,
placeholder="例如:請問什麼是去詞綴化?",
interactive=False # 初始為不可互動
)
submit_button = gr.Button("提問", interactive=False) # 初始為不可互動
output_textbox = gr.Markdown(label="回答")
# 在介面載入時執行後台載入,並在完成後啟用介面元素
demo.load(
fn=load_retriever_components, # 呼叫這個函數來載入模型和知識庫
inputs=None,
outputs=status_text # 載入完成後,更新狀態文本框
).then(
fn=lambda: (gr.update(interactive=True), gr.update(interactive=True)), # 啟用輸入框和按鈕
inputs=None,
outputs=[question_input, submit_button]
)
# 點擊按鈕時呼叫問答函數
submit_button.click(
fn=ask_question_gradio,
inputs=[question_input], # 不再需要直接傳遞 state,因為模型是全局的
outputs=[output_textbox, status_text] # 同時更新回答和狀態
)
demo.launch(server_name="0.0.0.0", server_port=7860)