A0lgk commited on
Commit
667ecf6
·
verified ·
1 Parent(s): 0f7544c

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +342 -0
  2. mnist_exploration.ipynb +383 -0
README.md ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ library_name: sklearn
6
+ tags:
7
+ - mnist
8
+ - image-classification
9
+ - digits
10
+ - handwritten
11
+ - computer-vision
12
+ - logistic-regression
13
+ - machine-learning
14
+ datasets:
15
+ - ylecun/mnist
16
+ metrics:
17
+ - accuracy
18
+ - f1
19
+ - precision
20
+ - recall
21
+ pipeline_tag: image-classification
22
+ ---
23
+
24
+ # MNIST Handwritten Digit Classifier
25
+
26
+ A classical machine learning approach to handwritten digit recognition using Logistic Regression on the MNIST dataset.
27
+
28
+ ## Model Description
29
+
30
+ This model classifies 28x28 grayscale images of handwritten digits (0-9) using a simple yet effective Logistic Regression classifier. The project serves as an introduction to image classification and the MNIST dataset.
31
+
32
+ ### Intended Uses
33
+
34
+ - **Educational**: Learning image classification fundamentals
35
+ - **Benchmarking**: Baseline for comparing more complex models
36
+ - **Research**: Exploring classical ML on image data
37
+ - **Prototyping**: Quick digit recognition experiments
38
+
39
+ ## Training Data
40
+
41
+ **Dataset**: [ylecun/mnist](https://huggingface.co/datasets/ylecun/mnist)
42
+
43
+ | Split | Images |
44
+ |-------|--------|
45
+ | Train | 60,000 |
46
+ | Test | 10,000 |
47
+ | **Total** | **70,000** |
48
+
49
+ ### Data Characteristics
50
+
51
+ | Property | Value |
52
+ |----------|-------|
53
+ | Image Size | 28 x 28 pixels |
54
+ | Channels | 1 (Grayscale) |
55
+ | Classes | 10 (digits 0-9) |
56
+ | Pixel Range | 0-255 (raw), 0-1 (normalized) |
57
+ | Format | PNG/NumPy arrays |
58
+
59
+ ### Class Distribution
60
+
61
+ The dataset is relatively balanced across all 10 digit classes.
62
+
63
+ ## Model Architecture
64
+
65
+ ### Preprocessing Pipeline
66
+
67
+ ```
68
+ Raw Image (28x28, uint8)
69
+
70
+ Normalize to [0, 1] (divide by 255)
71
+
72
+ Flatten to vector (784 dimensions)
73
+
74
+ Logistic Regression Classifier
75
+
76
+ Softmax Probabilities (10 classes)
77
+ ```
78
+
79
+ ### Classifier Configuration
80
+
81
+ ```python
82
+ LogisticRegression(
83
+ max_iter=100,
84
+ solver='lbfgs',
85
+ multi_class='multinomial',
86
+ n_jobs=-1
87
+ )
88
+ ```
89
+
90
+ | Parameter | Value | Description |
91
+ |-----------|-------|-------------|
92
+ | max_iter | 100 | Maximum iterations for convergence |
93
+ | solver | lbfgs | L-BFGS optimization algorithm |
94
+ | multi_class | multinomial | True multiclass (not OvR) |
95
+ | n_jobs | -1 | Use all CPU cores |
96
+
97
+ ## Performance
98
+
99
+ ### Test Set Results
100
+
101
+ | Metric | Score |
102
+ |--------|-------|
103
+ | Accuracy | ~92% |
104
+ | Macro F1 | ~92% |
105
+ | Macro Precision | ~92% |
106
+ | Macro Recall | ~92% |
107
+
108
+ ### Per-Class Performance
109
+
110
+ | Digit | Precision | Recall | F1-Score |
111
+ |-------|-----------|--------|----------|
112
+ | 0 | ~0.95 | ~0.97 | ~0.96 |
113
+ | 1 | ~0.95 | ~0.97 | ~0.96 |
114
+ | 2 | ~0.91 | ~0.89 | ~0.90 |
115
+ | 3 | ~0.89 | ~0.90 | ~0.90 |
116
+ | 4 | ~0.92 | ~0.92 | ~0.92 |
117
+ | 5 | ~0.88 | ~0.87 | ~0.87 |
118
+ | 6 | ~0.94 | ~0.95 | ~0.94 |
119
+ | 7 | ~0.93 | ~0.91 | ~0.92 |
120
+ | 8 | ~0.88 | ~0.87 | ~0.88 |
121
+ | 9 | ~0.89 | ~0.90 | ~0.90 |
122
+
123
+ *Note: Performance varies slightly between runs*
124
+
125
+ ### Common Confusion Pairs
126
+
127
+ - 4 ↔ 9 (similar upper loops)
128
+ - 3 ↔ 8 (curved shapes)
129
+ - 5 ↔ 3 (similar strokes)
130
+ - 7 ↔ 1 (vertical strokes)
131
+
132
+ ## Usage
133
+
134
+ ### Installation
135
+
136
+ ```bash
137
+ pip install scikit-learn pandas numpy matplotlib seaborn pillow
138
+ ```
139
+
140
+ ### Load and Preprocess Data
141
+
142
+ ```python
143
+ import pandas as pd
144
+ import numpy as np
145
+ from PIL import Image
146
+
147
+ # Load from Hugging Face
148
+ df_train = pd.read_parquet("hf://datasets/ylecun/mnist/mnist/train-00000-of-00001.parquet")
149
+ df_test = pd.read_parquet("hf://datasets/ylecun/mnist/mnist/test-00000-of-00001.parquet")
150
+
151
+ def extract_image(row):
152
+ """Extract image as numpy array"""
153
+ img_data = row['image']
154
+ if isinstance(img_data, dict) and 'bytes' in img_data:
155
+ from io import BytesIO
156
+ img = Image.open(BytesIO(img_data['bytes']))
157
+ return np.array(img)
158
+ elif isinstance(img_data, Image.Image):
159
+ return np.array(img_data)
160
+ return np.array(img_data)
161
+
162
+ # Prepare data
163
+ X_train = np.array([extract_image(row) for _, row in df_train.iterrows()])
164
+ y_train = df_train['label'].values
165
+
166
+ # Normalize and flatten
167
+ X_train_flat = X_train.astype('float32').reshape(-1, 784) / 255.0
168
+ ```
169
+
170
+ ### Train Model
171
+
172
+ ```python
173
+ from sklearn.linear_model import LogisticRegression
174
+
175
+ model = LogisticRegression(
176
+ max_iter=100,
177
+ solver='lbfgs',
178
+ multi_class='multinomial',
179
+ n_jobs=-1
180
+ )
181
+ model.fit(X_train_flat, y_train)
182
+ ```
183
+
184
+ ### Inference
185
+
186
+ ```python
187
+ import joblib
188
+
189
+ # Load model
190
+ model = joblib.load('mnist_model.pkl')
191
+
192
+ # Predict single image
193
+ def predict_digit(image):
194
+ """
195
+ image: 28x28 numpy array or PIL Image
196
+ returns: predicted digit (0-9)
197
+ """
198
+ if isinstance(image, Image.Image):
199
+ image = np.array(image)
200
+
201
+ # Preprocess
202
+ image_flat = image.astype('float32').reshape(1, 784) / 255.0
203
+
204
+ # Predict
205
+ prediction = model.predict(image_flat)[0]
206
+ probabilities = model.predict_proba(image_flat)[0]
207
+
208
+ return prediction, probabilities
209
+
210
+ # Example
211
+ digit, probs = predict_digit(test_image)
212
+ print(f"Predicted: {digit} (confidence: {probs[digit]:.2%})")
213
+ ```
214
+
215
+ ### Visualization
216
+
217
+ ```python
218
+ import matplotlib.pyplot as plt
219
+ from sklearn.metrics import confusion_matrix
220
+ import seaborn as sns
221
+
222
+ # Confusion Matrix
223
+ y_pred = model.predict(X_test_flat)
224
+ cm = confusion_matrix(y_test, y_pred)
225
+
226
+ plt.figure(figsize=(10, 8))
227
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
228
+ xticklabels=range(10), yticklabels=range(10))
229
+ plt.xlabel('Predicted')
230
+ plt.ylabel('True')
231
+ plt.title('Confusion Matrix - MNIST')
232
+ plt.show()
233
+ ```
234
+
235
+ ### Average Digit Visualization
236
+
237
+ ```python
238
+ # Compute mean image per digit
239
+ fig, axes = plt.subplots(2, 5, figsize=(12, 5))
240
+ for digit in range(10):
241
+ ax = axes[digit // 5, digit % 5]
242
+ mask = y_train == digit
243
+ mean_img = X_train[mask].mean(axis=0)
244
+ ax.imshow(mean_img, cmap='hot')
245
+ ax.set_title(f'Digit: {digit}')
246
+ ax.axis('off')
247
+ plt.tight_layout()
248
+ plt.show()
249
+ ```
250
+
251
+ ## Limitations
252
+
253
+ - **Simple Model**: Logistic Regression doesn't capture spatial relationships
254
+ - **No Data Augmentation**: Sensitive to rotation, scaling, translation
255
+ - **Grayscale Only**: Won't work with color images
256
+ - **Fixed Size**: Requires exactly 28x28 input
257
+ - **Clean Data**: Struggles with noisy or poorly centered digits
258
+
259
+ ## Comparison with Other Approaches
260
+
261
+ | Model | MNIST Accuracy |
262
+ |-------|----------------|
263
+ | **Logistic Regression** | **~92%** |
264
+ | Random Forest | ~97% |
265
+ | SVM (RBF kernel) | ~98% |
266
+ | MLP (2 hidden layers) | ~98% |
267
+ | CNN (LeNet-5) | ~99% |
268
+ | Modern CNNs | ~99.7% |
269
+
270
+ ## Technical Specifications
271
+
272
+ ### Dependencies
273
+
274
+ ```
275
+ scikit-learn>=1.0.0
276
+ pandas>=1.3.0
277
+ numpy>=1.20.0
278
+ matplotlib>=3.4.0
279
+ seaborn>=0.11.0
280
+ pillow>=8.0.0
281
+ ```
282
+
283
+ ### Hardware Requirements
284
+
285
+ | Task | Hardware | Time |
286
+ |------|----------|------|
287
+ | Training | CPU | ~2-5 min |
288
+ | Inference | CPU | < 1ms per image |
289
+ | Memory | RAM | ~500MB |
290
+
291
+ ## Files
292
+
293
+ ```
294
+ MNIST/
295
+ ├── README_HF.md # This model card
296
+ ├── mnist_exploration.ipynb # Full exploration notebook
297
+ ├── mnist_model.pkl # Trained model (generated)
298
+ └── figures/ # Visualizations (generated)
299
+ ```
300
+
301
+ ## Citation
302
+
303
+ ```bibtex
304
+ @article{lecun1998mnist,
305
+ title={Gradient-based learning applied to document recognition},
306
+ author={LeCun, Yann and Bottou, L{\'e}on and Bengio, Yoshua and Haffner, Patrick},
307
+ journal={Proceedings of the IEEE},
308
+ volume={86},
309
+ number={11},
310
+ pages={2278--2324},
311
+ year={1998}
312
+ }
313
+
314
+ @misc{mnist_hf,
315
+ title={MNIST Dataset},
316
+ author={LeCun, Yann and Cortes, Corinna and Burges, Christopher J.C.},
317
+ howpublished={Hugging Face Datasets},
318
+ url={https://huggingface.co/datasets/ylecun/mnist}
319
+ }
320
+ ```
321
+
322
+ ## License
323
+
324
+ MIT License
325
+
326
+ ## Acknowledgments
327
+
328
+ - Yann LeCun for creating MNIST
329
+ - Scikit-learn team for the ML library
330
+ - Hugging Face for dataset hosting
331
+
332
+ ---
333
+
334
+ ## Next Steps
335
+
336
+ For better performance, consider:
337
+
338
+ 1. **More Complex Models**: SVM, Random Forest, Neural Networks
339
+ 2. **Deep Learning**: CNNs with PyTorch/TensorFlow
340
+ 3. **Data Augmentation**: Rotation, scaling, elastic deformations
341
+ 4. **Feature Engineering**: HOG, SIFT features
342
+ 5. **Ensemble Methods**: Combine multiple classifiers
mnist_exploration.ipynb ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🔢 Exploration du Dataset MNIST\n",
8
+ "\n",
9
+ "Ce notebook explore le célèbre dataset MNIST de chiffres manuscrits."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "metadata": {},
15
+ "source": [
16
+ "## 1. Chargement des données"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "import pandas as pd\n",
26
+ "import numpy as np\n",
27
+ "import matplotlib.pyplot as plt\n",
28
+ "from PIL import Image\n",
29
+ "\n",
30
+ "# Chargement du dataset depuis Hugging Face\n",
31
+ "splits = {\n",
32
+ " 'train': 'mnist/train-00000-of-00001.parquet',\n",
33
+ " 'test': 'mnist/test-00000-of-00001.parquet'\n",
34
+ "}\n",
35
+ "\n",
36
+ "df_train = pd.read_parquet(\"hf://datasets/ylecun/mnist/\" + splits[\"train\"])\n",
37
+ "df_test = pd.read_parquet(\"hf://datasets/ylecun/mnist/\" + splits[\"test\"])\n",
38
+ "\n",
39
+ "print(f\"✅ Données chargées avec succès!\")\n",
40
+ "print(f\"📊 Taille du set d'entraînement: {len(df_train)} images\")\n",
41
+ "print(f\"📊 Taille du set de test: {len(df_test)} images\")"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "markdown",
46
+ "metadata": {},
47
+ "source": [
48
+ "## 2. Exploration des données"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "# Structure du DataFrame\n",
58
+ "print(\"Colonnes du dataset:\")\n",
59
+ "print(df_train.columns.tolist())\n",
60
+ "print(\"\\nAperçu des premières lignes:\")\n",
61
+ "df_train.head()"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "# Distribution des labels\n",
71
+ "print(\"Distribution des chiffres dans le set d'entraînement:\")\n",
72
+ "label_counts = df_train['label'].value_counts().sort_index()\n",
73
+ "\n",
74
+ "plt.figure(figsize=(10, 5))\n",
75
+ "plt.bar(label_counts.index, label_counts.values, color='steelblue', edgecolor='black')\n",
76
+ "plt.xlabel('Chiffre', fontsize=12)\n",
77
+ "plt.ylabel('Nombre d\\'images', fontsize=12)\n",
78
+ "plt.title('Distribution des chiffres dans MNIST (train)', fontsize=14)\n",
79
+ "plt.xticks(range(10))\n",
80
+ "for i, v in enumerate(label_counts.values):\n",
81
+ " plt.text(i, v + 100, str(v), ha='center', fontsize=9)\n",
82
+ "plt.tight_layout()\n",
83
+ "plt.show()"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "metadata": {},
89
+ "source": [
90
+ "## 3. Visualisation des images"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "def extract_image(row):\n",
100
+ " \"\"\"Extrait l'image depuis la colonne 'image' du DataFrame.\"\"\"\n",
101
+ " img_data = row['image']\n",
102
+ " if isinstance(img_data, dict) and 'bytes' in img_data:\n",
103
+ " # Format avec bytes\n",
104
+ " from io import BytesIO\n",
105
+ " img = Image.open(BytesIO(img_data['bytes']))\n",
106
+ " return np.array(img)\n",
107
+ " elif isinstance(img_data, Image.Image):\n",
108
+ " return np.array(img_data)\n",
109
+ " elif isinstance(img_data, np.ndarray):\n",
110
+ " return img_data\n",
111
+ " else:\n",
112
+ " # Essayer de convertir directement\n",
113
+ " return np.array(img_data)\n",
114
+ "\n",
115
+ "# Afficher quelques exemples\n",
116
+ "fig, axes = plt.subplots(2, 5, figsize=(12, 5))\n",
117
+ "fig.suptitle('Exemples d\\'images MNIST', fontsize=14)\n",
118
+ "\n",
119
+ "for idx, ax in enumerate(axes.flat):\n",
120
+ " img = extract_image(df_train.iloc[idx])\n",
121
+ " label = df_train.iloc[idx]['label']\n",
122
+ " ax.imshow(img, cmap='gray')\n",
123
+ " ax.set_title(f'Label: {label}', fontsize=11)\n",
124
+ " ax.axis('off')\n",
125
+ "\n",
126
+ "plt.tight_layout()\n",
127
+ "plt.show()"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "# Afficher un exemple de chaque chiffre\n",
137
+ "fig, axes = plt.subplots(2, 5, figsize=(12, 5))\n",
138
+ "fig.suptitle('Un exemple de chaque chiffre (0-9)', fontsize=14)\n",
139
+ "\n",
140
+ "for digit in range(10):\n",
141
+ " ax = axes[digit // 5, digit % 5]\n",
142
+ " sample = df_train[df_train['label'] == digit].iloc[0]\n",
143
+ " img = extract_image(sample)\n",
144
+ " ax.imshow(img, cmap='gray')\n",
145
+ " ax.set_title(f'Chiffre: {digit}', fontsize=11)\n",
146
+ " ax.axis('off')\n",
147
+ "\n",
148
+ "plt.tight_layout()\n",
149
+ "plt.show()"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "markdown",
154
+ "metadata": {},
155
+ "source": [
156
+ "## 4. Préparation des données pour le Machine Learning"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "# Convertir toutes les images en arrays numpy\n",
166
+ "print(\"Conversion des images en arrays numpy...\")\n",
167
+ "\n",
168
+ "X_train = np.array([extract_image(row) for _, row in df_train.iterrows()])\n",
169
+ "y_train = df_train['label'].values\n",
170
+ "\n",
171
+ "X_test = np.array([extract_image(row) for _, row in df_test.iterrows()])\n",
172
+ "y_test = df_test['label'].values\n",
173
+ "\n",
174
+ "print(f\"\\n✅ Conversion terminée!\")\n",
175
+ "print(f\"X_train shape: {X_train.shape}\")\n",
176
+ "print(f\"y_train shape: {y_train.shape}\")\n",
177
+ "print(f\"X_test shape: {X_test.shape}\")\n",
178
+ "print(f\"y_test shape: {y_test.shape}\")"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": [
187
+ "# Normalisation des données (0-1)\n",
188
+ "X_train_norm = X_train.astype('float32') / 255.0\n",
189
+ "X_test_norm = X_test.astype('float32') / 255.0\n",
190
+ "\n",
191
+ "# Aplatir les images pour les modèles classiques (28x28 -> 784)\n",
192
+ "X_train_flat = X_train_norm.reshape(X_train_norm.shape[0], -1)\n",
193
+ "X_test_flat = X_test_norm.reshape(X_test_norm.shape[0], -1)\n",
194
+ "\n",
195
+ "print(f\"Données normalisées et aplaties:\")\n",
196
+ "print(f\"X_train_flat shape: {X_train_flat.shape}\")\n",
197
+ "print(f\"X_test_flat shape: {X_test_flat.shape}\")"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "markdown",
202
+ "metadata": {},
203
+ "source": [
204
+ "## 5. Modèle simple de classification"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": null,
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "from sklearn.linear_model import LogisticRegression\n",
214
+ "from sklearn.metrics import accuracy_score, classification_report, confusion_matrix\n",
215
+ "import seaborn as sns\n",
216
+ "\n",
217
+ "# Entraînement d'un modèle de régression logistique\n",
218
+ "print(\"🔄 Entraînement du modèle de régression logistique...\")\n",
219
+ "print(\"(Cela peut prendre quelques minutes)\\n\")\n",
220
+ "\n",
221
+ "model = LogisticRegression(max_iter=100, solver='lbfgs', multi_class='multinomial', n_jobs=-1)\n",
222
+ "model.fit(X_train_flat, y_train)\n",
223
+ "\n",
224
+ "print(\"✅ Entraînement terminé!\")"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "# Évaluation du modèle\n",
234
+ "y_pred = model.predict(X_test_flat)\n",
235
+ "accuracy = accuracy_score(y_test, y_pred)\n",
236
+ "\n",
237
+ "print(f\"🎯 Précision sur le set de test: {accuracy:.4f} ({accuracy*100:.2f}%)\\n\")\n",
238
+ "print(\"Rapport de classification:\")\n",
239
+ "print(classification_report(y_test, y_pred))"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": null,
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "# Matrice de confusion\n",
249
+ "cm = confusion_matrix(y_test, y_pred)\n",
250
+ "\n",
251
+ "plt.figure(figsize=(10, 8))\n",
252
+ "sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', \n",
253
+ " xticklabels=range(10), yticklabels=range(10))\n",
254
+ "plt.xlabel('Prédiction', fontsize=12)\n",
255
+ "plt.ylabel('Vraie valeur', fontsize=12)\n",
256
+ "plt.title('Matrice de confusion - MNIST', fontsize=14)\n",
257
+ "plt.tight_layout()\n",
258
+ "plt.show()"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "markdown",
263
+ "metadata": {},
264
+ "source": [
265
+ "## 6. Visualisation des prédictions"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "metadata": {},
272
+ "outputs": [],
273
+ "source": [
274
+ "# Afficher quelques prédictions\n",
275
+ "fig, axes = plt.subplots(3, 5, figsize=(14, 8))\n",
276
+ "fig.suptitle('Exemples de prédictions', fontsize=14)\n",
277
+ "\n",
278
+ "indices = np.random.choice(len(X_test), 15, replace=False)\n",
279
+ "\n",
280
+ "for i, (ax, idx) in enumerate(zip(axes.flat, indices)):\n",
281
+ " ax.imshow(X_test[idx], cmap='gray')\n",
282
+ " pred = y_pred[idx]\n",
283
+ " true = y_test[idx]\n",
284
+ " color = 'green' if pred == true else 'red'\n",
285
+ " ax.set_title(f'Préd: {pred} | Vrai: {true}', color=color, fontsize=10)\n",
286
+ " ax.axis('off')\n",
287
+ "\n",
288
+ "plt.tight_layout()\n",
289
+ "plt.show()"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": null,
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "# Afficher les erreurs\n",
299
+ "errors = np.where(y_pred != y_test)[0]\n",
300
+ "print(f\"Nombre d'erreurs: {len(errors)} sur {len(y_test)} ({len(errors)/len(y_test)*100:.2f}%)\\n\")\n",
301
+ "\n",
302
+ "# Afficher quelques erreurs\n",
303
+ "fig, axes = plt.subplots(2, 5, figsize=(14, 6))\n",
304
+ "fig.suptitle('Exemples d\\'erreurs de classification', fontsize=14)\n",
305
+ "\n",
306
+ "for i, ax in enumerate(axes.flat):\n",
307
+ " if i < len(errors):\n",
308
+ " idx = errors[i]\n",
309
+ " ax.imshow(X_test[idx], cmap='gray')\n",
310
+ " ax.set_title(f'Préd: {y_pred[idx]} | Vrai: {y_test[idx]}', color='red', fontsize=10)\n",
311
+ " ax.axis('off')\n",
312
+ "\n",
313
+ "plt.tight_layout()\n",
314
+ "plt.show()"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "markdown",
319
+ "metadata": {},
320
+ "source": [
321
+ "## 7. Analyse des pixels moyens par chiffre"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "execution_count": null,
327
+ "metadata": {},
328
+ "outputs": [],
329
+ "source": [
330
+ "# Calculer l'image moyenne pour chaque chiffre\n",
331
+ "fig, axes = plt.subplots(2, 5, figsize=(12, 5))\n",
332
+ "fig.suptitle('Image moyenne pour chaque chiffre', fontsize=14)\n",
333
+ "\n",
334
+ "for digit in range(10):\n",
335
+ " ax = axes[digit // 5, digit % 5]\n",
336
+ " mask = y_train == digit\n",
337
+ " mean_img = X_train[mask].mean(axis=0)\n",
338
+ " ax.imshow(mean_img, cmap='hot')\n",
339
+ " ax.set_title(f'Chiffre: {digit}', fontsize=11)\n",
340
+ " ax.axis('off')\n",
341
+ "\n",
342
+ "plt.tight_layout()\n",
343
+ "plt.show()"
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "markdown",
348
+ "metadata": {},
349
+ "source": [
350
+ "---\n",
351
+ "## 📝 Résumé\n",
352
+ "\n",
353
+ "Dans ce notebook, nous avons:\n",
354
+ "1. Chargé le dataset MNIST depuis Hugging Face\n",
355
+ "2. Exploré la structure et la distribution des données\n",
356
+ "3. Visualisé des exemples d'images\n",
357
+ "4. Préparé les données pour le machine learning\n",
358
+ "5. Entraîné un modèle de régression logistique simple\n",
359
+ "6. Évalué les performances du modèle\n",
360
+ "7. Analysé les images moyennes par chiffre\n",
361
+ "\n",
362
+ "**Prochaines étapes possibles:**\n",
363
+ "- Essayer d'autres modèles (SVM, Random Forest, KNN)\n",
364
+ "- Implémenter un réseau de neurones avec TensorFlow/PyTorch\n",
365
+ "- Appliquer des techniques d'augmentation de données\n",
366
+ "- Explorer la réduction de dimensionnalité (PCA, t-SNE)"
367
+ ]
368
+ }
369
+ ],
370
+ "metadata": {
371
+ "kernelspec": {
372
+ "display_name": "Python 3",
373
+ "language": "python",
374
+ "name": "python3"
375
+ },
376
+ "language_info": {
377
+ "name": "python",
378
+ "version": "3.10.0"
379
+ }
380
+ },
381
+ "nbformat": 4,
382
+ "nbformat_minor": 4
383
+ }