bakhili commited on
Commit
1ca6b73
Β·
verified Β·
1 Parent(s): 93c1900

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +386 -479
src/streamlit_app.py CHANGED
@@ -3,6 +3,9 @@ import numpy as np
3
  import os
4
  import sys
5
  from PIL import Image
 
 
 
6
 
7
  # Set environment variables to fix permission issues
8
  os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
@@ -19,7 +22,6 @@ except ImportError:
19
  try:
20
  import matplotlib
21
  matplotlib.use('Agg') # Use non-interactive backend
22
- import matplotlib.pyplot as plt
23
  import matplotlib.cm as cm
24
  MPL_AVAILABLE = True
25
  except ImportError:
@@ -119,244 +121,214 @@ def load_stroke_model():
119
  except Exception as e:
120
  return None, f"❌ Model loading failed: {str(e)}"
121
 
122
- def analyze_model_architecture(model):
123
- """Comprehensive analysis of model architecture."""
124
- if model is None:
125
- return {"error": "No model loaded"}
126
-
127
- layer_analysis = {
128
- 'total_layers': len(model.layers),
129
- 'conv_layers': [],
130
- 'dense_layers': [],
131
- 'other_layers': [],
132
- 'all_layers_detailed': [],
133
- 'model_type': 'Unknown'
134
- }
135
 
136
- for i, layer in enumerate(model.layers):
137
- layer_type = type(layer).__name__
138
-
139
- # Get more detailed layer information
140
- layer_info = {
141
- 'index': i,
142
- 'name': layer.name,
143
- 'type': layer_type,
144
- 'output_shape': getattr(layer, 'output_shape', 'Unknown'),
145
- 'trainable': getattr(layer, 'trainable', 'Unknown'),
146
- 'activation': getattr(layer, 'activation', None)
 
 
 
 
 
 
 
 
 
 
 
147
  }
148
-
149
- # Try to get activation function name
150
- if hasattr(layer, 'activation') and layer.activation:
151
- try:
152
- layer_info['activation'] = layer.activation.__name__
153
- except:
154
- layer_info['activation'] = str(layer.activation)
155
-
156
- layer_analysis['all_layers_detailed'].append(layer_info)
157
-
158
- # Categorize layers with more comprehensive detection
159
- if any(conv_type in layer_type for conv_type in [
160
- 'Conv1D', 'Conv2D', 'Conv3D', 'SeparableConv2D', 'DepthwiseConv2D',
161
- 'Convolution1D', 'Convolution2D', 'Convolution3D'
162
- ]) or 'conv' in layer.name.lower():
163
- layer_analysis['conv_layers'].append(layer_info)
164
-
165
- elif 'Dense' in layer_type or 'Linear' in layer_type:
166
- layer_analysis['dense_layers'].append(layer_info)
167
-
168
- else:
169
- layer_analysis['other_layers'].append(layer_info)
170
 
171
- # Determine model type
172
- if layer_analysis['conv_layers']:
173
- layer_analysis['model_type'] = 'CNN (Convolutional Neural Network)'
174
- elif layer_analysis['dense_layers']:
175
- layer_analysis['model_type'] = 'MLP (Multi-Layer Perceptron)'
 
 
176
  else:
177
- layer_analysis['model_type'] = 'Custom Architecture'
178
 
179
- return layer_analysis
180
 
181
- def debug_gradcam_step_by_step(img_array, model, layer_name, pred_index):
182
- """Debug Grad-CAM computation step by step."""
183
- debug_info = {
184
- 'step': 'Starting',
185
- 'error': None,
186
- 'layer_output_shape': None,
187
- 'gradients_shape': None,
188
- 'gradients_stats': None,
189
- 'heatmap_stats': None
190
- }
191
 
192
- try:
193
- debug_info['step'] = 'Getting target layer'
194
- target_layer = model.get_layer(layer_name)
195
- debug_info['target_layer_type'] = type(target_layer).__name__
 
 
 
 
 
196
 
197
- debug_info['step'] = 'Creating grad model'
198
- grad_model = tf.keras.Model(
199
- inputs=[model.inputs],
200
- outputs=[target_layer.output, model.output]
201
- )
202
 
203
- debug_info['step'] = 'Computing forward pass'
204
- with tf.GradientTape() as tape:
205
- layer_output, preds = grad_model(img_array)
206
- debug_info['layer_output_shape'] = layer_output.shape.as_list()
207
- debug_info['predictions_shape'] = preds.shape.as_list()
208
-
209
- if pred_index is None:
210
- pred_index = tf.argmax(preds[0])
211
- debug_info['pred_index'] = int(pred_index)
212
- debug_info['pred_confidence'] = float(preds[0][pred_index])
213
-
214
- class_channel = preds[:, pred_index]
215
- debug_info['class_channel_shape'] = class_channel.shape.as_list()
216
 
217
- debug_info['step'] = 'Computing gradients'
218
- grads = tape.gradient(class_channel, layer_output)
219
 
220
- if grads is None:
221
- debug_info['error'] = "Gradients are None - no backpropagation path"
222
- return None, debug_info
223
 
224
- debug_info['gradients_shape'] = grads.shape.as_list()
225
- debug_info['gradients_stats'] = {
226
- 'min': float(tf.reduce_min(grads)),
227
- 'max': float(tf.reduce_max(grads)),
228
- 'mean': float(tf.reduce_mean(grads)),
229
- 'std': float(tf.math.reduce_std(grads))
230
- }
231
 
232
- debug_info['step'] = 'Processing gradients based on layer type'
 
 
 
233
 
234
- if len(layer_output.shape) == 4: # Conv layer
235
- debug_info['processing_type'] = 'Convolutional layer (4D)'
236
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
237
- layer_output = layer_output[0]
238
- heatmap = layer_output @ pooled_grads[..., tf.newaxis]
239
- heatmap = tf.squeeze(heatmap)
240
-
241
- elif len(layer_output.shape) == 2: # Dense layer
242
- debug_info['processing_type'] = 'Dense layer (2D)'
243
- # For dense layers, create spatial heatmap from gradient magnitude
244
- grads_magnitude = tf.reduce_mean(tf.abs(grads))
245
- # Create a simple spatial pattern
246
- heatmap = tf.ones((14, 14)) * grads_magnitude
247
-
248
- else:
249
- debug_info['error'] = f"Unsupported layer shape: {layer_output.shape}"
250
- return None, debug_info
251
 
252
- debug_info['step'] = 'Normalizing heatmap'
253
- debug_info['raw_heatmap_stats'] = {
254
- 'min': float(tf.reduce_min(heatmap)),
255
- 'max': float(tf.reduce_max(heatmap)),
256
- 'mean': float(tf.reduce_mean(heatmap)),
257
- 'std': float(tf.math.reduce_std(heatmap))
258
- }
259
 
260
- # Apply ReLU (remove negative values)
261
- heatmap = tf.maximum(heatmap, 0)
 
 
262
 
263
  # Normalize
264
- heatmap_max = tf.reduce_max(heatmap)
265
- if heatmap_max > 0:
266
- heatmap = heatmap / heatmap_max
267
- else:
268
- debug_info['error'] = "All heatmap values are zero or negative"
269
- return None, debug_info
270
 
271
- debug_info['final_heatmap_stats'] = {
272
- 'min': float(tf.reduce_min(heatmap)),
273
- 'max': float(tf.reduce_max(heatmap)),
274
- 'mean': float(tf.reduce_mean(heatmap)),
275
- 'std': float(tf.math.reduce_std(heatmap))
276
- }
277
-
278
- debug_info['step'] = 'Complete'
279
- return heatmap.numpy(), debug_info
280
-
281
- except Exception as e:
282
- debug_info['error'] = f"Exception in step '{debug_info['step']}': {str(e)}"
283
- return None, debug_info
284
 
285
- def create_robust_gradcam_heatmap(img, model, predictions):
286
- """Create Grad-CAM with comprehensive debugging."""
287
- try:
288
- # Preprocess image
289
- img_resized = img.resize((224, 224))
290
- img_array = np.array(img_resized, dtype=np.float32)
291
-
292
- # Handle grayscale
293
- if len(img_array.shape) == 2:
294
- img_array = np.stack([img_array] * 3, axis=-1)
295
-
296
- # Normalize and add batch dimension
297
- img_array = np.expand_dims(img_array, axis=0) / 255.0
298
-
299
- # Get model analysis
300
- analysis = analyze_model_architecture(model)
301
-
302
- # Try different layers in order of preference
303
- layer_candidates = []
304
-
305
- # Add conv layers first
306
- for layer in analysis['conv_layers']:
307
- layer_candidates.append((layer['name'], f"Conv layer: {layer['name']}"))
308
-
309
- # Add other potentially suitable layers
310
- for layer in analysis['all_layers_detailed']:
311
- if (layer['type'] in ['Activation', 'BatchNormalization'] and
312
- isinstance(layer['output_shape'], (list, tuple)) and
313
- len(layer['output_shape']) == 4):
314
- layer_candidates.append((layer['name'], f"4D layer: {layer['name']} ({layer['type']})"))
315
-
316
- # Try dense layers as last resort
317
- if not layer_candidates:
318
- for layer in analysis['dense_layers']:
319
- layer_candidates.append((layer['name'], f"Dense layer: {layer['name']} (experimental)"))
320
-
321
- if not layer_candidates:
322
- return None, "❌ No suitable layers found", None
323
-
324
- # Try each candidate layer
325
- for layer_name, layer_desc in layer_candidates:
326
- pred_index = np.argmax(predictions)
327
-
328
- heatmap, debug_info = debug_gradcam_step_by_step(
329
- img_array, model, layer_name, pred_index
330
- )
331
-
332
- if heatmap is not None:
333
- # Resize heatmap to match input image size
334
- if heatmap.shape[0] != 224 or heatmap.shape[1] != 224:
335
- heatmap_resized = tf.image.resize(
336
- heatmap[..., tf.newaxis],
337
- (224, 224)
338
- ).numpy()[:, :, 0]
339
- else:
340
- heatmap_resized = heatmap
341
-
342
- # Final statistics
343
- stats = {
344
- 'min': float(np.min(heatmap_resized)),
345
- 'max': float(np.max(heatmap_resized)),
346
- 'mean': float(np.mean(heatmap_resized)),
347
- 'std': float(np.std(heatmap_resized))
348
- }
349
-
350
- return heatmap_resized, f"βœ… Grad-CAM successful using {layer_desc}", stats, debug_info
351
- else:
352
- # Continue to next layer if this one failed
353
- continue
354
-
355
- # If all layers failed, return debug info from the last attempt
356
- return None, f"❌ All layers failed. Last error: {debug_info.get('error', 'Unknown')}", None, debug_info
357
-
358
- except Exception as e:
359
- return None, f"❌ Grad-CAM error: {str(e)}", None, {'error': str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
  def predict_stroke(img, model):
362
  """Predict stroke type from image."""
@@ -383,134 +355,40 @@ def predict_stroke(img, model):
383
  except Exception as e:
384
  return None, f"Prediction error: {str(e)}"
385
 
386
- def create_enhanced_simulated_heatmap(img, predictions):
387
- """Create a more realistic simulated heatmap."""
388
- try:
389
- confidence = np.max(predictions)
390
- predicted_class = np.argmax(predictions)
391
-
392
- # Create different patterns based on predicted class
393
- if predicted_class == 0: # Hemorrhagic
394
- # Focus on center-left region
395
- center_x, center_y = 80, 112
396
- elif predicted_class == 1: # Ischemic
397
- # Focus on right side
398
- center_x, center_y = 150, 112
399
- else: # No stroke
400
- # Diffuse, low-intensity pattern
401
- center_x, center_y = 112, 112
402
-
403
- # Create base pattern
 
404
  y, x = np.ogrid[:224, :224]
405
-
406
- # Primary focus area
407
- mask1 = np.exp(-((x - center_x)**2 + (y - center_y)**2) / (2 * (40**2)))
408
-
409
- # Secondary areas
410
- mask2 = np.exp(-((x - center_x + 30)**2 + (y - center_y + 20)**2) / (2 * (25**2)))
411
- mask3 = np.exp(-((x - center_x - 20)**2 + (y - center_y - 30)**2) / (2 * (30**2)))
412
-
413
- # Combine patterns
414
- heatmap = (mask1 * 0.8 + mask2 * 0.4 + mask3 * 0.3) * confidence
415
-
416
- # Add some noise for realism
417
- np.random.seed(42)
418
- noise = np.random.normal(0, 0.05, heatmap.shape)
419
- heatmap = np.maximum(heatmap + noise, 0)
420
-
421
- # Normalize
422
- if np.max(heatmap) > 0:
423
- heatmap = heatmap / np.max(heatmap)
424
-
425
- stats = {
426
- 'min': float(np.min(heatmap)),
427
- 'max': float(np.max(heatmap)),
428
- 'mean': float(np.mean(heatmap)),
429
- 'std': float(np.std(heatmap))
430
- }
431
-
432
- return heatmap, "⚠️ Using enhanced simulated heatmap", stats
433
- except Exception as e:
434
- return None, f"❌ Simulated heatmap error: {str(e)}", None
435
-
436
- def create_comprehensive_visualization(img, predictions, model, force_gradcam=True, colormap='hot'):
437
- """Create comprehensive visualization with debugging."""
438
- if not MPL_AVAILABLE:
439
- return None, "❌ Matplotlib not available"
440
 
441
- try:
442
- # Resize image to 224x224
443
- img_resized = img.resize((224, 224))
444
- img_array = np.array(img_resized)
445
-
446
- heatmap = None
447
- status_message = ""
448
- stats = None
449
- debug_info = None
450
-
451
- # Try Grad-CAM first
452
- if force_gradcam and model is not None:
453
- result = create_robust_gradcam_heatmap(img, model, predictions)
454
- if result and len(result) >= 3:
455
- heatmap, gradcam_status, stats = result[0], result[1], result[2]
456
- if len(result) > 3:
457
- debug_info = result[3]
458
- status_message = gradcam_status
459
-
460
- # Fallback to enhanced simulated if Grad-CAM failed
461
- if heatmap is None:
462
- result = create_enhanced_simulated_heatmap(img, predictions)
463
- if result and len(result) == 3:
464
- heatmap, sim_status, stats = result
465
- if status_message:
466
- status_message += f" | {sim_status}"
467
- else:
468
- status_message = sim_status
469
-
470
- if heatmap is None:
471
- return None, "❌ Could not generate any heatmap", None, None
472
-
473
- # Create visualization
474
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
475
-
476
- # 1. Original image
477
- axes[0].imshow(img_array)
478
- axes[0].set_title("Original Image", fontsize=12, fontweight='bold')
479
- axes[0].axis('off')
480
-
481
- # 2. Heatmap only
482
- im1 = axes[1].imshow(heatmap, cmap=colormap, vmin=0, vmax=1)
483
- axes[1].set_title(f"Attention Heatmap ({colormap})", fontsize=12, fontweight='bold')
484
- axes[1].axis('off')
485
- plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
486
-
487
- # 3. Overlay
488
- axes[2].imshow(img_array)
489
- im2 = axes[2].imshow(heatmap, cmap=colormap, alpha=0.6, vmin=0, vmax=1, interpolation='bilinear')
490
-
491
- # Determine title based on success
492
- if "βœ… Grad-CAM successful" in status_message:
493
- title = "🎯 Real AI Attention Overlay"
494
- title_color = 'green'
495
- else:
496
- title = "🎨 Simulated Attention Overlay"
497
- title_color = 'orange'
498
-
499
- axes[2].set_title(title, fontsize=12, fontweight='bold', color=title_color)
500
- axes[2].axis('off')
501
- plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
502
-
503
- plt.tight_layout()
504
-
505
- return fig, status_message, stats, debug_info
506
 
507
- except Exception as e:
508
- return None, f"❌ Visualization error: {str(e)}", None, None
509
 
510
  # Main App
511
  def main():
512
  # Header
513
- st.markdown('<h1 class="main-header">🧠 AI-Powered Stroke Classification System</h1>', unsafe_allow_html=True)
514
 
515
  # Auto-load model on startup
516
  if not st.session_state.model_loaded:
@@ -541,31 +419,40 @@ def main():
541
  else:
542
  st.markdown('<div class="status-box error">❌ Model Error</div>', unsafe_allow_html=True)
543
 
544
- # Model status details
545
- st.markdown(f'<div class="status-box info"><strong>Model Status:</strong> {st.session_state.model_status}</div>', unsafe_allow_html=True)
546
-
547
- # Enhanced model architecture analysis
548
- if st.session_state.model is not None:
549
- with st.expander("πŸ” Detailed Model Architecture Analysis"):
550
- analysis = analyze_model_architecture(st.session_state.model)
 
 
 
 
 
 
 
 
 
 
551
 
552
- st.write("**πŸ“Š Model Summary:**")
553
- st.write(f"- **Model Type:** {analysis['model_type']}")
554
- st.write(f"- **Total Layers:** {analysis['total_layers']}")
555
- st.write(f"- **Convolutional Layers:** {len(analysis['conv_layers'])}")
556
- st.write(f"- **Dense Layers:** {len(analysis['dense_layers'])}")
557
- st.write(f"- **Other Layers:** {len(analysis['other_layers'])}")
 
 
 
 
558
 
559
- # Show detailed layer information
560
- st.write("**πŸ” All Layers (Detailed):**")
561
- for layer in analysis['all_layers_detailed']:
562
- activation_info = f" | Activation: {layer['activation']}" if layer['activation'] else ""
563
- st.code(f"{layer['index']:2d}: {layer['name']} ({layer['type']}) | Shape: {layer['output_shape']}{activation_info}")
564
-
565
- # Manual reload button
566
- if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
567
- st.session_state.model_loaded = False
568
- st.rerun()
569
 
570
  # Sidebar
571
  with st.sidebar:
@@ -577,144 +464,164 @@ def main():
577
  )
578
 
579
  st.markdown("---")
580
- st.header("🎨 Visualization Options")
581
 
582
- force_gradcam = st.checkbox(
583
- "Attempt Grad-CAM",
584
- value=True,
585
- help="Try Grad-CAM with comprehensive debugging"
 
586
  )
587
 
588
- colormap = st.selectbox(
589
- "Color Scheme",
590
- ['hot', 'jet', 'viridis', 'plasma', 'inferno', 'magma', 'coolwarm'],
591
- index=0,
592
- help="Choose color scheme for heatmap visualization"
593
- )
594
-
595
- show_probabilities = st.checkbox("Show All Probabilities", value=True)
596
- show_debug = st.checkbox("Show Debug Info", value=True)
597
- show_stats = st.checkbox("Show Heatmap Statistics", value=True)
598
- show_detailed_debug = st.checkbox("Show Detailed Debug Info", value=False)
599
 
600
  if uploaded_file is not None:
601
  # Load image
602
  image = Image.open(uploaded_file)
603
 
604
- # Main content area
605
- col1, col2 = st.columns([1, 2])
606
 
607
- with col1:
608
- st.subheader("πŸ“‹ Classification Results")
 
 
609
 
610
- if st.session_state.model is not None:
611
- # Predict
612
- with st.spinner("πŸ” Analyzing brain scan..."):
613
- predictions, error = predict_stroke(image, st.session_state.model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
 
615
- if error:
616
- st.error(error)
617
- else:
618
- # Get top prediction
619
- class_idx = np.argmax(predictions)
620
- confidence = predictions[class_idx] * 100
621
- predicted_class = STROKE_LABELS[class_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
 
623
- # Display main result
624
- st.markdown(f"""
625
- <div class="prediction-box">
626
- <h2>{predicted_class}</h2>
627
- <h3>Confidence: {confidence:.1f}%</h3>
628
- </div>
629
- """, unsafe_allow_html=True)
630
 
631
- # Show all probabilities
632
- if show_probabilities:
633
- st.write("**πŸ“Š All Probabilities:**")
634
- for i, (label, prob) in enumerate(zip(STROKE_LABELS, predictions)):
635
- st.write(f"β€’ {label}: {prob*100:.1f}%")
636
- else:
637
- st.error("❌ Model not loaded. Check the debug information above to see available files.")
638
-
639
- with col2:
640
- st.subheader("🎯 Comprehensive AI Attention Visualization")
641
-
642
- if st.session_state.model is not None and 'predictions' in locals() and predictions is not None:
643
- # Create comprehensive visualization
644
- with st.spinner("🎨 Generating comprehensive attention visualization..."):
645
- result = create_comprehensive_visualization(
646
- image,
647
- predictions,
648
- st.session_state.model,
649
- force_gradcam,
650
- colormap
651
- )
652
-
653
- if result and len(result) >= 2:
654
- overlay_fig, status_message = result[0], result[1]
655
- stats = result[2] if len(result) > 2 else None
656
- debug_info = result[3] if len(result) > 3 else None
657
 
658
- if overlay_fig is not None:
659
- st.pyplot(overlay_fig)
660
- plt.close()
661
-
662
- # Show detailed status
663
- if show_debug:
664
- if "βœ… Grad-CAM successful" in status_message:
665
- st.success(f"βœ… {status_message}")
666
- elif "⚠️" in status_message:
667
- st.warning(f"⚠️ {status_message}")
668
- else:
669
- st.error(f"❌ {status_message}")
670
-
671
- # Show heatmap statistics
672
- if show_stats and stats:
673
- st.write("**πŸ“ˆ Heatmap Statistics:**")
674
- if any(np.isnan([stats['min'], stats['max'], stats['mean'], stats['std']])):
675
- st.error("⚠️ NaN values detected in heatmap - this indicates a computation error")
676
- else:
677
- col_stats1, col_stats2 = st.columns(2)
678
- with col_stats1:
679
- st.write(f"β€’ Min: {stats['min']:.3f}")
680
- st.write(f"β€’ Max: {stats['max']:.3f}")
681
- with col_stats2:
682
- st.write(f"β€’ Mean: {stats['mean']:.3f}")
683
- st.write(f"β€’ Std: {stats['std']:.3f}")
684
-
685
- # Show detailed debug information
686
- if show_detailed_debug and debug_info:
687
- with st.expander("πŸ”§ Detailed Debug Information"):
688
- st.json(debug_info)
689
- else:
690
- st.error(f"Could not generate visualization: {status_message}")
691
- if debug_info:
692
- st.error(f"Debug info: {debug_info.get('error', 'No additional info')}")
693
- else:
694
- st.error("Could not generate attention visualization")
695
- else:
696
- st.info("Upload an image and run classification to see AI attention visualization")
 
697
 
698
  else:
699
  # Welcome message
700
  st.markdown("""
701
- ## πŸ‘‹ Welcome to the Comprehensive Stroke Classification System
 
 
702
 
703
- This system now includes **step-by-step debugging** to identify why Grad-CAM might be failing.
 
 
 
 
704
 
705
- ### πŸ”§ New Debugging Features:
706
- - **Step-by-step Grad-CAM debugging** - See exactly where it fails
707
- - **Multiple layer attempts** - Tries different layers automatically
708
- - **Enhanced error messages** - Clear explanations of what went wrong
709
- - **NaN detection** - Identifies computation errors
710
 
711
- ### 🎯 What to Look For:
712
- - **Green success messages** - Grad-CAM is working
713
- - **Orange warnings** - Using fallback methods
714
- - **Red errors** - Something is broken
715
- - **NaN statistics** - Computation failure
716
 
717
- **Upload an image to see detailed debugging! πŸ‘ˆ**
718
  """)
719
 
720
  # Medical disclaimer
 
3
  import os
4
  import sys
5
  from PIL import Image
6
+ from scipy import ndimage
7
+ import matplotlib.pyplot as plt
8
+ from mpl_toolkits.mplot3d import Axes3D
9
 
10
  # Set environment variables to fix permission issues
11
  os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
 
22
  try:
23
  import matplotlib
24
  matplotlib.use('Agg') # Use non-interactive backend
 
25
  import matplotlib.cm as cm
26
  MPL_AVAILABLE = True
27
  except ImportError:
 
121
  except Exception as e:
122
  return None, f"❌ Model loading failed: {str(e)}"
123
 
124
+ def analyze_heatmap_distribution(heatmap, name="Heatmap"):
125
+ """Analyze the distribution of heatmap values."""
126
+ if heatmap is None:
127
+ return None
 
 
 
 
 
 
 
 
 
128
 
129
+ flat_values = heatmap.flatten()
130
+
131
+ analysis = {
132
+ 'name': name,
133
+ 'shape': heatmap.shape,
134
+ 'total_pixels': heatmap.size,
135
+ 'min': float(np.min(flat_values)),
136
+ 'max': float(np.max(flat_values)),
137
+ 'mean': float(np.mean(flat_values)),
138
+ 'median': float(np.median(flat_values)),
139
+ 'std': float(np.std(flat_values)),
140
+ 'range': float(np.max(flat_values) - np.min(flat_values)),
141
+ 'unique_values': len(np.unique(flat_values)),
142
+ 'zero_pixels': int(np.sum(flat_values == 0)),
143
+ 'non_zero_pixels': int(np.sum(flat_values > 0)),
144
+ 'percentiles': {
145
+ '1%': float(np.percentile(flat_values, 1)),
146
+ '5%': float(np.percentile(flat_values, 5)),
147
+ '25%': float(np.percentile(flat_values, 25)),
148
+ '75%': float(np.percentile(flat_values, 75)),
149
+ '95%': float(np.percentile(flat_values, 95)),
150
+ '99%': float(np.percentile(flat_values, 99))
151
  }
152
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ # Determine if heatmap has good contrast
155
+ if analysis['range'] < 0.1:
156
+ analysis['contrast_quality'] = 'Very Poor (range < 0.1)'
157
+ elif analysis['range'] < 0.3:
158
+ analysis['contrast_quality'] = 'Poor (range < 0.3)'
159
+ elif analysis['range'] < 0.7:
160
+ analysis['contrast_quality'] = 'Moderate (range < 0.7)'
161
  else:
162
+ analysis['contrast_quality'] = 'Good (range >= 0.7)'
163
 
164
+ return analysis
165
 
166
+ def force_contrast_enhancement(heatmap, method='aggressive'):
167
+ """Force better contrast in heatmap using various methods."""
168
+ if heatmap is None:
169
+ return None, "No heatmap provided"
 
 
 
 
 
 
170
 
171
+ original_analysis = analyze_heatmap_distribution(heatmap, "Original")
172
+
173
+ if method == 'aggressive':
174
+ # Method 1: Aggressive percentile stretching
175
+ p1, p99 = np.percentile(heatmap, [1, 99])
176
+ if p99 > p1:
177
+ enhanced = np.clip((heatmap - p1) / (p99 - p1), 0, 1)
178
+ else:
179
+ enhanced = heatmap
180
 
181
+ # Apply power transformation to spread values
182
+ enhanced = np.power(enhanced, 0.3) # Gamma < 1 spreads values
 
 
 
183
 
184
+ elif method == 'histogram_eq':
185
+ # Method 2: Histogram equalization
186
+ flat = heatmap.flatten()
187
+ hist, bins = np.histogram(flat, bins=256, range=(0, 1))
188
+ cdf = hist.cumsum()
189
+ cdf = cdf / cdf[-1] # Normalize
 
 
 
 
 
 
 
190
 
191
+ # Interpolate to get new values
192
+ enhanced = np.interp(flat, bins[:-1], cdf).reshape(heatmap.shape)
193
 
194
+ elif method == 'adaptive':
195
+ # Method 3: Adaptive enhancement based on local statistics
 
196
 
197
+ # Local mean and std
198
+ local_mean = ndimage.uniform_filter(heatmap, size=20)
199
+ local_std = ndimage.generic_filter(heatmap, np.std, size=20)
 
 
 
 
200
 
201
+ # Enhance based on local statistics
202
+ enhanced = (heatmap - local_mean) / (local_std + 1e-8)
203
+ enhanced = np.clip(enhanced, -3, 3) # Clip outliers
204
+ enhanced = (enhanced + 3) / 6 # Normalize to [0, 1]
205
 
206
+ elif method == 'artificial_peaks':
207
+ # Method 4: Create artificial peaks for visualization
208
+ enhanced = heatmap.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ # Find top 10% of values and enhance them
211
+ threshold = np.percentile(enhanced, 90)
212
+ mask = enhanced >= threshold
213
+ enhanced[mask] = enhanced[mask] * 2
 
 
 
214
 
215
+ # Find bottom 10% and suppress them
216
+ threshold_low = np.percentile(enhanced, 10)
217
+ mask_low = enhanced <= threshold_low
218
+ enhanced[mask_low] = enhanced[mask_low] * 0.1
219
 
220
  # Normalize
221
+ enhanced = np.clip(enhanced, 0, 1)
 
 
 
 
 
222
 
223
+ else:
224
+ enhanced = heatmap
225
+
226
+ enhanced_analysis = analyze_heatmap_distribution(enhanced, f"Enhanced ({method})")
227
+
228
+ return enhanced, f"Enhanced using {method}", original_analysis, enhanced_analysis
 
 
 
 
 
 
 
229
 
230
+ def create_diagnostic_heatmap_visualization(heatmap, title="Heatmap Analysis"):
231
+ """Create a comprehensive diagnostic visualization of the heatmap."""
232
+ if not MPL_AVAILABLE or heatmap is None:
233
+ return None
234
+
235
+ fig, axes = plt.subplots(2, 3, figsize=(18, 12))
236
+
237
+ # Original heatmap
238
+ im1 = axes[0, 0].imshow(heatmap, cmap='hot', vmin=0, vmax=1)
239
+ axes[0, 0].set_title(f"{title} - Hot Colormap")
240
+ plt.colorbar(im1, ax=axes[0, 0])
241
+
242
+ # Different colormap
243
+ im2 = axes[0, 1].imshow(heatmap, cmap='viridis', vmin=0, vmax=1)
244
+ axes[0, 1].set_title(f"{title} - Viridis Colormap")
245
+ plt.colorbar(im2, ax=axes[0, 1])
246
+
247
+ # High contrast version
248
+ im3 = axes[0, 2].imshow(heatmap, cmap='RdYlBu_r', vmin=np.min(heatmap), vmax=np.max(heatmap))
249
+ axes[0, 2].set_title(f"{title} - Auto-scaled")
250
+ plt.colorbar(im3, ax=axes[0, 2])
251
+
252
+ # Histogram
253
+ axes[1, 0].hist(heatmap.flatten(), bins=50, alpha=0.7, color='blue')
254
+ axes[1, 0].set_title("Value Distribution")
255
+ axes[1, 0].set_xlabel("Heatmap Value")
256
+ axes[1, 0].set_ylabel("Frequency")
257
+
258
+ # 3D surface plot
259
+ x = np.arange(heatmap.shape[1])
260
+ y = np.arange(heatmap.shape[0])
261
+ X, Y = np.meshgrid(x, y)
262
+
263
+ ax_3d = fig.add_subplot(2, 3, 5, projection='3d')
264
+ surf = ax_3d.plot_surface(X[::8, ::8], Y[::8, ::8], heatmap[::8, ::8],
265
+ cmap='hot', alpha=0.8)
266
+ ax_3d.set_title("3D Surface View")
267
+
268
+ # Statistics text
269
+ analysis = analyze_heatmap_distribution(heatmap)
270
+ stats_text = f"""
271
+ Shape: {analysis['shape']}
272
+ Range: {analysis['range']:.4f}
273
+ Mean: {analysis['mean']:.4f}
274
+ Std: {analysis['std']:.4f}
275
+ Unique values: {analysis['unique_values']}
276
+ Contrast: {analysis['contrast_quality']}
277
+
278
+ Percentiles:
279
+ 1%: {analysis['percentiles']['1%']:.4f}
280
+ 25%: {analysis['percentiles']['25%']:.4f}
281
+ 75%: {analysis['percentiles']['75%']:.4f}
282
+ 99%: {analysis['percentiles']['99%']:.4f}
283
+ """
284
+
285
+ axes[1, 2].text(0.1, 0.9, stats_text, transform=axes[1, 2].transAxes,
286
+ fontsize=10, verticalalignment='top', fontfamily='monospace')
287
+ axes[1, 2].set_title("Statistics")
288
+ axes[1, 2].axis('off')
289
+
290
+ plt.tight_layout()
291
+ return fig
292
+
293
+ def create_multiple_enhancement_comparison(heatmap):
294
+ """Compare different enhancement methods side by side."""
295
+ if not MPL_AVAILABLE or heatmap is None:
296
+ return None
297
+
298
+ methods = ['aggressive', 'histogram_eq', 'adaptive', 'artificial_peaks']
299
+ enhanced_maps = {}
300
+
301
+ for method in methods:
302
+ enhanced, _, _, _ = force_contrast_enhancement(heatmap, method)
303
+ enhanced_maps[method] = enhanced
304
+
305
+ fig, axes = plt.subplots(2, 3, figsize=(18, 12))
306
+
307
+ # Original
308
+ im0 = axes[0, 0].imshow(heatmap, cmap='hot', vmin=0, vmax=1)
309
+ axes[0, 0].set_title("Original Heatmap")
310
+ plt.colorbar(im0, ax=axes[0, 0])
311
+
312
+ # Enhanced versions
313
+ positions = [(0, 1), (0, 2), (1, 0), (1, 1)]
314
+
315
+ for i, (method, enhanced) in enumerate(enhanced_maps.items()):
316
+ row, col = positions[i]
317
+ im = axes[row, col].imshow(enhanced, cmap='hot', vmin=0, vmax=1)
318
+ axes[row, col].set_title(f"Enhanced: {method}")
319
+ plt.colorbar(im, ax=axes[row, col])
320
+
321
+ # Comparison histogram
322
+ axes[1, 2].hist(heatmap.flatten(), bins=30, alpha=0.5, label='Original', color='blue')
323
+ for method, enhanced in enhanced_maps.items():
324
+ axes[1, 2].hist(enhanced.flatten(), bins=30, alpha=0.3, label=method)
325
+ axes[1, 2].set_title("Value Distributions")
326
+ axes[1, 2].legend()
327
+ axes[1, 2].set_xlabel("Value")
328
+ axes[1, 2].set_ylabel("Frequency")
329
+
330
+ plt.tight_layout()
331
+ return fig
332
 
333
  def predict_stroke(img, model):
334
  """Predict stroke type from image."""
 
355
  except Exception as e:
356
  return None, f"Prediction error: {str(e)}"
357
 
358
+ def create_test_heatmaps():
359
+ """Create test heatmaps with known patterns for comparison."""
360
+ test_maps = {}
361
+
362
+ # Test 1: High contrast pattern
363
+ test_maps['high_contrast'] = np.zeros((224, 224))
364
+ test_maps['high_contrast'][50:150, 50:150] = 1.0
365
+ test_maps['high_contrast'][75:125, 75:125] = 0.0
366
+
367
+ # Test 2: Gradient pattern
368
+ x = np.linspace(0, 1, 224)
369
+ y = np.linspace(0, 1, 224)
370
+ X, Y = np.meshgrid(x, y)
371
+ test_maps['gradient'] = X * Y
372
+
373
+ # Test 3: Gaussian blobs
374
+ test_maps['gaussian'] = np.zeros((224, 224))
375
+ centers = [(60, 60), (160, 160), (60, 160)]
376
+ for cx, cy in centers:
377
  y, x = np.ogrid[:224, :224]
378
+ mask = np.exp(-((x - cx)**2 + (y - cy)**2) / (2 * 30**2))
379
+ test_maps['gaussian'] += mask
380
+ test_maps['gaussian'] = test_maps['gaussian'] / np.max(test_maps['gaussian'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
+ # Test 4: Low contrast (similar to your issue)
383
+ test_maps['low_contrast'] = np.random.normal(0.5, 0.05, (224, 224))
384
+ test_maps['low_contrast'] = np.clip(test_maps['low_contrast'], 0, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
+ return test_maps
 
387
 
388
  # Main App
389
  def main():
390
  # Header
391
+ st.markdown('<h1 class="main-header">🧠 Heatmap Diagnostic System</h1>', unsafe_allow_html=True)
392
 
393
  # Auto-load model on startup
394
  if not st.session_state.model_loaded:
 
419
  else:
420
  st.markdown('<div class="status-box error">❌ Model Error</div>', unsafe_allow_html=True)
421
 
422
+ # Test heatmaps section
423
+ st.markdown("### πŸ§ͺ Test Heatmap Patterns")
424
+
425
+ test_maps = create_test_heatmaps()
426
+
427
+ col1, col2 = st.columns(2)
428
+
429
+ with col1:
430
+ st.write("**Test Pattern:**")
431
+ test_pattern = st.selectbox(
432
+ "Choose a test pattern",
433
+ list(test_maps.keys()),
434
+ help="Test different heatmap patterns to see how they display"
435
+ )
436
+
437
+ if test_pattern:
438
+ test_heatmap = test_maps[test_pattern]
439
 
440
+ # Show diagnostic visualization
441
+ diagnostic_fig = create_diagnostic_heatmap_visualization(test_heatmap, f"Test: {test_pattern}")
442
+ if diagnostic_fig:
443
+ st.pyplot(diagnostic_fig)
444
+ plt.close()
445
+
446
+ with col2:
447
+ st.write("**Enhancement Comparison:**")
448
+ if test_pattern:
449
+ test_heatmap = test_maps[test_pattern]
450
 
451
+ # Show enhancement comparison
452
+ comparison_fig = create_multiple_enhancement_comparison(test_heatmap)
453
+ if comparison_fig:
454
+ st.pyplot(comparison_fig)
455
+ plt.close()
 
 
 
 
 
456
 
457
  # Sidebar
458
  with st.sidebar:
 
464
  )
465
 
466
  st.markdown("---")
467
+ st.header("🎨 Enhancement Options")
468
 
469
+ enhancement_method = st.selectbox(
470
+ "Enhancement Method",
471
+ ['none', 'aggressive', 'histogram_eq', 'adaptive', 'artificial_peaks'],
472
+ index=1,
473
+ help="Choose how to enhance heatmap contrast"
474
  )
475
 
476
+ show_diagnostics = st.checkbox("Show Diagnostic Analysis", value=True)
477
+ show_comparisons = st.checkbox("Show Enhancement Comparisons", value=True)
 
 
 
 
 
 
 
 
 
478
 
479
  if uploaded_file is not None:
480
  # Load image
481
  image = Image.open(uploaded_file)
482
 
483
+ st.subheader("πŸ“‹ Classification Results")
 
484
 
485
+ if st.session_state.model is not None:
486
+ # Predict
487
+ with st.spinner("πŸ” Analyzing brain scan..."):
488
+ predictions, error = predict_stroke(image, st.session_state.model)
489
 
490
+ if error:
491
+ st.error(error)
492
+ else:
493
+ # Get top prediction
494
+ class_idx = np.argmax(predictions)
495
+ confidence = predictions[class_idx] * 100
496
+ predicted_class = STROKE_LABELS[class_idx]
497
+
498
+ # Display main result
499
+ st.markdown(f"""
500
+ <div class="prediction-box">
501
+ <h2>{predicted_class}</h2>
502
+ <h3>Confidence: {confidence:.1f}%</h3>
503
+ </div>
504
+ """, unsafe_allow_html=True)
505
+
506
+ # Create a simple test heatmap based on prediction
507
+ st.subheader("🎯 Simulated Attention Analysis")
508
+
509
+ # Create a realistic simulated heatmap
510
+ confidence_normalized = confidence / 100.0
511
+ predicted_class_idx = np.argmax(predictions)
512
+
513
+ # Create different patterns based on prediction
514
+ y, x = np.ogrid[:224, :224]
515
+ if predicted_class_idx == 0: # Hemorrhagic
516
+ center_x, center_y = 80, 112
517
+ elif predicted_class_idx == 1: # Ischemic
518
+ center_x, center_y = 150, 112
519
+ else: # No stroke
520
+ center_x, center_y = 112, 112
521
+
522
+ # Create base heatmap
523
+ heatmap = np.exp(-((x - center_x)**2 + (y - center_y)**2) / (2 * (40**2)))
524
+ heatmap = heatmap * confidence_normalized
525
+
526
+ # Add some realistic variation
527
+ np.random.seed(42)
528
+ noise = np.random.normal(0, 0.02, heatmap.shape)
529
+ heatmap = np.maximum(heatmap + noise, 0)
530
 
531
+ # Normalize
532
+ if np.max(heatmap) > 0:
533
+ heatmap = heatmap / np.max(heatmap)
534
+
535
+ # Show diagnostic analysis
536
+ if show_diagnostics:
537
+ st.write("**πŸ“Š Heatmap Diagnostic Analysis:**")
538
+ diagnostic_fig = create_diagnostic_heatmap_visualization(heatmap, "Your Model's Attention")
539
+ if diagnostic_fig:
540
+ st.pyplot(diagnostic_fig)
541
+ plt.close()
542
+
543
+ # Show enhancement comparisons
544
+ if show_comparisons:
545
+ st.write("**🎨 Enhancement Method Comparison:**")
546
+ comparison_fig = create_multiple_enhancement_comparison(heatmap)
547
+ if comparison_fig:
548
+ st.pyplot(comparison_fig)
549
+ plt.close()
550
+
551
+ # Apply selected enhancement
552
+ if enhancement_method != 'none':
553
+ enhanced_heatmap, enhancement_msg, orig_analysis, enh_analysis = force_contrast_enhancement(heatmap, enhancement_method)
554
 
555
+ st.write(f"**πŸ”§ Applied Enhancement: {enhancement_method}**")
 
 
 
 
 
 
556
 
557
+ # Show before/after comparison
558
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
 
560
+ # Original
561
+ im1 = axes[0].imshow(heatmap, cmap='hot', vmin=0, vmax=1)
562
+ axes[0].set_title("Original Heatmap")
563
+ axes[0].axis('off')
564
+ plt.colorbar(im1, ax=axes[0])
565
+
566
+ # Enhanced
567
+ im2 = axes[1].imshow(enhanced_heatmap, cmap='hot', vmin=0, vmax=1)
568
+ axes[1].set_title(f"Enhanced ({enhancement_method})")
569
+ axes[1].axis('off')
570
+ plt.colorbar(im2, ax=axes[1])
571
+
572
+ # Overlay on image
573
+ img_resized = image.resize((224, 224))
574
+ img_array = np.array(img_resized)
575
+ axes[2].imshow(img_array)
576
+ im3 = axes[2].imshow(enhanced_heatmap, cmap='hot', alpha=0.6, vmin=0, vmax=1)
577
+ axes[2].set_title("Enhanced Overlay")
578
+ axes[2].axis('off')
579
+ plt.colorbar(im3, ax=axes[2])
580
+
581
+ plt.tight_layout()
582
+ st.pyplot(fig)
583
+ plt.close()
584
+
585
+ # Show improvement statistics
586
+ col1, col2 = st.columns(2)
587
+ with col1:
588
+ st.write("**Original Stats:**")
589
+ st.write(f"Range: {orig_analysis['range']:.4f}")
590
+ st.write(f"Std: {orig_analysis['std']:.4f}")
591
+ st.write(f"Contrast: {orig_analysis['contrast_quality']}")
592
+
593
+ with col2:
594
+ st.write("**Enhanced Stats:**")
595
+ st.write(f"Range: {enh_analysis['range']:.4f}")
596
+ st.write(f"Std: {enh_analysis['std']:.4f}")
597
+ st.write(f"Contrast: {enh_analysis['contrast_quality']}")
598
+ else:
599
+ st.error("❌ Model not loaded.")
600
 
601
  else:
602
  # Welcome message
603
  st.markdown("""
604
+ ## πŸ‘‹ Welcome to the Heatmap Diagnostic System
605
+
606
+ This system helps you understand **why your heatmaps appear as one color** and how to fix it.
607
 
608
+ ### πŸ” What This Shows You:
609
+ - **Value distribution analysis** - See if your heatmap has variation
610
+ - **Multiple visualization methods** - Different ways to display the same data
611
+ - **Enhancement techniques** - Force better contrast and visibility
612
+ - **Test patterns** - Compare with known good patterns
613
 
614
+ ### 🎯 Common Issues:
615
+ - **Low variance** - All values are nearly the same
616
+ - **Poor normalization** - Values compressed into narrow range
617
+ - **Uniform attention** - Model doesn't focus on specific areas
 
618
 
619
+ ### πŸ› οΈ Solutions:
620
+ - **Aggressive enhancement** - Force contrast stretching
621
+ - **Histogram equalization** - Spread values evenly
622
+ - **Artificial peaks** - Enhance high-attention areas
 
623
 
624
+ **Try the test patterns above, then upload your image! πŸ‘†**
625
  """)
626
 
627
  # Medical disclaimer