import streamlit as st import numpy as np import os import sys from PIL import Image from scipy import ndimage import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # Set environment variables to fix permission issues os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' os.environ['STREAMLIT_SERVER_HEADLESS'] = 'true' # Minimal imports to avoid conflicts try: import tensorflow as tf TF_AVAILABLE = True except ImportError: TF_AVAILABLE = False st.error("TensorFlow not available") try: import matplotlib matplotlib.use('Agg') # Use non-interactive backend import matplotlib.cm as cm MPL_AVAILABLE = True except ImportError: MPL_AVAILABLE = False # Page config st.set_page_config( page_title="Stroke Classifier", page_icon="๐Ÿง ", layout="wide") # Simple styling st.markdown(""" """, unsafe_allow_html=True) # Initialize session state if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False st.session_state.model = None st.session_state.model_status = "Not loaded" STROKE_LABELS = ["Hemorrhagic Stroke", "Ischemic Stroke", "No Stroke"] def find_model_file(): """Find the model file in various possible locations.""" possible_paths = [ "stroke_classification_model.h5", "./stroke_classification_model.h5", "/app/stroke_classification_model.h5", "src/stroke_classification_model.h5", os.path.join(os.getcwd(), "stroke_classification_model.h5") ] # Also check all .h5 files in current directory and subdirectories for root, dirs, files in os.walk('.'): for file in files: if file.endswith('.h5'): possible_paths.append(os.path.join(root, file)) for path in possible_paths: if os.path.exists(path): return path return None @st.cache_resource def load_stroke_model(): """Load model with caching.""" if not TF_AVAILABLE: return None, "โŒ TensorFlow not available" try: # Find the model file model_path = find_model_file() if model_path is None: # List all files to help debug current_files = [] for root, dirs, files in os.walk('.'): for file in files: current_files.append(os.path.join(root, file)) return None, f"โŒ Model file not found. Available files: {current_files[:10]}" st.info(f"Found model at: {model_path}") # Load model with minimal custom objects model = tf.keras.models.load_model(model_path, compile=False) return model, f"โœ… Model loaded successfully from: {model_path}" except Exception as e: return None, f"โŒ Model loading failed: {str(e)}" def analyze_heatmap_distribution(heatmap, name="Heatmap"): """Analyze the distribution of heatmap values.""" if heatmap is None: return None flat_values = heatmap.flatten() analysis = { 'name': name, 'shape': heatmap.shape, 'total_pixels': heatmap.size, 'min': float(np.min(flat_values)), 'max': float(np.max(flat_values)), 'mean': float(np.mean(flat_values)), 'median': float(np.median(flat_values)), 'std': float(np.std(flat_values)), 'range': float(np.max(flat_values) - np.min(flat_values)), 'unique_values': len(np.unique(flat_values)), 'zero_pixels': int(np.sum(flat_values == 0)), 'non_zero_pixels': int(np.sum(flat_values > 0)), 'percentiles': { '1%': float(np.percentile(flat_values, 1)), '5%': float(np.percentile(flat_values, 5)), '25%': float(np.percentile(flat_values, 25)), '75%': float(np.percentile(flat_values, 75)), '95%': float(np.percentile(flat_values, 95)), '99%': float(np.percentile(flat_values, 99)) } } # Determine if heatmap has good contrast if analysis['range'] < 0.1: analysis['contrast_quality'] = 'Very Poor (range < 0.1)' elif analysis['range'] < 0.3: analysis['contrast_quality'] = 'Poor (range < 0.3)' elif analysis['range'] < 0.7: analysis['contrast_quality'] = 'Moderate (range < 0.7)' else: analysis['contrast_quality'] = 'Good (range >= 0.7)' return analysis def force_contrast_enhancement(heatmap, method='aggressive'): """Force better contrast in heatmap using various methods.""" if heatmap is None: return None, "No heatmap provided" original_analysis = analyze_heatmap_distribution(heatmap, "Original") if method == 'aggressive': # Method 1: Aggressive percentile stretching p1, p99 = np.percentile(heatmap, [1, 99]) if p99 > p1: enhanced = np.clip((heatmap - p1) / (p99 - p1), 0, 1) else: enhanced = heatmap # Apply power transformation to spread values enhanced = np.power(enhanced, 0.3) # Gamma < 1 spreads values elif method == 'histogram_eq': # Method 2: Histogram equalization flat = heatmap.flatten() hist, bins = np.histogram(flat, bins=256, range=(0, 1)) cdf = hist.cumsum() cdf = cdf / cdf[-1] # Normalize # Interpolate to get new values enhanced = np.interp(flat, bins[:-1], cdf).reshape(heatmap.shape) elif method == 'adaptive': # Method 3: Adaptive enhancement based on local statistics # Local mean and std local_mean = ndimage.uniform_filter(heatmap, size=20) local_std = ndimage.generic_filter(heatmap, np.std, size=20) # Enhance based on local statistics enhanced = (heatmap - local_mean) / (local_std + 1e-8) enhanced = np.clip(enhanced, -3, 3) # Clip outliers enhanced = (enhanced + 3) / 6 # Normalize to [0, 1] elif method == 'artificial_peaks': # Method 4: Create artificial peaks for visualization enhanced = heatmap.copy() # Find top 10% of values and enhance them threshold = np.percentile(enhanced, 90) mask = enhanced >= threshold enhanced[mask] = enhanced[mask] * 2 # Find bottom 10% and suppress them threshold_low = np.percentile(enhanced, 10) mask_low = enhanced <= threshold_low enhanced[mask_low] = enhanced[mask_low] * 0.1 # Normalize enhanced = np.clip(enhanced, 0, 1) else: enhanced = heatmap enhanced_analysis = analyze_heatmap_distribution(enhanced, f"Enhanced ({method})") return enhanced, f"Enhanced using {method}", original_analysis, enhanced_analysis def create_diagnostic_heatmap_visualization(heatmap, title="Heatmap Analysis"): """Create a comprehensive diagnostic visualization of the heatmap.""" if not MPL_AVAILABLE or heatmap is None: return None fig, axes = plt.subplots(2, 3, figsize=(18, 12)) # Original heatmap im1 = axes[0, 0].imshow(heatmap, cmap='hot', vmin=0, vmax=1) axes[0, 0].set_title(f"{title} - Hot Colormap") plt.colorbar(im1, ax=axes[0, 0]) # Different colormap im2 = axes[0, 1].imshow(heatmap, cmap='viridis', vmin=0, vmax=1) axes[0, 1].set_title(f"{title} - Viridis Colormap") plt.colorbar(im2, ax=axes[0, 1]) # High contrast version im3 = axes[0, 2].imshow(heatmap, cmap='RdYlBu_r', vmin=np.min(heatmap), vmax=np.max(heatmap)) axes[0, 2].set_title(f"{title} - Auto-scaled") plt.colorbar(im3, ax=axes[0, 2]) # Histogram axes[1, 0].hist(heatmap.flatten(), bins=50, alpha=0.7, color='blue') axes[1, 0].set_title("Value Distribution") axes[1, 0].set_xlabel("Heatmap Value") axes[1, 0].set_ylabel("Frequency") # 3D surface plot x = np.arange(heatmap.shape[1]) y = np.arange(heatmap.shape[0]) X, Y = np.meshgrid(x, y) ax_3d = fig.add_subplot(2, 3, 5, projection='3d') surf = ax_3d.plot_surface(X[::8, ::8], Y[::8, ::8], heatmap[::8, ::8], cmap='hot', alpha=0.8) ax_3d.set_title("3D Surface View") # Statistics text analysis = analyze_heatmap_distribution(heatmap) stats_text = f""" Shape: {analysis['shape']} Range: {analysis['range']:.4f} Mean: {analysis['mean']:.4f} Std: {analysis['std']:.4f} Unique values: {analysis['unique_values']} Contrast: {analysis['contrast_quality']} Percentiles: 1%: {analysis['percentiles']['1%']:.4f} 25%: {analysis['percentiles']['25%']:.4f} 75%: {analysis['percentiles']['75%']:.4f} 99%: {analysis['percentiles']['99%']:.4f} """ axes[1, 2].text(0.1, 0.9, stats_text, transform=axes[1, 2].transAxes, fontsize=10, verticalalignment='top', fontfamily='monospace') axes[1, 2].set_title("Statistics") axes[1, 2].axis('off') plt.tight_layout() return fig def create_multiple_enhancement_comparison(heatmap): """Compare different enhancement methods side by side.""" if not MPL_AVAILABLE or heatmap is None: return None methods = ['aggressive', 'histogram_eq', 'adaptive', 'artificial_peaks'] enhanced_maps = {} for method in methods: enhanced, _, _, _ = force_contrast_enhancement(heatmap, method) enhanced_maps[method] = enhanced fig, axes = plt.subplots(2, 3, figsize=(18, 12)) # Original im0 = axes[0, 0].imshow(heatmap, cmap='hot', vmin=0, vmax=1) axes[0, 0].set_title("Original Heatmap") plt.colorbar(im0, ax=axes[0, 0]) # Enhanced versions positions = [(0, 1), (0, 2), (1, 0), (1, 1)] for i, (method, enhanced) in enumerate(enhanced_maps.items()): row, col = positions[i] im = axes[row, col].imshow(enhanced, cmap='hot', vmin=0, vmax=1) axes[row, col].set_title(f"Enhanced: {method}") plt.colorbar(im, ax=axes[row, col]) # Comparison histogram axes[1, 2].hist(heatmap.flatten(), bins=30, alpha=0.5, label='Original', color='blue') for method, enhanced in enhanced_maps.items(): axes[1, 2].hist(enhanced.flatten(), bins=30, alpha=0.3, label=method) axes[1, 2].set_title("Value Distributions") axes[1, 2].legend() axes[1, 2].set_xlabel("Value") axes[1, 2].set_ylabel("Frequency") plt.tight_layout() return fig def predict_stroke(img, model): """Predict stroke type from image.""" if model is None: return None, "Model not loaded" try: # Preprocess image img_resized = img.resize((224, 224)) img_array = np.array(img_resized, dtype=np.float32) # Handle grayscale if len(img_array.shape) == 2: img_array = np.stack([img_array] * 3, axis=-1) # Normalize and add batch dimension img_array = np.expand_dims(img_array, axis=0) / 255.0 # Predict predictions = model.predict(img_array, verbose=0) return predictions[0], None except Exception as e: return None, f"Prediction error: {str(e)}" def create_test_heatmaps(): """Create test heatmaps with known patterns for comparison.""" test_maps = {} # Test 1: High contrast pattern test_maps['high_contrast'] = np.zeros((224, 224)) test_maps['high_contrast'][50:150, 50:150] = 1.0 test_maps['high_contrast'][75:125, 75:125] = 0.0 # Test 2: Gradient pattern x = np.linspace(0, 1, 224) y = np.linspace(0, 1, 224) X, Y = np.meshgrid(x, y) test_maps['gradient'] = X * Y # Test 3: Gaussian blobs test_maps['gaussian'] = np.zeros((224, 224)) centers = [(60, 60), (160, 160), (60, 160)] for cx, cy in centers: y, x = np.ogrid[:224, :224] mask = np.exp(-((x - cx)**2 + (y - cy)**2) / (2 * 30**2)) test_maps['gaussian'] += mask test_maps['gaussian'] = test_maps['gaussian'] / np.max(test_maps['gaussian']) # Test 4: Low contrast (similar to your issue) test_maps['low_contrast'] = np.random.normal(0.5, 0.05, (224, 224)) test_maps['low_contrast'] = np.clip(test_maps['low_contrast'], 0, 1) return test_maps # Main App def main(): # Header st.markdown('

๐Ÿง  Heatmap Diagnostic System

', unsafe_allow_html=True) # Auto-load model on startup if not st.session_state.model_loaded: with st.spinner("Loading AI model..."): st.session_state.model, st.session_state.model_status = load_stroke_model() st.session_state.model_loaded = True # System status st.markdown("### ๐Ÿ”ง System Status") col1, col2, col3 = st.columns(3) with col1: if TF_AVAILABLE: st.markdown('
โœ… TensorFlow Ready
', unsafe_allow_html=True) st.write(f"TF Version: {tf.__version__}") else: st.markdown('
โŒ TensorFlow Error
', unsafe_allow_html=True) with col2: if MPL_AVAILABLE: st.markdown('
โœ… Matplotlib Ready
', unsafe_allow_html=True) else: st.markdown('
โŒ Matplotlib Error
', unsafe_allow_html=True) with col3: if "โœ…" in st.session_state.model_status: st.markdown('
โœ… Model Loaded
', unsafe_allow_html=True) else: st.markdown('
โŒ Model Error
', unsafe_allow_html=True) # Test heatmaps section st.markdown("### ๐Ÿงช Test Heatmap Patterns") test_maps = create_test_heatmaps() col1, col2 = st.columns(2) with col1: st.write("**Test Pattern:**") test_pattern = st.selectbox( "Choose a test pattern", list(test_maps.keys()), help="Test different heatmap patterns to see how they display" ) if test_pattern: test_heatmap = test_maps[test_pattern] # Show diagnostic visualization diagnostic_fig = create_diagnostic_heatmap_visualization(test_heatmap, f"Test: {test_pattern}") if diagnostic_fig: st.pyplot(diagnostic_fig) plt.close() with col2: st.write("**Enhancement Comparison:**") if test_pattern: test_heatmap = test_maps[test_pattern] # Show enhancement comparison comparison_fig = create_multiple_enhancement_comparison(test_heatmap) if comparison_fig: st.pyplot(comparison_fig) plt.close() # Sidebar with st.sidebar: st.header("๐Ÿ“ค Upload Brain Scan") uploaded_file = st.file_uploader( "Choose a brain scan image...", type=['png', 'jpg', 'jpeg', 'bmp', 'tiff'], help="Upload a brain scan image for stroke classification" ) st.markdown("---") st.header("๐ŸŽจ Enhancement Options") enhancement_method = st.selectbox( "Enhancement Method", ['none', 'aggressive', 'histogram_eq', 'adaptive', 'artificial_peaks'], index=1, help="Choose how to enhance heatmap contrast" ) show_diagnostics = st.checkbox("Show Diagnostic Analysis", value=True) show_comparisons = st.checkbox("Show Enhancement Comparisons", value=True) if uploaded_file is not None: # Load image image = Image.open(uploaded_file) st.subheader("๐Ÿ“‹ Classification Results") if st.session_state.model is not None: # Predict with st.spinner("๐Ÿ” Analyzing brain scan..."): predictions, error = predict_stroke(image, st.session_state.model) if error: st.error(error) else: # Get top prediction class_idx = np.argmax(predictions) confidence = predictions[class_idx] * 100 predicted_class = STROKE_LABELS[class_idx] # Display main result st.markdown(f"""

{predicted_class}

Confidence: {confidence:.1f}%

""", unsafe_allow_html=True) # Create a simple test heatmap based on prediction st.subheader("๐ŸŽฏ Simulated Attention Analysis") # Create a realistic simulated heatmap confidence_normalized = confidence / 100.0 predicted_class_idx = np.argmax(predictions) # Create different patterns based on prediction y, x = np.ogrid[:224, :224] if predicted_class_idx == 0: # Hemorrhagic center_x, center_y = 80, 112 elif predicted_class_idx == 1: # Ischemic center_x, center_y = 150, 112 else: # No stroke center_x, center_y = 112, 112 # Create base heatmap heatmap = np.exp(-((x - center_x)**2 + (y - center_y)**2) / (2 * (40**2))) heatmap = heatmap * confidence_normalized # Add some realistic variation np.random.seed(42) noise = np.random.normal(0, 0.02, heatmap.shape) heatmap = np.maximum(heatmap + noise, 0) # Normalize if np.max(heatmap) > 0: heatmap = heatmap / np.max(heatmap) # Show diagnostic analysis if show_diagnostics: st.write("**๐Ÿ“Š Heatmap Diagnostic Analysis:**") diagnostic_fig = create_diagnostic_heatmap_visualization(heatmap, "Your Model's Attention") if diagnostic_fig: st.pyplot(diagnostic_fig) plt.close() # Show enhancement comparisons if show_comparisons: st.write("**๐ŸŽจ Enhancement Method Comparison:**") comparison_fig = create_multiple_enhancement_comparison(heatmap) if comparison_fig: st.pyplot(comparison_fig) plt.close() # Apply selected enhancement if enhancement_method != 'none': enhanced_heatmap, enhancement_msg, orig_analysis, enh_analysis = force_contrast_enhancement(heatmap, enhancement_method) st.write(f"**๐Ÿ”ง Applied Enhancement: {enhancement_method}**") # Show before/after comparison fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # Original im1 = axes[0].imshow(heatmap, cmap='hot', vmin=0, vmax=1) axes[0].set_title("Original Heatmap") axes[0].axis('off') plt.colorbar(im1, ax=axes[0]) # Enhanced im2 = axes[1].imshow(enhanced_heatmap, cmap='hot', vmin=0, vmax=1) axes[1].set_title(f"Enhanced ({enhancement_method})") axes[1].axis('off') plt.colorbar(im2, ax=axes[1]) # Overlay on image img_resized = image.resize((224, 224)) img_array = np.array(img_resized) axes[2].imshow(img_array) im3 = axes[2].imshow(enhanced_heatmap, cmap='hot', alpha=0.6, vmin=0, vmax=1) axes[2].set_title("Enhanced Overlay") axes[2].axis('off') plt.colorbar(im3, ax=axes[2]) plt.tight_layout() st.pyplot(fig) plt.close() # Show improvement statistics col1, col2 = st.columns(2) with col1: st.write("**Original Stats:**") st.write(f"Range: {orig_analysis['range']:.4f}") st.write(f"Std: {orig_analysis['std']:.4f}") st.write(f"Contrast: {orig_analysis['contrast_quality']}") with col2: st.write("**Enhanced Stats:**") st.write(f"Range: {enh_analysis['range']:.4f}") st.write(f"Std: {enh_analysis['std']:.4f}") st.write(f"Contrast: {enh_analysis['contrast_quality']}") else: st.error("โŒ Model not loaded.") else: # Welcome message st.markdown(""" ## ๐Ÿ‘‹ Welcome to the Heatmap Diagnostic System This system helps you understand **why your heatmaps appear as one color** and how to fix it. ### ๐Ÿ” What This Shows You: - **Value distribution analysis** - See if your heatmap has variation - **Multiple visualization methods** - Different ways to display the same data - **Enhancement techniques** - Force better contrast and visibility - **Test patterns** - Compare with known good patterns ### ๐ŸŽฏ Common Issues: - **Low variance** - All values are nearly the same - **Poor normalization** - Values compressed into narrow range - **Uniform attention** - Model doesn't focus on specific areas ### ๐Ÿ› ๏ธ Solutions: - **Aggressive enhancement** - Force contrast stretching - **Histogram equalization** - Spread values evenly - **Artificial peaks** - Enhance high-attention areas **Try the test patterns above, then upload your image! ๐Ÿ‘†** """) # Medical disclaimer st.markdown("---") st.warning("โš ๏ธ **Medical Disclaimer:** This AI system is for educational and research purposes only. It should not be used for actual medical diagnosis. Always consult qualified healthcare professionals for medical decisions.") if __name__ == "__main__": main()