Seeay commited on
Commit
0a1bcf1
·
verified ·
1 Parent(s): 7ae165c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +28 -0
README.md CHANGED
@@ -11,6 +11,34 @@ A simple CNN for handwritten digit classification, trained on the MNIST dataset.
11
  - Accuracy: 99.4% on test set
12
  - Pytorch
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Showcase
16
  ![image](https://cdn-uploads.huggingface.co/production/uploads/6945d6be622680b0eee91373/vNpvjnEHBU8KGbA3Wr7RB.png)
 
11
  - Accuracy: 99.4% on test set
12
  - Pytorch
13
 
14
+ # Usage
15
+ ```python
16
+ import torch
17
+ from torch import nn
18
+
19
+ #Define the architecture
20
+ class CNN(nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
24
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
25
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
26
+ self.fc1 = nn.Linear(64 * 7 * 7, 128)
27
+ self.fc2 = nn.Linear(128, 10)
28
+ self.relu = nn.ReLU()
29
+
30
+ def forward(self, x):
31
+ x = self.pool(self.relu(self.conv1(x)))
32
+ x = self.pool(self.relu(self.conv2(x)))
33
+ x = x.view(-1, 64 * 7 * 7)
34
+ x = self.relu(self.fc1(x))
35
+ return self.fc2(x)
36
+
37
+ #Load model
38
+ model = CNN()
39
+ model.load_state_dict(torch.load("mnist_cnn.pth"))
40
+ model.eval()
41
+ ```
42
 
43
  # Showcase
44
  ![image](https://cdn-uploads.huggingface.co/production/uploads/6945d6be622680b0eee91373/vNpvjnEHBU8KGbA3Wr7RB.png)