from transformers import AutoProcessor, AutoModelForImageTextToText import torch from PIL import Image import os model_id = "google/medgemma-1.5-4b-it" def test_medgemma(): print(f"Loading {model_id}...") try: processor = AutoProcessor.from_pretrained(model_id) # We try to load without device_map="auto" for MPS or manual device control model = AutoModelForImageTextToText.from_pretrained( model_id, torch_dtype=torch.float32, # CPU usually stable with float32 trust_remote_code=True ).eval() print("Model loaded.") image_path = "sample_test.png" if not os.path.exists(image_path): print("No test image found.") return image = Image.open(image_path).convert("RGB") # Use chat template as suggested messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Extract all text from this image."} ] } ] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ) print("Running inference...") with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=100) input_len = inputs["input_ids"].shape[-1] result = processor.decode(output[0][input_len:], skip_special_tokens=True) print("\n--- MedGemma Result ---") print(result) print("-----------------------") except Exception as e: print(f"Error: {e}") import traceback traceback.print_exc() if __name__ == "__main__": test_medgemma()