import streamlit as st import numpy as np import os import sys from PIL import Image from scipy import ndimage import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # Set environment variables to fix permission issues os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' os.environ['STREAMLIT_SERVER_HEADLESS'] = 'true' # Minimal imports to avoid conflicts try: import tensorflow as tf TF_AVAILABLE = True except ImportError: TF_AVAILABLE = False st.error("TensorFlow not available") try: import matplotlib matplotlib.use('Agg') # Use non-interactive backend import matplotlib.cm as cm MPL_AVAILABLE = True except ImportError: MPL_AVAILABLE = False # Page config st.set_page_config( page_title="Stroke Classifier", page_icon="๐ง ", layout="wide") # Simple styling st.markdown(""" """, unsafe_allow_html=True) # Initialize session state if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False st.session_state.model = None st.session_state.model_status = "Not loaded" STROKE_LABELS = ["Hemorrhagic Stroke", "Ischemic Stroke", "No Stroke"] def find_model_file(): """Find the model file in various possible locations.""" possible_paths = [ "stroke_classification_model.h5", "./stroke_classification_model.h5", "/app/stroke_classification_model.h5", "src/stroke_classification_model.h5", os.path.join(os.getcwd(), "stroke_classification_model.h5") ] # Also check all .h5 files in current directory and subdirectories for root, dirs, files in os.walk('.'): for file in files: if file.endswith('.h5'): possible_paths.append(os.path.join(root, file)) for path in possible_paths: if os.path.exists(path): return path return None @st.cache_resource def load_stroke_model(): """Load model with caching.""" if not TF_AVAILABLE: return None, "โ TensorFlow not available" try: # Find the model file model_path = find_model_file() if model_path is None: # List all files to help debug current_files = [] for root, dirs, files in os.walk('.'): for file in files: current_files.append(os.path.join(root, file)) return None, f"โ Model file not found. Available files: {current_files[:10]}" st.info(f"Found model at: {model_path}") # Load model with minimal custom objects model = tf.keras.models.load_model(model_path, compile=False) return model, f"โ Model loaded successfully from: {model_path}" except Exception as e: return None, f"โ Model loading failed: {str(e)}" def analyze_heatmap_distribution(heatmap, name="Heatmap"): """Analyze the distribution of heatmap values.""" if heatmap is None: return None flat_values = heatmap.flatten() analysis = { 'name': name, 'shape': heatmap.shape, 'total_pixels': heatmap.size, 'min': float(np.min(flat_values)), 'max': float(np.max(flat_values)), 'mean': float(np.mean(flat_values)), 'median': float(np.median(flat_values)), 'std': float(np.std(flat_values)), 'range': float(np.max(flat_values) - np.min(flat_values)), 'unique_values': len(np.unique(flat_values)), 'zero_pixels': int(np.sum(flat_values == 0)), 'non_zero_pixels': int(np.sum(flat_values > 0)), 'percentiles': { '1%': float(np.percentile(flat_values, 1)), '5%': float(np.percentile(flat_values, 5)), '25%': float(np.percentile(flat_values, 25)), '75%': float(np.percentile(flat_values, 75)), '95%': float(np.percentile(flat_values, 95)), '99%': float(np.percentile(flat_values, 99)) } } # Determine if heatmap has good contrast if analysis['range'] < 0.1: analysis['contrast_quality'] = 'Very Poor (range < 0.1)' elif analysis['range'] < 0.3: analysis['contrast_quality'] = 'Poor (range < 0.3)' elif analysis['range'] < 0.7: analysis['contrast_quality'] = 'Moderate (range < 0.7)' else: analysis['contrast_quality'] = 'Good (range >= 0.7)' return analysis def force_contrast_enhancement(heatmap, method='aggressive'): """Force better contrast in heatmap using various methods.""" if heatmap is None: return None, "No heatmap provided" original_analysis = analyze_heatmap_distribution(heatmap, "Original") if method == 'aggressive': # Method 1: Aggressive percentile stretching p1, p99 = np.percentile(heatmap, [1, 99]) if p99 > p1: enhanced = np.clip((heatmap - p1) / (p99 - p1), 0, 1) else: enhanced = heatmap # Apply power transformation to spread values enhanced = np.power(enhanced, 0.3) # Gamma < 1 spreads values elif method == 'histogram_eq': # Method 2: Histogram equalization flat = heatmap.flatten() hist, bins = np.histogram(flat, bins=256, range=(0, 1)) cdf = hist.cumsum() cdf = cdf / cdf[-1] # Normalize # Interpolate to get new values enhanced = np.interp(flat, bins[:-1], cdf).reshape(heatmap.shape) elif method == 'adaptive': # Method 3: Adaptive enhancement based on local statistics # Local mean and std local_mean = ndimage.uniform_filter(heatmap, size=20) local_std = ndimage.generic_filter(heatmap, np.std, size=20) # Enhance based on local statistics enhanced = (heatmap - local_mean) / (local_std + 1e-8) enhanced = np.clip(enhanced, -3, 3) # Clip outliers enhanced = (enhanced + 3) / 6 # Normalize to [0, 1] elif method == 'artificial_peaks': # Method 4: Create artificial peaks for visualization enhanced = heatmap.copy() # Find top 10% of values and enhance them threshold = np.percentile(enhanced, 90) mask = enhanced >= threshold enhanced[mask] = enhanced[mask] * 2 # Find bottom 10% and suppress them threshold_low = np.percentile(enhanced, 10) mask_low = enhanced <= threshold_low enhanced[mask_low] = enhanced[mask_low] * 0.1 # Normalize enhanced = np.clip(enhanced, 0, 1) else: enhanced = heatmap enhanced_analysis = analyze_heatmap_distribution(enhanced, f"Enhanced ({method})") return enhanced, f"Enhanced using {method}", original_analysis, enhanced_analysis def create_diagnostic_heatmap_visualization(heatmap, title="Heatmap Analysis"): """Create a comprehensive diagnostic visualization of the heatmap.""" if not MPL_AVAILABLE or heatmap is None: return None fig, axes = plt.subplots(2, 3, figsize=(18, 12)) # Original heatmap im1 = axes[0, 0].imshow(heatmap, cmap='hot', vmin=0, vmax=1) axes[0, 0].set_title(f"{title} - Hot Colormap") plt.colorbar(im1, ax=axes[0, 0]) # Different colormap im2 = axes[0, 1].imshow(heatmap, cmap='viridis', vmin=0, vmax=1) axes[0, 1].set_title(f"{title} - Viridis Colormap") plt.colorbar(im2, ax=axes[0, 1]) # High contrast version im3 = axes[0, 2].imshow(heatmap, cmap='RdYlBu_r', vmin=np.min(heatmap), vmax=np.max(heatmap)) axes[0, 2].set_title(f"{title} - Auto-scaled") plt.colorbar(im3, ax=axes[0, 2]) # Histogram axes[1, 0].hist(heatmap.flatten(), bins=50, alpha=0.7, color='blue') axes[1, 0].set_title("Value Distribution") axes[1, 0].set_xlabel("Heatmap Value") axes[1, 0].set_ylabel("Frequency") # 3D surface plot x = np.arange(heatmap.shape[1]) y = np.arange(heatmap.shape[0]) X, Y = np.meshgrid(x, y) ax_3d = fig.add_subplot(2, 3, 5, projection='3d') surf = ax_3d.plot_surface(X[::8, ::8], Y[::8, ::8], heatmap[::8, ::8], cmap='hot', alpha=0.8) ax_3d.set_title("3D Surface View") # Statistics text analysis = analyze_heatmap_distribution(heatmap) stats_text = f""" Shape: {analysis['shape']} Range: {analysis['range']:.4f} Mean: {analysis['mean']:.4f} Std: {analysis['std']:.4f} Unique values: {analysis['unique_values']} Contrast: {analysis['contrast_quality']} Percentiles: 1%: {analysis['percentiles']['1%']:.4f} 25%: {analysis['percentiles']['25%']:.4f} 75%: {analysis['percentiles']['75%']:.4f} 99%: {analysis['percentiles']['99%']:.4f} """ axes[1, 2].text(0.1, 0.9, stats_text, transform=axes[1, 2].transAxes, fontsize=10, verticalalignment='top', fontfamily='monospace') axes[1, 2].set_title("Statistics") axes[1, 2].axis('off') plt.tight_layout() return fig def create_multiple_enhancement_comparison(heatmap): """Compare different enhancement methods side by side.""" if not MPL_AVAILABLE or heatmap is None: return None methods = ['aggressive', 'histogram_eq', 'adaptive', 'artificial_peaks'] enhanced_maps = {} for method in methods: enhanced, _, _, _ = force_contrast_enhancement(heatmap, method) enhanced_maps[method] = enhanced fig, axes = plt.subplots(2, 3, figsize=(18, 12)) # Original im0 = axes[0, 0].imshow(heatmap, cmap='hot', vmin=0, vmax=1) axes[0, 0].set_title("Original Heatmap") plt.colorbar(im0, ax=axes[0, 0]) # Enhanced versions positions = [(0, 1), (0, 2), (1, 0), (1, 1)] for i, (method, enhanced) in enumerate(enhanced_maps.items()): row, col = positions[i] im = axes[row, col].imshow(enhanced, cmap='hot', vmin=0, vmax=1) axes[row, col].set_title(f"Enhanced: {method}") plt.colorbar(im, ax=axes[row, col]) # Comparison histogram axes[1, 2].hist(heatmap.flatten(), bins=30, alpha=0.5, label='Original', color='blue') for method, enhanced in enhanced_maps.items(): axes[1, 2].hist(enhanced.flatten(), bins=30, alpha=0.3, label=method) axes[1, 2].set_title("Value Distributions") axes[1, 2].legend() axes[1, 2].set_xlabel("Value") axes[1, 2].set_ylabel("Frequency") plt.tight_layout() return fig def predict_stroke(img, model): """Predict stroke type from image.""" if model is None: return None, "Model not loaded" try: # Preprocess image img_resized = img.resize((224, 224)) img_array = np.array(img_resized, dtype=np.float32) # Handle grayscale if len(img_array.shape) == 2: img_array = np.stack([img_array] * 3, axis=-1) # Normalize and add batch dimension img_array = np.expand_dims(img_array, axis=0) / 255.0 # Predict predictions = model.predict(img_array, verbose=0) return predictions[0], None except Exception as e: return None, f"Prediction error: {str(e)}" def create_test_heatmaps(): """Create test heatmaps with known patterns for comparison.""" test_maps = {} # Test 1: High contrast pattern test_maps['high_contrast'] = np.zeros((224, 224)) test_maps['high_contrast'][50:150, 50:150] = 1.0 test_maps['high_contrast'][75:125, 75:125] = 0.0 # Test 2: Gradient pattern x = np.linspace(0, 1, 224) y = np.linspace(0, 1, 224) X, Y = np.meshgrid(x, y) test_maps['gradient'] = X * Y # Test 3: Gaussian blobs test_maps['gaussian'] = np.zeros((224, 224)) centers = [(60, 60), (160, 160), (60, 160)] for cx, cy in centers: y, x = np.ogrid[:224, :224] mask = np.exp(-((x - cx)**2 + (y - cy)**2) / (2 * 30**2)) test_maps['gaussian'] += mask test_maps['gaussian'] = test_maps['gaussian'] / np.max(test_maps['gaussian']) # Test 4: Low contrast (similar to your issue) test_maps['low_contrast'] = np.random.normal(0.5, 0.05, (224, 224)) test_maps['low_contrast'] = np.clip(test_maps['low_contrast'], 0, 1) return test_maps # Main App def main(): # Header st.markdown('