Swin-Small Fine-tuned on CIFAR-10
This model is a fine-tuned version of microsoft/swin-small-patch4-window7-224 on the CIFAR-10 dataset, achieving 98.01% test accuracy.
Model Description
Swin Transformer is a hierarchical vision transformer that uses shifted windows for efficient computation. This Small variant has been fine-tuned on CIFAR-10 for image classification tasks.
- Model type: Image Classification
- Base Model: microsoft/swin-small-patch4-window7-224
- Dataset: CIFAR-10 (10 classes)
- License: Apache 2.0
Performance
| Metric | Value |
|---|---|
| Test Accuracy | 98.01% |
| Final Training Accuracy | 92.60% |
| Final Validation Accuracy | 89.61% |
Training Details
Training Hyperparameters
- Epochs: 50
- Optimizer: AdamW
- Initial learning rate: 1e-4
- Minimum learning rate: 1e-6
- Weight decay: 0.05
- Learning Rate Scheduler: CosineAnnealingLR
- Batch Size:
- Training: 128
- Validation: 512
- Test: 512
Hardware
- GPU: NVIDIA RTX PRO 5000 (48GB VRAM)
- CUDA Version: 13.0
- Driver Version: 580.126.09
Usage
Quick Start
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
# Load model and processor
processor = AutoImageProcessor.from_pretrained("clr4takeoff/swin-small-cifar10")
model = AutoModelForImageClassification.from_pretrained("clr4takeoff/swin-small-cifar10")
# Load and process image
image = Image.open("path/to/your/image.jpg")
inputs = processor(images=image, return_tensors="pt")
# Inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
print(f"Predicted class: {predicted_class}")
CIFAR-10 Class Labels
class_labels = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]
predicted_label = class_labels[predicted_class]
print(f"Predicted label: {predicted_label}")
Batch Inference
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Prepare dataset
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=32)
# Inference loop
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in dataloader:
inputs = processor(images=images, return_tensors="pt")
outputs = model(**inputs)
predictions = outputs.logits.argmax(-1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
accuracy = correct / total
print(f"Accuracy: {accuracy:.2%}")
Citation
If you use this model, please cite:
@misc{swin-small-cifar10,
author = {clr4takeoff},
title = {Swin-Small Fine-tuned on CIFAR-10},
year = {2026},
publisher = {HuggingFace},
url = {https://huggingface.co/clr4takeoff/swin-small-cifar10}
}
Base Model Citation
@inproceedings{liu2021swin,
title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year={2021}
}
- Downloads last month
- 11
Model tree for clr4takeoff/swin-small-cifar10
Base model
microsoft/swin-small-patch4-window7-224