MoE-CNN
Model Description
- Model type: Image Classification
- License: MIT
How to Get Started with the Model
Use the code below to get started with the model.
model = MixtureOfExperts(num_experts=10)
checkpoint_path = "FP_ML_MOE_SIMPLE_99_75.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Validation Accuracy: {checkpoint["val_accuracy"]:.2f}")
input_data = torch.randn(1, 1, 28, 28)
results = model.predict(input_data.to(device))
print("Results:", results)
Training Details
Training Data
https://huggingface.co/datasets/ylecun/mnist
Training Procedure
Data Augmentation
- RandomRotation(10)
- RandomAffine(0, shear=10)
- RandomAffine(0, translate=(0.1, 0.1))
- RandomResizedCrop(28, scale=(0.8, 1.0))
- RandomPerspective(distortion_scale=0.2, p=0.5)
- Resize((28, 28))
Training Hyperparameters
Adam with learning rate of 0.001 for fast initial convergence SGD with learning rate of 0.01 and learning rate decay to 0.001
Size
2,247,151 parameters with 674,145 effective parameters
Evaluation
Testing Data, Factors & Metrics
Testing Data
https://huggingface.co/datasets/ylecun/mnist
Metrics
- Accuracy: 99.75%
- Error rate: 0.25%
Technical Specifications [optional]
Model Architecture
Mixture-of-Experts (MoE) architecture with a simple CNN as the experts.
Dataset used to train Mikask/moe-cnn
Evaluation results
- Accuracy on MNISTself-reported99.750