Local_OCR_Demo / test_medgemma.py
DocUA's picture
Initial commit: DeepSeek-OCR-2 & MedGemma-1.5 multimodal analysis app with ZeroGPU support
b752d16
from transformers import AutoProcessor, AutoModelForImageTextToText
import torch
from PIL import Image
import os
model_id = "google/medgemma-1.5-4b-it"
def test_medgemma():
print(f"Loading {model_id}...")
try:
processor = AutoProcessor.from_pretrained(model_id)
# We try to load without device_map="auto" for MPS or manual device control
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.float32, # CPU usually stable with float32
trust_remote_code=True
).eval()
print("Model loaded.")
image_path = "sample_test.png"
if not os.path.exists(image_path):
print("No test image found.")
return
image = Image.open(image_path).convert("RGB")
# Use chat template as suggested
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Extract all text from this image."}
]
}
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
)
print("Running inference...")
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=100)
input_len = inputs["input_ids"].shape[-1]
result = processor.decode(output[0][input_len:], skip_special_tokens=True)
print("\n--- MedGemma Result ---")
print(result)
print("-----------------------")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_medgemma()