Spaces:
Runtime error
Runtime error
| import logging | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
| import torch | |
| import spaces | |
| ########################## | |
| # CONFIGURATION | |
| ########################## | |
| logging.basicConfig( | |
| level=logging.getLevelName("INFO"), | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| # Example images and texts | |
| EXAMPLES = [ | |
| ["images/ingredients_1.jpg", "24.36% chocolat noir 63% origine non UE (cacao, sucre, beurre de cacao, émulsifiant léci - thine de colza, vanille bourbon gousse), œuf, farine de blé, beurre, sucre, miel, sucre perlé, levure chimique, zeste de citron."], | |
| ["images/ingredients_2.jpg", "farine de froment, œufs, lait entier pasteurisé Aprigine: France), sucre, sel, extrait de vanille naturelle Conditi( 35."], | |
| # ["images/ingredients_3.jpg", "tural basmati rice - cooked (98%), rice bran oil, salt"], | |
| ["images/ingredients_4.jpg", "Eau de noix de coco 93.9%, Arôme natutel de fruit"], | |
| ["images/ingredients_5.jpg", "Sucre, pâte de cacao, beurre de cacao, émulsifiant: léci - thines (soja). Peut contenir des traces de lait. Chocolat noir: cacao: 50% minimum. À conserver à l'abri de la chaleur et de l'humidité. Élaboré en France."], | |
| ] | |
| MODEL_ID = "openfoodfacts/spellcheck-mistral-7b" | |
| PRESENTATION = """# 🍊 Ingredients Spellcheck - Open Food Facts | |
| Open Food Facts is a non-profit organization building the largest open food database in the world. 🌎 | |
| When a product is added to the database, all its details, such as allergens, additives, or nutritional values, are either wrote down by the contributor, | |
| or automatically extracted from the product pictures using OCR. | |
| However, it often happens the information extracted by OCR contains typos and errors due to bad quality pictures: low-definition, curved product, light reflection, etc... | |
| To solve this problem, we developed an 🍊 **Ingredient Spellcheck** 🍊, a model capable of correcting typos in a list of ingredients following a defined guideline. | |
| The model, based on Mistral-7B-v0.3, was fine-tuned on thousand of corrected lists of ingredients extracted from the database. More information in the model card. | |
| ### *Project in progress* 🏗️ | |
| ## 👇 Links | |
| * Open Food Facts website: https://world.openfoodfacts.org/discover | |
| * Open Food Facts Github: https://github.com/openfoodfacts | |
| * Spellcheck project: https://github.com/openfoodfacts/openfoodfacts-ai/tree/develop/spellcheck | |
| * Model card: https://huggingface.co/openfoodfacts/spellcheck-mistral-7b | |
| """ | |
| # CPU/GPU device | |
| zero = torch.Tensor([0]).cuda() | |
| # Transformers seed to orient generation to be reproducible (as possible since it doesn't ensure 100% reproducibility) | |
| set_seed(42) | |
| ########################## | |
| # LOADING | |
| ########################## | |
| # Tokenizer | |
| logging.info(f"Load tokenizer from {MODEL_ID}.") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| # Model | |
| logging.info(f"Load model from {MODEL_ID}.") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", | |
| # attn_implementation="flash_attention_2", # Not supported by ZERO GPU | |
| # torch_dtype=torch.bfloat16, | |
| ) | |
| ########################## | |
| # FUNCTIONS | |
| ########################## | |
| def process(text: str) -> str: | |
| """Take the text, the tokenizer and the causal model and generate the correction.""" | |
| prompt = prepare_instruction(text) | |
| input_ids = tokenizer( | |
| prompt, | |
| add_special_tokens=True, | |
| return_tensors="pt" | |
| ).input_ids | |
| output = model.generate( | |
| input_ids.to(zero.device), # GPU | |
| do_sample=False, | |
| max_new_tokens=512, | |
| ) | |
| return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt):].strip() | |
| def prepare_instruction(text: str) -> str: | |
| """Prepare instruction prompt for fine-tuning and inference. | |
| Identical to instruction during training. | |
| Args: | |
| text (str): List of ingredients | |
| Returns: | |
| str: Instruction. | |
| """ | |
| instruction = ( | |
| "###Correct the list of ingredients:\n" | |
| + text | |
| + "\n\n###Correction:\n" | |
| ) | |
| return instruction | |
| ########################## | |
| # GRADIO SETUP | |
| ########################## | |
| with gr.Blocks() as demo: | |
| gr.Markdown(PRESENTATION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image(type="pil", label="image_input", interactive=False) | |
| with gr.Column(): | |
| ingredients = gr.Textbox(label="List of ingredients") | |
| spellcheck_button = gr.Button(value='Run spellcheck') | |
| correction = gr.Textbox(label="Correction", interactive=False) | |
| with gr.Row(): | |
| gr.Examples( | |
| fn=process, | |
| examples=EXAMPLES, | |
| inputs=[ | |
| image, | |
| ingredients, | |
| ], | |
| ) | |
| spellcheck_button.click( | |
| fn=process, | |
| inputs=[ingredients], | |
| outputs=[correction] | |
| ) | |
| if __name__ == "__main__": | |
| # Launch the demo | |
| demo.launch() | |