File size: 6,405 Bytes
9d8e0f4
 
46de2be
 
 
 
 
 
 
 
 
 
9d8e0f4
 
46de2be
 
 
 
 
 
 
 
 
 
ba562d8
46de2be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
---
tags:
- image_classification
- computer_vision
license: mit
datasets:
- p2pfl/CIFAR10
language:
- en
pipeline_tag: image-classification
metrics:
- f1
---

# SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers

### Model Description

Implementation of the ***SAG-ViT*** model as proposed in the [SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers](https://arxiv.org/abs/2411.09420) paper. 

It is a novel transformer framework designed to enhance Vision Transformers (ViT) with scale-awareness and refined patch-level feature embeddings. It extracts multiscale features using EfficientNetV2 organizes patches into a graph based on spatial relationships, and refines them with a Graph Attention Network (GAT). A Transformer encoder then integrates these embeddings globally, capturing long-range dependencies for comprehensive image understanding.

### Model Architecture

![SAGViTArchitecture](images/SAG-ViT.png) 

_Image source: [SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers](https://arxiv.org/abs/2411.09420)_

### Usage

SAG-ViT expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape `(N, 3, H, W)`, where `N` is the number of images, `H` and `W` are expected to be at least `49` pixels.
The images have to be loaded in to a range of `[0, 1]` and then normalized using `mean = [0.485, 0.456, 0.406]`
and `std = [0.229, 0.224, 0.225]`.

To train or run inference on our model, refer to the following steps: 

Clone our repository and load the model pretrained on CIFAR-10 dataset.
```bash
git clone https://huggingface.co/shravvvv/SAG-ViT
cd SAG-ViT
```

Install required dependencies.
```bash
pip install -r requirements.txt
```

Use `from_pretrained` to load the model from Hugging Face Hub and run inference on a sample input image.
```python
from transformers import AutoModel, AutoConfig
from PIL import Image
from torchvision import transforms
import torch

# Step 1: Load the model and configuration directly from Hugging Face Hub
repo_name = "shravvvv/SAG-ViT" 
config = AutoConfig.from_pretrained(repo_name)  # Load config from hub
model = AutoModel.from_pretrained(repo_name, config=config)  # Load model from hub

# Step 2: Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match the expected input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Example normalization
])

# Step 3: Load and preprocess the input image
input_image_path = "path/to/your/image.jpg"
img = Image.open(input_image_path).convert("RGB")
img = transform(img).unsqueeze(0)  # Add batch dimension

# Step 4: Ensure the model is in evaluation mode
model.eval()

# Step 5: Run inference
with torch.no_grad():
    outputs = model(img)
    logits = outputs.logits  # Accessing logits from ModelOutput

# Step 6: Post-process the predictions
predicted_class_index = torch.argmax(logits, dim=1)  # Get the predicted class index

# CIFAR-10 label mapping
class_names = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

# Get the predicted class name from the class index
predicted_class_name = class_names[predicted_class_index.item()]
print(f"Predicted class: {predicted_class_name}")
```

### Running Tests

If you clone our [repository](https://github.com/shravan-18/SAG-ViT), the *'tests'* folder will contain unit tests for each of our model's modules. Make sure you have a proper Python environment with the required dependencies installed. Then run:
```bash
python -m unittest discover -s tests
```

or, if you are using `pytest`, you can run:
```bash
pytest tests
```

**Results**  
We evaluated SAG-ViT on diverse datasets:
- **CIFAR-10** (natural images)
- **GTSRB** (traffic sign recognition)
- **NCT-CRC-HE-100K** (histopathological images)
- **NWPU-RESISC45** (remote sensing imagery)
- **PlantVillage** (agricultural imagery)

SAG-ViT achieves state-of-the-art results across all benchmarks, as shown in the table below (F1 scores):

<center>

| Backbone           | CIFAR-10 | GTSRB  | NCT-CRC-HE-100K | NWPU-RESISC45 | PlantVillage |
|--------------------|----------|--------|-----------------|---------------|--------------|
| DenseNet201        | 0.5427   | 0.9862 | 0.9214          | 0.4493        | 0.8725       |
| Vgg16              | 0.5345   | 0.8180 | 0.8234          | 0.4114        | 0.7064       |
| Vgg19              | 0.5307   | 0.7551 | 0.8178          | 0.3844        | 0.6811       |
| DenseNet121        | 0.5290   | 0.9813 | 0.9247          | 0.4381        | 0.8321       |
| AlexNet            | 0.6126   | 0.9059 | 0.8743          | 0.4397        | 0.7684       |
| Inception          | 0.7734   | 0.8934 | 0.8707          | 0.8707        | 0.8216       |
| ResNet             | 0.9172   | 0.9134 | 0.9478          | 0.9103        | 0.8905       |
| MobileNet          | 0.9169   | 0.3006 | 0.4965          | 0.1667        | 0.2213       |
| ViT - S            | 0.8465   | 0.8542 | 0.8234          | 0.6116        | 0.8654       |
| ViT - L            | 0.8637   | 0.8613 | 0.8345          | 0.8358        | 0.8842       |
| MNASNet1_0         | 0.1032   | 0.0024 | 0.0212          | 0.0011        | 0.0049       |
| ShuffleNet_V2_x1_0 | 0.3523   | 0.4244 | 0.4598          | 0.1808        | 0.3190       |
| SqueezeNet1_0      | 0.4328   | 0.8392 | 0.7843          | 0.3913        | 0.6638       |
| GoogLeNet          | 0.4954   | 0.9455 | 0.8631          | 0.3720        | 0.7726       |
| **Proposed (SAG-ViT)** | **0.9574** | **0.9958** | **0.9861** | **0.9549** | **0.9772** |

</center>

## Citation

If you find our [paper](https://arxiv.org/abs/2411.09420) and [code](https://github.com/shravan-18/SAG-ViT) helpful for your research, please consider citing our work and giving the repository a star:

```bibtex
@misc{SAGViT,
      title={SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers}, 
      author={Shravan Venkatraman and Jaskaran Singh Walia and Joe Dhanith P R},
      year={2024},
      eprint={2411.09420},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2411.09420}, 
}
```