File size: 1,806 Bytes
066bef2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
---
license: mit
tags:
  - vqvae
  - image-generation
  - unsupervised-learning
  - pytorch
  - imagenet
  - generative-model
datasets:
  - imagenet-200
library_name: pytorch
model-index:
  - name: VQ-VAE-ImageNet200
    results:
      - task:
          type: image-generation
          name: Image Generation
        dataset:
          name: Tiny ImageNet (ImageNet-200)
          type: image-classification
        metrics:
          - name: FID
            type: frechet-inception-distance
            value: 102.87
---

# VQ-VAE for Tiny ImageNet (ImageNet-200)

This repository contains a **Vector Quantized Variational Autoencoder (VQ-VAE)** trained on the Tiny ImageNet-200 dataset using PyTorch. It is part of an image augmentation and representation learning pipeline for generative modeling and unsupervised learning tasks.

---

## 🧠 Model Details

- **Model Type**: Vector Quantized Variational Autoencoder (VQ-VAE)
- **Dataset**: Tiny ImageNet (ImageNet-200)
- **Epochs**: 35  
- **Latent Space**: Discrete codebook (vector quantization)  
- **Input Size**: 64×64 RGB  
- **Loss Function**: Mean Squared Error (MSE) + VQ commitment loss  
- **Final Training Loss**: ~0.0292  
- **FID Score**: ~102.87  
- **Architecture**: 3-layer CNN Encoder & Decoder with quantization bottleneck

---

## 📦 Files

- `generator.pt` — Trained VQ-VAE model weights
- `loss_curve.png` — Plot of training loss across 35 epochs
- `fid_score.json` — FID evaluation result on 1000 generated samples
- `fid_real/` — 1000 real Tiny ImageNet samples used for FID
- `fid_fake/` — 1000 VQ-VAE reconstructions used for FID

---

## 🔧 Usage

```python
import torch
from models.vqvae.model import VQVAE

model = VQVAE()
model.load_state_dict(torch.load("generator.pt", map_location="cpu"))
model.eval()