--- license: apache-2.0 base_model: google/medgemma-4b-it tags: - vision-language - medical-imaging - radiology - medgemma - gemma - flare2025 - peft - lora - multimodal - medical-ai datasets: - FLARE-MedFM/FLARE-Task5-MLLM-2D pipeline_tag: image-text-to-text library_name: transformers --- # MedGemma Fine-tuned for FLARE 2025 Medical Image Analysis This model is a fine-tuned version of [google/medgemma-4b-it](https://huggingface.co/google/medgemma-4b-it) specifically optimized for medical image analysis tasks in the FLARE 2025 2D Medical Multimodal Dataset challenge. ## Model Description - **Base Model**: MedGemma-4B-IT (Google's medical-specialized Gemma model) - **Fine-tuning Method**: QLoRA (Low-Rank Adaptation) - **Target Domain**: Medical imaging across 7 modalities (CT, MRI, X-ray, Ultrasound, Fundus, Pathology, Endoscopy) - **Tasks**: Medical image captioning, visual question answering, report generation, diagnosis support - **Training Data**: 19 FLARE 2025 datasets with comprehensive medical annotations ## Training Details ### Training Data The model was fine-tuned on 19 diverse medical imaging datasets from FLARE 2025, including: - **Classification**: Disease diagnosis with balanced accuracy optimization - **Multi-label Classification**: Multi-pathology identification - **Detection**: Anatomical structure and pathology detection - **Instance Detection**: Identity-aware detection (e.g., chromosome analysis) - **Counting**: Cell counting and quantitative analysis - **Regression**: Continuous medical measurements - **Report Generation**: Comprehensive medical report writing Details available at: https://huggingface.co/datasets/FLARE-MedFM/FLARE-Task5-MLLM-2D ### Training Configuration ```yaml\n # LoRA Configuration lora_r: 16\nlora_alpha: 32 lora_dropout: 0.1 target_modules: ['gate_proj', 'up_proj', 'o_proj', 'down_proj', 'v_proj', 'q_proj', 'k_proj'] task_type: CAUSAL_LM bias: none ``` ### Training Procedure - **Base Architecture**: MedGemma-4B with medical domain pre-training - **Optimization**: 4-bit quantization with BitsAndBytesConfig - **LoRA Configuration**: - r=64, alpha=16, dropout=0.1 - Target modules: All attention and MLP layers - **Memory Optimization**: Gradient checkpointing, flash attention - **Batch Size**: Dynamic based on image resolution and GPU memory - **Learning Rate**: 1e-4 with cosine scheduling - **Training Steps**: 4000 steps with evaluation every 500 steps - **Chat Template**: Gemma-style chat formatting for medical conversations ## Model Performance This model has been evaluated across multiple medical imaging tasks using FLARE 2025 evaluation metrics: ### Evaluation Metrics by Task Type **Classification Tasks (Disease Diagnosis):** - **Balanced Accuracy** (PRIMARY): Handles class imbalance in medical diagnosis - **Accuracy**: Standard classification accuracy - **F1 Score**: Weighted F1 for multi-class scenarios **Multi-label Classification (Multi-pathology):** - **F1 Score** (PRIMARY): Sample-wise F1 across multiple medical conditions - **Precision**: Label prediction precision - **Recall**: Medical condition coverage recall **Detection Tasks (Anatomical/Pathological):** - **F1 Score @ IoU > 0.5** (PRIMARY): Standard computer vision detection metric - **Precision**: Detection precision at IoU threshold - **Recall**: Detection recall at IoU threshold **Instance Detection (Identity-Aware Detection):** - **F1 Score @ IoU > 0.3** (PRIMARY): Medical imaging standard (e.g., chromosome detection) - **F1 Score @ IoU > 0.5**: Computer vision standard - **Average F1**: COCO-style average across IoU thresholds (0.3-0.7) - **Per-instance metrics**: Detailed breakdown by object identity **Counting Tasks (Cell/Structure Counting):** - **Mean Absolute Error** (PRIMARY): Cell counting accuracy - **Root Mean Squared Error**: Additional counting precision metric **Regression Tasks (Medical Measurements):** - **Mean Absolute Error** (PRIMARY): Continuous value prediction accuracy - **Root Mean Squared Error**: Regression precision metric **Report Generation (Medical Reports):** - **GREEN Score** (PRIMARY): Comprehensive medical report evaluation with 7 components: - Entity matching with severity assessment (30%) - Location accuracy with laterality (20%) - Negation and uncertainty handling (15%) - Temporal accuracy (10%) - Size/measurement accuracy (10%) - Clinical significance weighting (10%) - Report structure completeness (5%) - **BLEU Score**: Text generation quality - **Clinical Efficacy**: Medical relevance scoring ## Usage ### Installation ```bash pip install transformers torch peft accelerate bitsandbytes ``` ### Basic Usage ```python import torch from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText from peft import PeftModel from PIL import Image # Load the fine-tuned model base_model_name = "google/medgemma-4b-it" adapter_model_name = "leoyinn/flare25-medgemma" # Load tokenizer and processor tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) processor = AutoProcessor.from_pretrained(base_model_name, trust_remote_code=True) # Load base model base_model = AutoModelForImageTextToText.from_pretrained( base_model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, attn_implementation="eager" ) # Load the fine-tuned adapter model = PeftModel.from_pretrained(base_model, adapter_model_name) # Prepare input with MedGemma chat format image = Image.open("medical_image.jpg").convert("RGB") image = image.resize((448, 448)) # MedGemma standard size # Create proper message format messages = [ { "role": "system", "content": [{ "type": "text", "text": "You are an expert medical AI assistant specialized in analyzing medical images and providing accurate diagnostic insights." }] }, { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "Describe the medical findings in this image and provide a diagnostic assessment."} ] } ] # Apply chat template full_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Process and generate inputs = processor( images=[image], text=full_text, return_tensors="pt", padding=True, truncation=False ).to(model.device, dtype=torch.bfloat16) # Generate medical response with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=300, do_sample=False, # Deterministic for medical applications use_cache=True, cache_implementation="dynamic" ) # Decode response input_len = inputs["input_ids"].shape[-1] response = processor.decode(outputs[0][input_len:], skip_special_tokens=True) print(response) ``` ## Limitations and Ethical Considerations ### Limitations - Model outputs may contain inaccuracies and should be verified by medical professionals - Performance may vary across different medical imaging modalities and populations - Training data may contain biases present in medical literature and datasets - Model has not been validated in clinical settings - Designed for research and educational purposes, not clinical decision-making ### Intended Use - Medical education and training - Research in medical AI and computer vision - Development of clinical decision support tools (with proper validation) - Academic research in multimodal medical AI - Medical image analysis prototyping ### Out-of-Scope Use - Direct clinical diagnosis without physician oversight - Treatment recommendations without medical professional validation - Use in emergency medical situations - Deployment in production clinical systems without extensive validation - Patient-facing applications without proper medical supervision ## Citation If you use this model in your research, please cite: ```bibtex @misc{medgemma-flare2025, title={MedGemma Fine-tuned for FLARE 2025 Medical Image Analysis}, author={Shuolin Yin}, year={2025}, publisher={Hugging Face}, url={https://huggingface.co/leoyinn/flare25-medgemma} } @misc{medgemma-base, title={MedGemma: Medical Gemma Models for Healthcare}, author={Google Research}, year={2024}, publisher={Hugging Face}, url={https://huggingface.co/google/medgemma-4b-it} } @misc{flare2025, title={FLARE 2025: A Multi-Modal Foundation Model Challenge for Medical AI}, year={2025}, url={https://huggingface.co/datasets/FLARE-MedFM/FLARE-Task5-MLLM-2D} } ``` ## Model Details - **Model Type**: Vision-Language Model (VLM) specialized for medical applications - **Architecture**: MedGemma (Gemma-based) with LoRA adapters - **Parameters**: ~4B base parameters + LoRA adapters - **Precision**: bfloat16 base model + full precision adapters - **Framework**: PyTorch, Transformers, PEFT - **Input Resolution**: 448x448 pixels (standard for MedGemma) - **Context Length**: Supports long medical reports and conversations ## Technical Specifications - **Base Model**: google/medgemma-4b-it - **Adapter Type**: LoRA (Low-Rank Adaptation) - **Target Modules**: All attention projection layers and MLP layers - **Chat Template**: Gemma-style with medical system prompts - **Attention Implementation**: Eager attention for stability - **Cache Implementation**: Dynamic caching for efficient inference ## Contact For questions or issues, please open an issue in the model repository or contact the authors. --- **Disclaimer**: This model is for research and educational purposes only. Always consult qualified medical professionals for clinical decisions.