Spaces:
Runtime error
Runtime error
| import gc | |
| import os | |
| import io | |
| import time | |
| import tempfile | |
| import logging | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| from transformers import Mistral3ForConditionalGeneration, AutoProcessor | |
| from mistral_text_encoding_core import encode_prompt | |
| # ------------------------------------------------------ | |
| # Logging | |
| # ------------------------------------------------------ | |
| logging.basicConfig( | |
| level=os.getenv("LOG_LEVEL", "INFO"), | |
| format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", | |
| ) | |
| logger = logging.getLogger("mistral-text-encoding-gradio") | |
| # ------------------------------------------------------ | |
| # Config | |
| # ------------------------------------------------------ | |
| TEXT_ENCODER_ID = os.getenv("TEXT_ENCODER_ID", "/repository") | |
| TOKENIZER_ID = os.getenv( | |
| "TOKENIZER_ID", "mistralai/Mistral-Small-3.1-24B-Instruct-2503" | |
| ) | |
| DTYPE = torch.bfloat16 | |
| # ------------------------------------------------------ | |
| # Global model references | |
| # ------------------------------------------------------ | |
| logger.info("Loading models...") | |
| t0 = time.time() | |
| text_encoder = Mistral3ForConditionalGeneration.from_pretrained( | |
| TEXT_ENCODER_ID, | |
| dtype=DTYPE, | |
| ).to("cuda") | |
| logger.info( | |
| "Loaded Mistral text encoder (%.2fs) dtype=%s device=%s", | |
| time.time() - t0, | |
| text_encoder.dtype, | |
| DEVICE_MAP, | |
| ) | |
| t1 = time.time() | |
| tokenizer = AutoProcessor.from_pretrained(TOKENIZER_ID) | |
| logger.info("Loaded tokenizer in %.2fs", time.time() - t1) | |
| torch.set_grad_enabled(False) | |
| def get_vram_info(): | |
| """Get current VRAM usage info.""" | |
| if torch.cuda.is_available(): | |
| return { | |
| "vram_allocated_mb": round(torch.cuda.memory_allocated() / 1024 / 1024, 2), | |
| "vram_reserved_mb": round(torch.cuda.memory_reserved() / 1024 / 1024, 2), | |
| "vram_max_allocated_mb": round(torch.cuda.max_memory_allocated() / 1024 / 1024, 2), | |
| } | |
| return {"vram": "CUDA not available"} | |
| def encode_text(prompt: str): | |
| """Encode text and return a downloadable pytorch file.""" | |
| global text_encoder, tokenizer | |
| if text_encoder is None or tokenizer is None: | |
| return None, "Model not loaded" | |
| t0 = time.time() | |
| # Handle multiple prompts (one per line) | |
| prompts = [p.strip() for p in prompt.strip().split("\n") if p.strip()] | |
| if not prompts: | |
| return None, "Please enter at least one prompt" | |
| num_prompts = len(prompts) | |
| prompt_input = prompts[0] if num_prompts == 1 else prompts | |
| logger.info("Encoding %d prompt(s)", num_prompts) | |
| prompt_embeds, text_ids = encode_prompt( | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| prompt=prompt_input, | |
| ) | |
| duration = (time.time() - t0) * 1000.0 | |
| logger.info( | |
| "Encoded in %.2f ms | prompt_embeds.shape=%s | text_ids.shape=%s", | |
| duration, | |
| tuple(prompt_embeds.shape), | |
| tuple(text_ids.shape), | |
| ) | |
| # Save tensor to a temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pt") | |
| torch.save(prompt_embeds.cpu(), temp_file.name) | |
| # Clean up GPU tensors | |
| del prompt_embeds, text_ids | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| vram = get_vram_info() | |
| status = ( | |
| f"Encoded {num_prompts} prompt(s) in {duration:.2f}ms\n" | |
| f"VRAM: {vram.get('vram_allocated_mb', 'N/A')} MB allocated, " | |
| f"{vram.get('vram_max_allocated_mb', 'N/A')} MB peak" | |
| ) | |
| return temp_file.name, status | |
| # ------------------------------------------------------ | |
| # Gradio Interface | |
| # ------------------------------------------------------ | |
| with gr.Blocks(title="Mistral Text Encoder") as demo: | |
| gr.Markdown("# Mistral Text Encoder") | |
| gr.Markdown("Enter text to encode. For multiple prompts, put each on a new line.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Prompt(s)", | |
| placeholder="Enter your prompt here...\nOr multiple prompts, one per line", | |
| lines=5, | |
| ) | |
| encode_btn = gr.Button("Encode", variant="primary") | |
| with gr.Column(): | |
| output_file = gr.File(label="Download Embeddings (.pt)") | |
| status_output = gr.Textbox(label="Status", interactive=False) | |
| encode_btn.click( | |
| fn=encode_text, | |
| inputs=[prompt_input], | |
| outputs=[output_file, status_output], | |
| ) | |
| if __name__ == "__main__": | |
| load_models() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |