File size: 11,889 Bytes
ed863e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 |
---
license: apache-2.0
tags:
- medical
- radiology
- mammography
- contrastive-learning
- embeddings
- computer-vision
- pytorch
datasets:
- CMMD
pipeline_tag: feature-extraction
---
# PRIMER: Pretrained RadImageNet for Mammography Embedding Representations
PRIMER is a specialized deep learning model for mammography analysis, finetuned from RadImageNet using contrastive learning on the CMMD (Chinese Mammography Mass Database) dataset. The model generates discriminative embedding vectors specifically optimized for mammogram images.
## Model Overview
- **Base Model**: RadImageNet ResNet-50
- **Training Method**: SimCLR contrastive learning (NT-Xent loss)
- **Architecture**: ResNet-50 encoder + 2-layer MLP projection head
- **Input**: 224×224 RGB images (converted from DICOM grayscale)
- **Output**: 2048-dimensional embedding vectors
- **Training Dataset**: CMMD mammography DICOM files
- **Framework**: PyTorch 2.1+
## Key Features
- Finetuned specifically for mammography imaging
- Self-supervised contrastive learning (no labels required)
- Produces embeddings with better clustering and separation than baseline RadImageNet
- Handles DICOM preprocessing pipeline end-to-end
- Supports multiple backbone architectures (ResNet-50, DenseNet-121, Inception-V3)
## DICOM Preprocessing Pipeline
The model expects mammography DICOM images preprocessed through the following pipeline. This preprocessing is **critical** for proper model performance:
### Step 1: DICOM Loading
```
- Read DICOM file using pydicom
- Extract pixel array as float32
```
### Step 2: Photometric Interpretation Correction
```
- Check PhotometricInterpretation attribute
- If MONOCHROME1: Invert pixel values (max_value - pixel_value)
- MONOCHROME1: Higher values = darker (inverted scale)
- MONOCHROME2: Higher values = brighter (standard scale)
```
### Step 3: Intensity Normalization
```
- Percentile-based clipping to remove outliers:
- Compute 2nd percentile (p2) and 98th percentile (p98)
- Clip all values: pixel_value = clip(pixel_value, p2, p98)
- Min-max normalization to [0, 255]:
- normalized = ((pixel_value - min) / (max - min + 1e-8)) × 255
- Convert to uint8
```
### Step 4: CLAHE Enhancement
```
- Apply Contrast Limited Adaptive Histogram Equalization (CLAHE)
- Clip limit: 2.0
- Tile grid size: 8×8
- Improves local contrast and enhances subtle features
```
### Step 5: Grayscale to RGB Conversion
```
- Duplicate grayscale channel 3 times: RGB = [gray, gray, gray]
- Required because RadImageNet expects 3-channel input
```
### Step 6: Resizing
```
- Resize to 224×224 using bilinear interpolation
```
### Step 7: Data Augmentation (Training Only)
```
Training augmentations:
- Horizontal flip (p=0.5)
- Vertical flip (p=0.3)
- Rotation (±15 degrees, p=0.5)
- Random brightness/contrast (±0.2, p=0.5)
- Shift/scale/rotate (shift=0.1, scale=0.1, rotate=15°, p=0.5)
```
### Step 8: Normalization
```
- ImageNet normalization (required for RadImageNet compatibility):
- Mean: [0.485, 0.456, 0.406]
- Std: [0.229, 0.224, 0.225]
- Convert to tensor (C×H×W format)
```
### Complete Preprocessing Code
```python
import cv2
import numpy as np
import pydicom
from PIL import Image
class DICOMProcessor:
def __init__(self):
self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
def preprocess(self, dicom_path):
# 1. Load DICOM
dicom = pydicom.dcmread(dicom_path)
image = dicom.pixel_array.astype(np.float32)
# 2. Handle photometric interpretation
if hasattr(dicom, 'PhotometricInterpretation'):
if dicom.PhotometricInterpretation == "MONOCHROME1":
image = np.max(image) - image
# 3. Intensity normalization
p2, p98 = np.percentile(image, (2, 98))
image = np.clip(image, p2, p98)
image = ((image - image.min()) / (image.max() - image.min() + 1e-8) * 255)
image = image.astype(np.uint8)
# 4. CLAHE enhancement
image = self.clahe.apply(image)
# 5. Grayscale to RGB
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# 6. Resize
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
# 7. ImageNet normalization
image = image.astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = (image - mean) / std
# 8. Convert to tensor (C, H, W)
image = np.transpose(image, (2, 0, 1))
return image
```
## Model Architecture
### Overall Structure
```
Input DICOM (H×W grayscale)
↓
[DICOM Preprocessing Pipeline]
↓
224×224×3 RGB Tensor
↓
[RadImageNet ResNet-50 Encoder]
↓
2048-dim Embeddings
↓
[Projection Head] (training only)
↓
128-dim Projections
```
### Components
**1. Encoder (RadImageNet ResNet-50)**
- Pretrained on RadImageNet dataset
- Modified final layer: removed classification head
- Output: 2048-dimensional feature vectors
- Finetuned on mammography data during contrastive learning
**2. Projection Head (used during training, discarded for inference)**
- 2-layer MLP: 2048 → 512 → 128
- Batch normalization + ReLU activation
- Used only for contrastive learning
- Discarded during embedding extraction
**3. Loss Function: NT-Xent (Normalized Temperature-scaled Cross Entropy)**
- Contrastive loss from SimCLR framework
- Temperature parameter: τ = 0.07
- Cosine similarity with L2 normalization
- Positive pairs: Two augmented views of same image
- Negative pairs: All other images in batch
### Training Details
**Contrastive Learning Framework (SimCLR)**
```
For each mammogram:
1. Create two different augmented views (image1, image2)
2. Pass both through encoder → projection head
3. Compute NT-Xent loss between the two projections
4. Maximize agreement between views of same image
5. Minimize similarity with other images in batch
```
**Hyperparameters**
- Batch size: 128
- Epochs: 50
- Learning rate: 1e-4 (AdamW optimizer)
- Weight decay: 1e-5
- Temperature: 0.07
- LR scheduler: Cosine annealing with 10-epoch warmup
- Mixed precision training: Enabled (AMP)
- Gradient clipping: 1.0
- Early stopping patience: 15 epochs
**Training Data**
- Dataset: CMMD (Chinese Mammography Mass Database)
- Training split: 70%
- Validation split: 15%
- Test split: 15%
- Total training images: ~13,000 mammograms
## Model Specifications
| Property | Value |
|----------|-------|
| Model Type | Feature Extraction / Embedding Model |
| Architecture | ResNet-50 (RadImageNet pretrained) |
| Input Shape | (3, 224, 224) |
| Output Shape | (2048,) |
| Parameters | ~23.5M trainable |
| Model Size | 283 MB |
| Precision | FP32 |
| Framework | PyTorch 2.1+ |
## Usage
### Loading the Model
```python
import torch
import torch.nn as nn
import timm
# Define the encoder architecture
class RadImageNetEncoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = timm.create_model('resnet50', pretrained=False, num_classes=0)
self.feature_dim = 2048
def forward(self, x):
return self.encoder(x)
# Load the checkpoint
checkpoint = torch.load('pytorch_model.bin', map_location='cpu')
# Extract encoder weights
model = RadImageNetEncoder()
encoder_state_dict = {
k.replace('encoder.encoder.', ''): v
for k, v in checkpoint['model_state_dict'].items()
if k.startswith('encoder.encoder.')
}
model.encoder.load_state_dict(encoder_state_dict)
model.eval()
```
### Extracting Embeddings
```python
# Preprocess DICOM (see preprocessing code above)
processor = DICOMProcessor()
image = processor.preprocess('path/to/mammogram.dcm')
# Convert to tensor and add batch dimension
image_tensor = torch.from_numpy(image).unsqueeze(0) # Shape: (1, 3, 224, 224)
# Extract embeddings
with torch.no_grad():
embeddings = model(image_tensor) # Shape: (1, 2048)
# L2 normalize (recommended)
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
```
## Performance: PRIMER vs RadImageNet Baseline
PRIMER demonstrates significant improvements over baseline RadImageNet embeddings on mammography-specific evaluation metrics:
| Metric | RadImageNet (Baseline) | PRIMER (Finetuned) | Improvement |
|--------|------------------------|-------------------|-------------|
| Silhouette Score | 0.127 | 0.289 | +127% |
| Davies-Bouldin Score | 2.847 | 1.653 | -42% (lower is better) |
| Calinski-Harabasz Score | 1,834 | 3,621 | +97% |
| Embedding Variance | 0.012 | 0.024 | +100% |
| Intra-cluster Distance | 1.92 | 1.34 | -30% |
| Inter-cluster Distance | 2.15 | 2.87 | +33% |
**Key Improvements:**
- **Better Clustering**: Silhouette score increased from 0.127 to 0.289, indicating much tighter and more separated clusters
- **Enhanced Discrimination**: Davies-Bouldin score decreased by 42%, showing better cluster separation
- **Richer Representations**: Embedding variance doubled, indicating more diverse and informative features
- **Mammography-Specific**: Features learned are specialized for mammographic patterns (masses, calcifications, tissue density)
### Visualization Improvements
Dimensionality reduction visualizations (t-SNE, UMAP, PCA) show:
- PRIMER embeddings form distinct, well-separated clusters
- RadImageNet embeddings show more overlap and diffuse boundaries
- PRIMER captures mammography-specific visual patterns more effectively
## Requirements
```
torch>=2.1.0
torchvision>=0.16.0
pydicom>=2.4.4
opencv-python>=4.8.1.78
numpy>=1.26.0
timm>=0.9.12
albumentations>=1.3.1
scikit-learn>=1.3.2
```
## Dataset
**CMMD (Chinese Mammography Mass Database)**
- Modality: Full-field digital mammography (FFDM)
- Format: DICOM files
- Views: CC (craniocaudal), MLO (mediolateral oblique)
- Resolution: Variable (typically 2048×3328 or similar)
## Limitations
1. **Domain Specificity**: Model is trained on CMMD dataset (Chinese population). Performance may vary on other populations or imaging protocols.
2. **DICOM Format**: Requires proper DICOM preprocessing. Standard images (PNG/JPG) must follow the same preprocessing pipeline for best results.
3. **Image Quality**: Performance depends on proper CLAHE enhancement and normalization. Poor quality or corrupted DICOM files may produce suboptimal embeddings.
4. **Resolution**: Model expects 224×224 input. Very high-resolution details may be lost during resizing.
5. **Self-Supervised**: Model uses contrastive learning without labels. Does not perform classification directly - embeddings must be used with downstream tasks (clustering, retrieval, classification).
6. **Photometric Interpretation**: Critical to handle MONOCHROME1 vs MONOCHROME2 correctly. Failure to invert MONOCHROME1 images will result in poor embeddings.
## Intended Use
### Primary Use Cases
- **Feature Extraction**: Generate embeddings for mammography images
- **Similarity Search**: Find similar mammograms based on visual features
- **Clustering**: Group mammograms by visual characteristics
- **Transfer Learning**: Use as pretrained backbone for downstream tasks (classification, segmentation)
- **Retrieval Systems**: Content-based mammography image retrieval
- **Quality Control**: Identify outlier or anomalous mammograms
### Out-of-Scope Use Cases
- **Direct Diagnosis**: Model does not provide diagnostic predictions
- **Standalone Clinical Use**: Requires integration with clinical workflows and expert interpretation
- **Non-Mammography Images**: Optimized for mammography; may not generalize to other modalities
- **Real-time Processing**: Model size (283MB) and preprocessing may not be suitable for real-time applications without optimization
## Model Card Contact
For questions or issues, please open an issue on the [GitHub repository](https://github.com/Lab-Rasool/PRIMER) or contact via HuggingFace discussions.
|