adrien.aribaut-gaudin
fix: gitignore for the database folder + prompt for requirements + 3 blocks max for best_sources
8e58322
| import asyncio | |
| import os | |
| import shutil | |
| import json | |
| from typing import Dict | |
| import random | |
| import datetime | |
| import string | |
| import docx | |
| import pandas as pd | |
| from src.domain.block import Block | |
| from src.tools.doc_tools import get_title | |
| from src.domain.doc import Doc | |
| from src.domain.wikidoc import WikiPage | |
| from src.view.log_msg import create_msg_from | |
| import src.tools.semantic_db as semantic_db | |
| from src.tools.wiki import Wiki | |
| from src.llm.llm_tools import generate_response_to_exigence | |
| from src.llm.llm_tools import get_wikilist, get_public_paragraph, get_private_paragraph | |
| from src.tools.semantic_db import add_texts_to_collection, query_collection | |
| from src.tools.excel_tools import excel_to_dict | |
| import gradio as gr | |
| from src.retriever.retriever import Retriever | |
| class Controller: | |
| def __init__(self, config: Dict, client_db, retriever): | |
| self.templates_path = config['templates_path'] | |
| self.generated_docs_path = config['generated_docs_path'] | |
| self.styled_docs_path = config['styled_docs_path'] | |
| self.excel_doc_path = config['excel_doc_path'] | |
| self.new_docs = [] | |
| self.gen_docs = [] | |
| self.input_csv = "" | |
| template_path = config['templates_path'] + '/' + config['templates'][config['default_template_index']] | |
| self.default_template = Doc(template_path) | |
| self.template = self.default_template | |
| self.log = [] | |
| self.differences = [] | |
| self.list_differences = [] | |
| self.client_db = client_db | |
| self.retriever = retriever | |
| def copy_docs(self, temp_docs: []): | |
| """ | |
| Initial copy of the incoming document | |
| + | |
| create collection for requirments retrieval | |
| + | |
| Initiate paths | |
| TODO: Rename or refactor the function -> 1 mission / function | |
| TODO: To be tested on several documents | |
| TODO: Rename create_collection in create_requirement_collection | |
| """ | |
| doc_names = [doc.name for doc in temp_docs] | |
| for i in range(len(doc_names)): | |
| if '/' in doc_names[i]: | |
| doc_names[i] = doc_names[i].split('/')[-1] | |
| elif '\\' in doc_names[i]: | |
| doc_names[i] = doc_names[i].split('\\')[-1] | |
| doc_names[i] = doc_names[i].split('.')[0] | |
| docs = [Doc(path=doc.name) for doc in temp_docs] | |
| self.create_collection(docs) | |
| style_paths = [f"{self.generated_docs_path}/{dn}_.docx" for dn in doc_names] | |
| gen_paths = [f"{self.generated_docs_path}/{dn}_e.docx" for dn in doc_names] | |
| for doc, style_path, gen_path in zip(docs, style_paths, gen_paths): | |
| new_doc = doc.copy(style_path) | |
| self.new_docs.append(new_doc) | |
| def clear_docs(self): | |
| for new_doc in self.new_docs: | |
| if os.path.exists(new_doc.path): | |
| new_doc.clear() | |
| for gen_doc in self.gen_docs: | |
| if os.path.exists(gen_doc.path): | |
| gen_doc.clear() | |
| self.new_docs = [] | |
| self.gen_docs = [] | |
| self.log = [] | |
| path_to_clear = os.path.abspath(self.generated_docs_path) | |
| second_path_to_clear = os.path.abspath(self.excel_doc_path) | |
| [os.remove(f"{path_to_clear}/{doc}") for doc in os.listdir(path_to_clear)] | |
| [os.remove(f"{second_path_to_clear}/{doc}") for doc in os.listdir(second_path_to_clear)] | |
| def set_template(self, template_name: str = ""): | |
| if not template_name: | |
| self.template = self.default_template | |
| else: | |
| template_path = f"{self.templates_path}/{template_name}" | |
| self.template = Doc(template_path) | |
| def add_template(self, template_path: str): | |
| """ | |
| TODO: message to be but in config | |
| """ | |
| if not template_path: | |
| return | |
| elif not template_path.name.endswith(".docx"): | |
| gr.Warning("Seuls les fichiers .docx sont acceptés") | |
| return | |
| doc = docx.Document(template_path.name) | |
| doc.save(self.templates_path + '/' + get_title(template_path.name)) | |
| def delete_curr_template(self, template_name: str): | |
| if not template_name: | |
| return | |
| os.remove(f"{self.templates_path}/{template_name}") | |
| def retrieve_number_of_misapplied_styles(self): | |
| """ | |
| not used: buggy !! | |
| """ | |
| res = {} | |
| for new_doc in self.new_docs: | |
| res[new_doc] = new_doc.retrieve_number_of_misapplied_styles() | |
| return res | |
| def get_difference_with_template(self): | |
| self.differences = [] | |
| for new_doc in self.new_docs: | |
| diff_styles = new_doc.get_different_styles_with_template(template=self.template) | |
| diff_dicts = [{'doc': new_doc, 'style': s} for s in diff_styles] | |
| self.differences += diff_dicts | |
| template_styles = self.template.xdoc.styles | |
| template_styles = [style for style in template_styles if style.name in self.template.styles.names] | |
| return self.differences, template_styles | |
| def get_list_styles(self): | |
| self.list_differences = [] | |
| for new_doc in self.new_docs: | |
| list_styles = new_doc.get_list_styles() | |
| all_lists_styles = [{'doc': new_doc, 'list_style': s} for s in list_styles] | |
| self.list_differences += all_lists_styles | |
| return self.list_differences | |
| def map_style(self, this_style_index: int, template_style_name: str): | |
| """ | |
| maps a style from 'this' document into a style from the template | |
| """ | |
| #dont make any change if the style is already the same | |
| diff_dict = self.differences[this_style_index] | |
| doc = diff_dict['doc'] | |
| this_style_name = diff_dict['style'] | |
| log = doc.copy_one_style(this_style_name, template_style_name, self.template) | |
| if log: | |
| self.log.append({doc.name: log}) | |
| def update_list_style(self, this_style_index: int, template_style_name: str): | |
| """ | |
| maps a style from 'this' document into a style from the template | |
| """ | |
| #dont make any change if the style is already the same | |
| diff_dict = self.list_differences[this_style_index] | |
| doc = diff_dict['doc'] | |
| this_style_name = diff_dict['list_style'] | |
| log = doc.change_bullet_style(this_style_name, template_style_name, self.template) | |
| if log: | |
| self.log.append({doc.name: log}) | |
| def update_style(self,index,style_to_modify): | |
| return self.map_style(index, style_to_modify) if style_to_modify else None | |
| def apply_template(self, options_list): | |
| for new_doc in self.new_docs: | |
| log = new_doc.apply_template(template=self.template, options_list=options_list) | |
| if log: | |
| self.log.append({new_doc.name: log}) | |
| def reset(self): | |
| for new_doc in self.new_docs: | |
| new_doc.delete() | |
| for gen_doc in self.gen_docs: | |
| gen_doc.delete() | |
| self.new_docs = [] | |
| self.gen_docs = [] | |
| def get_log(self): | |
| msg_log = create_msg_from(self.log, self.new_docs) | |
| return msg_log | |
| """ | |
| Source Control | |
| """ | |
| def get_or_create_collection(self, id_: str) -> str: | |
| """ | |
| generates a new id if needed | |
| TODO: rename into get_or_create_generation_collection | |
| TODO: have a single DB with separate collections, one for requirements, one for generation | |
| """ | |
| if id_ != '-1': | |
| return id_ | |
| else: | |
| now = datetime.datetime.now().strftime("%m%d%H%M") | |
| letters = string.ascii_lowercase + string.digits | |
| id_ = now + '-' + ''.join(random.choice(letters) for _ in range(10)) | |
| semantic_db.get_or_create_collection(id_) | |
| return id_ | |
| async def wiki_fetch(self) -> [str]: | |
| """ | |
| returns the title of the wikipages corresponding to the tasks described in the input text | |
| """ | |
| all_tasks = [] | |
| for new_doc in self.new_docs: | |
| all_tasks += new_doc.tasks | |
| async_tasks = [asyncio.create_task(get_wikilist(task)) for task in all_tasks] | |
| wiki_lists = await asyncio.gather(*async_tasks) | |
| flatten_wiki_list = list(set().union(*[set(w) for w in wiki_lists])) | |
| return flatten_wiki_list | |
| async def wiki_upload_and_store(self, wiki_title: str, collection_name: str): | |
| """ | |
| uploads one wikipage and stores them into the right collection | |
| """ | |
| wikipage = Wiki().fetch(wiki_title) | |
| wiki_title = wiki_title | |
| if type(wikipage) != str: | |
| texts = WikiPage(wikipage.page_content).get_paragraphs() | |
| add_texts_to_collection(coll_name=collection_name, texts=texts, file=wiki_title, source='wiki') | |
| else: | |
| print(wikipage) | |
| """ | |
| Generate Control | |
| """ | |
| async def generate_doc_from_db(self, collection_name: str, from_files: [str]) -> [str]: | |
| def query_from_task(task): | |
| return get_public_paragraph(task) | |
| async def retrieve_text_and_generate(t, collection_name: str, from_files: [str]): | |
| """ | |
| retreives the texts from the database and generates the documents | |
| """ | |
| # retreive the texts from the database | |
| task_query = query_from_task(t) | |
| texts = query_collection(coll_name=collection_name, query=task_query, from_files=from_files) | |
| task_resolutions = get_private_paragraph(task=t, texts=texts) | |
| return task_resolutions | |
| async def real_doc_generation(new_doc): | |
| async_task_resolutions = [asyncio.create_task(retrieve_text_and_generate(t=task, collection_name=collection_name, from_files=from_files)) | |
| for task in new_doc.tasks] | |
| tasks_resolutions = await asyncio.gather(*async_task_resolutions) #A VOIR | |
| gen_path = f"{self.generated_docs_path}/{new_doc.name}e.docx" | |
| gen_doc = new_doc.copy(gen_path) | |
| gen_doc.replace_tasks(tasks_resolutions) | |
| gen_doc.save_as_docx() | |
| gen_paths.append(gen_doc.path) | |
| self.gen_docs.append(gen_doc) | |
| return gen_paths | |
| gen_paths = [] | |
| gen_paths = await asyncio.gather(*[asyncio.create_task(real_doc_generation(new_doc)) for new_doc in self.new_docs]) | |
| gen_paths = [path for sublist in gen_paths for path in sublist] | |
| gen_paths = list(set(gen_paths)) | |
| return gen_paths | |
| """ | |
| Requirements | |
| """ | |
| def clear_input_csv(self): | |
| self.input_csv = "" | |
| [os.remove(f"{self.excel_doc_path}/{doc}") for doc in os.listdir(self.excel_doc_path)] | |
| def set_input_csv(self, csv_path: str): | |
| """ | |
| TODO: rename to set_requirements_file | |
| """ | |
| self.input_csv = csv_path | |
| def create_collection(self, docs: [Doc]): | |
| """ | |
| TODO: rename to create_requirements_collection | |
| TODO: merge with semantic tool to have only one DB Object | |
| """ | |
| coll_name = "collection_for_docs" | |
| collection = self.client_db.get_or_create_collection(coll_name) | |
| if collection.count() == 0: | |
| for doc in docs: | |
| self.fill_collection(doc, collection) | |
| self.retriever.collection = collection | |
| def fill_collection(self, doc: Doc, collection: str): | |
| """ | |
| fills the collection with the blocks of the documents | |
| """ | |
| Retriever(doc=doc, collection=collection) | |
| def _select_best_sources(sources: [Block], delta_1_2=0.15, delta_1_n=0.3, absolute=1.2, alpha=0.9, max_blocks=3) -> [Block]: | |
| """ | |
| Select the best sources: not far from the very best, not far from the last selected, and not too bad per se | |
| """ | |
| best_sources = [] | |
| for idx, s in enumerate(sources): | |
| if idx == 0 \ | |
| or (s.distance - sources[idx - 1].distance < delta_1_2 | |
| and s.distance - sources[0].distance < delta_1_n) \ | |
| or s.distance < absolute: | |
| best_sources.append(s) | |
| delta_1_2 *= alpha | |
| delta_1_n *= alpha | |
| absolute *= alpha | |
| else: | |
| break | |
| best_sources = sorted(best_sources, key=lambda x: x.distance)[:max_blocks] | |
| return best_sources | |
| def generate_response_to_requirements(self): | |
| dict_of_excel_content = self.get_requirements_from_csv() | |
| for exigence in dict_of_excel_content: | |
| blocks_sources = self.retriever.similarity_search(queries = exigence["Exigence"]) | |
| best_sources = self._select_best_sources(blocks_sources) | |
| sources_contents = [f"Paragraph title : {s.title}\n-----\n{s.content}" if s.title else f"Paragraph {s.index}\n-----\n{s.content}" for s in best_sources] | |
| context = '\n'.join(sources_contents) | |
| i = 1 | |
| while (len(context) > 15000) and i < len(sources_contents): | |
| context = "\n".join(sources_contents[:-i]) | |
| i += 1 | |
| reponse_exigence = generate_response_to_exigence(exigence = exigence["Exigence"], titre_exigence = exigence["Titre"], content = context) | |
| dict_of_excel_content[dict_of_excel_content.index(exigence)]["Conformité"] = reponse_exigence | |
| dict_of_excel_content[dict_of_excel_content.index(exigence)]["Document"] = best_sources[0].doc | |
| dict_of_excel_content[dict_of_excel_content.index(exigence)]["Paragraphes"] = "; ".join([block.index for block in best_sources]) | |
| excel_name = self.input_csv | |
| if '/' in excel_name: | |
| excel_name = excel_name.split('/')[-1] | |
| elif '\\' in excel_name: | |
| excel_name = excel_name.split('\\')[-1] | |
| df = pd.DataFrame(data=dict_of_excel_content) | |
| df.to_excel(f"{self.excel_doc_path}/{excel_name}", index=False) | |
| return f"{self.excel_doc_path}/{excel_name}" | |
| def get_requirements_from_csv(self): | |
| excel_content = excel_to_dict(self.input_csv) | |
| return excel_content |