import gc import os import io import time import tempfile import logging import spaces import torch import gradio as gr from transformers import Mistral3ForConditionalGeneration, AutoProcessor from mistral_text_encoding_core import encode_prompt # ------------------------------------------------------ # Logging # ------------------------------------------------------ logging.basicConfig( level=os.getenv("LOG_LEVEL", "INFO"), format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", ) logger = logging.getLogger("mistral-text-encoding-gradio") # ------------------------------------------------------ # Config # ------------------------------------------------------ TEXT_ENCODER_ID = os.getenv("TEXT_ENCODER_ID", "/repository") TOKENIZER_ID = os.getenv( "TOKENIZER_ID", "mistralai/Mistral-Small-3.1-24B-Instruct-2503" ) DTYPE = torch.bfloat16 # ------------------------------------------------------ # Global model references # ------------------------------------------------------ logger.info("Loading models...") t0 = time.time() text_encoder = Mistral3ForConditionalGeneration.from_pretrained( TEXT_ENCODER_ID, dtype=DTYPE, ).to("cuda") logger.info( "Loaded Mistral text encoder (%.2fs) dtype=%s device=%s", time.time() - t0, text_encoder.dtype, DEVICE_MAP, ) t1 = time.time() tokenizer = AutoProcessor.from_pretrained(TOKENIZER_ID) logger.info("Loaded tokenizer in %.2fs", time.time() - t1) torch.set_grad_enabled(False) def get_vram_info(): """Get current VRAM usage info.""" if torch.cuda.is_available(): return { "vram_allocated_mb": round(torch.cuda.memory_allocated() / 1024 / 1024, 2), "vram_reserved_mb": round(torch.cuda.memory_reserved() / 1024 / 1024, 2), "vram_max_allocated_mb": round(torch.cuda.max_memory_allocated() / 1024 / 1024, 2), } return {"vram": "CUDA not available"} @spaces.GPU() def encode_text(prompt: str): """Encode text and return a downloadable pytorch file.""" global text_encoder, tokenizer if text_encoder is None or tokenizer is None: return None, "Model not loaded" t0 = time.time() # Handle multiple prompts (one per line) prompts = [p.strip() for p in prompt.strip().split("\n") if p.strip()] if not prompts: return None, "Please enter at least one prompt" num_prompts = len(prompts) prompt_input = prompts[0] if num_prompts == 1 else prompts logger.info("Encoding %d prompt(s)", num_prompts) prompt_embeds, text_ids = encode_prompt( text_encoder=text_encoder, tokenizer=tokenizer, prompt=prompt_input, ) duration = (time.time() - t0) * 1000.0 logger.info( "Encoded in %.2f ms | prompt_embeds.shape=%s | text_ids.shape=%s", duration, tuple(prompt_embeds.shape), tuple(text_ids.shape), ) # Save tensor to a temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pt") torch.save(prompt_embeds.cpu(), temp_file.name) # Clean up GPU tensors del prompt_embeds, text_ids gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() vram = get_vram_info() status = ( f"Encoded {num_prompts} prompt(s) in {duration:.2f}ms\n" f"VRAM: {vram.get('vram_allocated_mb', 'N/A')} MB allocated, " f"{vram.get('vram_max_allocated_mb', 'N/A')} MB peak" ) return temp_file.name, status # ------------------------------------------------------ # Gradio Interface # ------------------------------------------------------ with gr.Blocks(title="Mistral Text Encoder") as demo: gr.Markdown("# Mistral Text Encoder") gr.Markdown("Enter text to encode. For multiple prompts, put each on a new line.") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Prompt(s)", placeholder="Enter your prompt here...\nOr multiple prompts, one per line", lines=5, ) encode_btn = gr.Button("Encode", variant="primary") with gr.Column(): output_file = gr.File(label="Download Embeddings (.pt)") status_output = gr.Textbox(label="Status", interactive=False) encode_btn.click( fn=encode_text, inputs=[prompt_input], outputs=[output_file, status_output], ) if __name__ == "__main__": load_models() demo.launch(server_name="0.0.0.0", server_port=7860)