multimodalart's picture
Update app.py
f3b481c verified
raw
history blame
4.51 kB
import gc
import os
import io
import time
import tempfile
import logging
import spaces
import torch
import gradio as gr
from transformers import Mistral3ForConditionalGeneration, AutoProcessor
from mistral_text_encoding_core import encode_prompt
# ------------------------------------------------------
# Logging
# ------------------------------------------------------
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)
logger = logging.getLogger("mistral-text-encoding-gradio")
# ------------------------------------------------------
# Config
# ------------------------------------------------------
TEXT_ENCODER_ID = os.getenv("TEXT_ENCODER_ID", "/repository")
TOKENIZER_ID = os.getenv(
"TOKENIZER_ID", "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
)
DTYPE = torch.bfloat16
# ------------------------------------------------------
# Global model references
# ------------------------------------------------------
logger.info("Loading models...")
t0 = time.time()
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
TEXT_ENCODER_ID,
dtype=DTYPE,
).to("cuda")
logger.info(
"Loaded Mistral text encoder (%.2fs) dtype=%s device=%s",
time.time() - t0,
text_encoder.dtype,
DEVICE_MAP,
)
t1 = time.time()
tokenizer = AutoProcessor.from_pretrained(TOKENIZER_ID)
logger.info("Loaded tokenizer in %.2fs", time.time() - t1)
torch.set_grad_enabled(False)
def get_vram_info():
"""Get current VRAM usage info."""
if torch.cuda.is_available():
return {
"vram_allocated_mb": round(torch.cuda.memory_allocated() / 1024 / 1024, 2),
"vram_reserved_mb": round(torch.cuda.memory_reserved() / 1024 / 1024, 2),
"vram_max_allocated_mb": round(torch.cuda.max_memory_allocated() / 1024 / 1024, 2),
}
return {"vram": "CUDA not available"}
@spaces.GPU()
def encode_text(prompt: str):
"""Encode text and return a downloadable pytorch file."""
global text_encoder, tokenizer
if text_encoder is None or tokenizer is None:
return None, "Model not loaded"
t0 = time.time()
# Handle multiple prompts (one per line)
prompts = [p.strip() for p in prompt.strip().split("\n") if p.strip()]
if not prompts:
return None, "Please enter at least one prompt"
num_prompts = len(prompts)
prompt_input = prompts[0] if num_prompts == 1 else prompts
logger.info("Encoding %d prompt(s)", num_prompts)
prompt_embeds, text_ids = encode_prompt(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt_input,
)
duration = (time.time() - t0) * 1000.0
logger.info(
"Encoded in %.2f ms | prompt_embeds.shape=%s | text_ids.shape=%s",
duration,
tuple(prompt_embeds.shape),
tuple(text_ids.shape),
)
# Save tensor to a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pt")
torch.save(prompt_embeds.cpu(), temp_file.name)
# Clean up GPU tensors
del prompt_embeds, text_ids
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
vram = get_vram_info()
status = (
f"Encoded {num_prompts} prompt(s) in {duration:.2f}ms\n"
f"VRAM: {vram.get('vram_allocated_mb', 'N/A')} MB allocated, "
f"{vram.get('vram_max_allocated_mb', 'N/A')} MB peak"
)
return temp_file.name, status
# ------------------------------------------------------
# Gradio Interface
# ------------------------------------------------------
with gr.Blocks(title="Mistral Text Encoder") as demo:
gr.Markdown("# Mistral Text Encoder")
gr.Markdown("Enter text to encode. For multiple prompts, put each on a new line.")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt(s)",
placeholder="Enter your prompt here...\nOr multiple prompts, one per line",
lines=5,
)
encode_btn = gr.Button("Encode", variant="primary")
with gr.Column():
output_file = gr.File(label="Download Embeddings (.pt)")
status_output = gr.Textbox(label="Status", interactive=False)
encode_btn.click(
fn=encode_text,
inputs=[prompt_input],
outputs=[output_file, status_output],
)
if __name__ == "__main__":
load_models()
demo.launch(server_name="0.0.0.0", server_port=7860)