Skin Lesion Classification
Collection
Finetuned models for binary skin lesion classification using ISIC 2024 challenge and ISIC 2024 synthetic datasets
•
2 items
•
Updated
This model is a finetuned for skin lesion classification.
Params (M): 85.6
This model is intended for dermoscopic skin lesion classification using a 224x224 image size.
This model was only trained for 1 epoch and has not seen many malignant examples (due to large class imbalance in ISIC 2024 dataset).
class DinoSkinLesionClassifier(nn.Module, PyTorchModelHubMixin):
"""
PytorchModelHubMixin adds push to Hugging Face Hub
See: https://huggingface.co/docs/hub/models-uploading#upload-a-pytorch-model-using-huggingfacehub
"""
def __init__(self, num_classes=1, freeze_backbone=True):
super(DinoSkinLesionClassifier, self).__init__()
# Initialize Dino v3 backbone
self.backbone = timm.create_model('vit_base_patch16_dinov3', pretrained=True, num_classes=0, global_pool='avg')
# Freeze backbone weights if requested
# This makes training much faster
if freeze_backbone:
for param in self.backbone.parameters():
param.requires_grad = False
# Get feature dimension from the backbone
feat_dim = self.backbone.num_features
# Define the classification head
self.head = nn.Linear(feat_dim, num_classes) # Should be 768 in, 1 out
def forward(self, x):
out = self.backbone(x)
out = self.head(out)
return out
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(
repo_id="avanishd/vit-base-patch16-dinov3-finetuned-skin-lesion-classification",
filename="model.safetensors"
)
from safetensors.torch import load_model
model = EfficientNetSkinLesionClassifier()
load_model(model, filename=weights_path, strict=True)
model.to(device) # Don't forget to put on GPU
model.eval() # Set model to evaluation mode
# Example with PH2 Dataset
class PH2Dataset(Dataset):
"""
Dataset for PH2 images, which are in png format.
PH2 contains skin lesions images classified as
- Common Nevus (benign)
- Atypical Nevus (benign)
- Melanoma (malignant)
No need for is real label here, since this is purely for testing
"""
def __init__(self, dir_path, metadata, transform=None):
super(PH2Dataset, self).__init__()
self.dir_path = dir_path
self.transform = transform
self.image_files = [os.path.join(dir_path, f) for f in os.listdir(dir_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
# Load metadata w/ polars (only 2 columns)
self.metadata = pl.read_csv(metadata)
self.diagnostic_mapping = {
"Common Nevus": 0,
"Atypical Nevus": 0,
"Melanoma": 1,
}
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
# The image name in the metadata csv are like IMD003
image_id = self.image_files[idx].split('/')[-1].split('.')[0]
# Still need the entire path to open the image
image = Image.open(self.image_files[idx]).convert('RGB')
if self.transform: # Apply transform if it exists
image = self.transform(image)
diagnosis = self.metadata.filter(pl.col("image_name") == image_id).select("diagnosis").item()
label = torch.tensor(self.diagnostic_mapping[diagnosis], dtype=torch.int16)
return image, label
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Image net mean and std
transforms.Resize((224, 224)), # Dimensions for Efficient Net v2
])
ph_2_images = "/content/data/ph2_data/images"
ph_2_metadata = "/content/data/ph2_data/ph_2_dataset.csv"
ex_dataset = PH2Dataset(ph_2_images, ph_2_metadata, transform)
ex_loader = DataLoader(ex_dataset, batch_size=64, shuffle=False)
for (images, labels) in test_loader:
images = images.to(device)
labels = labels.to(device)
output = model(images)
y_pred_prob = torch.sigmoid(output).cpu().numpy().ravel()
y_pred = np.where(y_pred_prob < 0.5, 0, 1)
return y_pred
This model was trained with the ISIC 2024 challenge and ISIC 2024 synthetic datasets.
For the ISIC 2024 Challenge data, an 80-20 train test split was applied, and the test split was used to evaluate the model.
| Training Loss | Epoch | Step |
|---|---|---|
| 0.5027 | 1 | 100 |
| 0.5672 | 1 | 200 |
| 0.5373 | 1 | 300 |
| 0.4693 | 1 | 400 |
| 5.3829 | 1 | 500 |
| 0.4872 | 1 | 600 |
| 0.4717 | 1 | 700 |
| 0.4550 | 1 | 800 |
| 0.4185 | 1 | 900 |
| 0.4142 | 1 | 1000 |
| 0.3570 | 1 | 1100 |
| 0.3877 | 1 | 1200 |
| 0.4282 | 1 | 1300 |
| 8.8676 | 1 | 1400 |
| 0.3732 | 1 | 1500 |
| 0.3522 | 1 | 1600 |
| 0.3065 | 1 | 1700 |
| 0.3732 | 1 | 1800 |
| 0.3965 | 1 | 1900 |
| 0.4727 | 1 | 2000 |
| 0.3407 | 1 | 2100 |
| 0.3421 | 1 | 2200 |
| 0.3847 | 1 | 2300 |
| 0.3911 | 1 | 2400 |
| 0.4006 | 1 | 2500 |
| 0.2836 | 1 | 2600 |
| 0.3968 | 1 | 2700 |
| 0.3796 | 1 | 2800 |
| 0.3317 | 1 | 2900 |
| 0.2762 | 1 | 3000 |
| 0.3027 | 1 | 3100 |
| 0.3002 | 1 | 3200 |
| 0.3672 | 1 | 3300 |
| 0.2660 | 1 | 3400 |
| 0.3145 | 1 | 3500 |
| 0.4098 | 1 | 3600 |
| 0.3156 | 1 | 3700 |
| 0.2762 | 1 | 3800 |
| 0.2557 | 1 | 3900 |
| 0.3204 | 1 | 4000 |
| 0.3097 | 1 | 4100 |
| 0.2790 | 1 | 4200 |
| 0.3395 | 1 | 4300 |
| 0.2888 | 1 | 4400 |
| 0.3002 | 1 | 4500 |
| 0.3388 | 1 | 4600 |
| 0.3744 | 1 | 4700 |
| 0.3143 | 1 | 4800 |
| 0.3501 | 1 | 4900 |
| 0.2923 | 1 | 5000 |
| 0.3152 | 1 | 5100 |
| 0.3380 | 1 | 5200 |
Base model
timm/vit_base_patch16_dinov3.lvd1689m