import os from dotenv import load_dotenv from fastapi import FastAPI, Response, Body, Security from fastapi.security import APIKeyHeader from pydantic import BaseModel, model_validator from typing import List import json from conversation.conversation_store import ConversationStore from rag_langchain import LangChainRAG load_dotenv() api_keys = [os.environ["API_API_KEY"]] api = FastAPI() conversation_store = ConversationStore() api_key_header = APIKeyHeader(name="Authorization", auto_error=True) prompt_id = "summarize_rag_1" check_prompt_id = "check_control_challenge_step_back" rewrite_prompt_id = "first" default_llm = "gpt-4o 128k" class QModel(BaseModel): q: str retrieval_count: int = 10 temperature: str = "0.2" llm: str = default_llm @classmethod @model_validator(mode='before') def validate_to_json(cls, value): if isinstance(value, str): return cls(**json.loads(value)) return value class AModel(BaseModel): q: str a: str sources: List[str] oid: str class EmoModel(BaseModel): qid: str helpfulness: str @classmethod @model_validator(mode='before') def validate_to_json(cls, value): if isinstance(value, str): return cls(**json.loads(value)) return value @api.get("/") async def read_root(): return "Empty" @api.post("/qa", response_model=AModel) async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)): if not valid_api_key(api_key): return Response(status_code=401) rag = LangChainRAG( config={ "retrieve_documents": data.retrieval_count, "temperature": data.temperature, "prompt_id": prompt_id, "check_prompt_id": check_prompt_id, "rewrite_prompt_id": rewrite_prompt_id } ) answer, check_result, sources = await rag.rag_chain(data.q, data.llm) oid = conversation_store.save_content( q=data.q, a=answer, sources=list(map(lambda doc: doc.page_content, sources)), params= { "prompt_id": prompt_id, "check_prompt_id": check_prompt_id, "rewrite_prompt_id": rewrite_prompt_id, "check_result": check_result, "temperature": data.temperature, "retrieve_document_count": str(data.retrieval_count), } ) return AModel( a=answer, q=data.q, sources=list(map(lambda doc: doc.page_content, sources)), oid=oid ) @api.post("/emo") async def emo(api_key: str = Security(api_key_header), json_body: EmoModel = Body(...)): if not valid_api_key(api_key): return Response(status_code=401) conversation = conversation_store.get(json_body.qid) new_params = conversation.params new_params["user_grading"] = str(json_body.helpfulness) conversation_store.update( oid=json_body["qid"], q=conversation.conversation[0].q, a=conversation.conversation[0].a, sources=conversation.conversation[0].sources, params=new_params ) def valid_api_key(api_key: str) -> bool: return api_key in api_keys