File size: 6,405 Bytes
9d8e0f4 46de2be 9d8e0f4 46de2be ba562d8 46de2be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | ---
tags:
- image_classification
- computer_vision
license: mit
datasets:
- p2pfl/CIFAR10
language:
- en
pipeline_tag: image-classification
metrics:
- f1
---
# SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers
### Model Description
Implementation of the ***SAG-ViT*** model as proposed in the [SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers](https://arxiv.org/abs/2411.09420) paper.
It is a novel transformer framework designed to enhance Vision Transformers (ViT) with scale-awareness and refined patch-level feature embeddings. It extracts multiscale features using EfficientNetV2 organizes patches into a graph based on spatial relationships, and refines them with a Graph Attention Network (GAT). A Transformer encoder then integrates these embeddings globally, capturing long-range dependencies for comprehensive image understanding.
### Model Architecture

_Image source: [SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers](https://arxiv.org/abs/2411.09420)_
### Usage
SAG-ViT expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape `(N, 3, H, W)`, where `N` is the number of images, `H` and `W` are expected to be at least `49` pixels.
The images have to be loaded in to a range of `[0, 1]` and then normalized using `mean = [0.485, 0.456, 0.406]`
and `std = [0.229, 0.224, 0.225]`.
To train or run inference on our model, refer to the following steps:
Clone our repository and load the model pretrained on CIFAR-10 dataset.
```bash
git clone https://huggingface.co/shravvvv/SAG-ViT
cd SAG-ViT
```
Install required dependencies.
```bash
pip install -r requirements.txt
```
Use `from_pretrained` to load the model from Hugging Face Hub and run inference on a sample input image.
```python
from transformers import AutoModel, AutoConfig
from PIL import Image
from torchvision import transforms
import torch
# Step 1: Load the model and configuration directly from Hugging Face Hub
repo_name = "shravvvv/SAG-ViT"
config = AutoConfig.from_pretrained(repo_name) # Load config from hub
model = AutoModel.from_pretrained(repo_name, config=config) # Load model from hub
# Step 2: Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to match the expected input size
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Example normalization
])
# Step 3: Load and preprocess the input image
input_image_path = "path/to/your/image.jpg"
img = Image.open(input_image_path).convert("RGB")
img = transform(img).unsqueeze(0) # Add batch dimension
# Step 4: Ensure the model is in evaluation mode
model.eval()
# Step 5: Run inference
with torch.no_grad():
outputs = model(img)
logits = outputs.logits # Accessing logits from ModelOutput
# Step 6: Post-process the predictions
predicted_class_index = torch.argmax(logits, dim=1) # Get the predicted class index
# CIFAR-10 label mapping
class_names = [
'airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck'
]
# Get the predicted class name from the class index
predicted_class_name = class_names[predicted_class_index.item()]
print(f"Predicted class: {predicted_class_name}")
```
### Running Tests
If you clone our [repository](https://github.com/shravan-18/SAG-ViT), the *'tests'* folder will contain unit tests for each of our model's modules. Make sure you have a proper Python environment with the required dependencies installed. Then run:
```bash
python -m unittest discover -s tests
```
or, if you are using `pytest`, you can run:
```bash
pytest tests
```
**Results**
We evaluated SAG-ViT on diverse datasets:
- **CIFAR-10** (natural images)
- **GTSRB** (traffic sign recognition)
- **NCT-CRC-HE-100K** (histopathological images)
- **NWPU-RESISC45** (remote sensing imagery)
- **PlantVillage** (agricultural imagery)
SAG-ViT achieves state-of-the-art results across all benchmarks, as shown in the table below (F1 scores):
<center>
| Backbone | CIFAR-10 | GTSRB | NCT-CRC-HE-100K | NWPU-RESISC45 | PlantVillage |
|--------------------|----------|--------|-----------------|---------------|--------------|
| DenseNet201 | 0.5427 | 0.9862 | 0.9214 | 0.4493 | 0.8725 |
| Vgg16 | 0.5345 | 0.8180 | 0.8234 | 0.4114 | 0.7064 |
| Vgg19 | 0.5307 | 0.7551 | 0.8178 | 0.3844 | 0.6811 |
| DenseNet121 | 0.5290 | 0.9813 | 0.9247 | 0.4381 | 0.8321 |
| AlexNet | 0.6126 | 0.9059 | 0.8743 | 0.4397 | 0.7684 |
| Inception | 0.7734 | 0.8934 | 0.8707 | 0.8707 | 0.8216 |
| ResNet | 0.9172 | 0.9134 | 0.9478 | 0.9103 | 0.8905 |
| MobileNet | 0.9169 | 0.3006 | 0.4965 | 0.1667 | 0.2213 |
| ViT - S | 0.8465 | 0.8542 | 0.8234 | 0.6116 | 0.8654 |
| ViT - L | 0.8637 | 0.8613 | 0.8345 | 0.8358 | 0.8842 |
| MNASNet1_0 | 0.1032 | 0.0024 | 0.0212 | 0.0011 | 0.0049 |
| ShuffleNet_V2_x1_0 | 0.3523 | 0.4244 | 0.4598 | 0.1808 | 0.3190 |
| SqueezeNet1_0 | 0.4328 | 0.8392 | 0.7843 | 0.3913 | 0.6638 |
| GoogLeNet | 0.4954 | 0.9455 | 0.8631 | 0.3720 | 0.7726 |
| **Proposed (SAG-ViT)** | **0.9574** | **0.9958** | **0.9861** | **0.9549** | **0.9772** |
</center>
## Citation
If you find our [paper](https://arxiv.org/abs/2411.09420) and [code](https://github.com/shravan-18/SAG-ViT) helpful for your research, please consider citing our work and giving the repository a star:
```bibtex
@misc{SAGViT,
title={SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers},
author={Shravan Venkatraman and Jaskaran Singh Walia and Joe Dhanith P R},
year={2024},
eprint={2411.09420},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2411.09420},
}
``` |