Spaces:
Sleeping
Sleeping
| """ | |
| Stress Detection Preprocessing Module | |
| ====================================== | |
| Prepares Sentinel-2 multi-spectral data for stress detection model. | |
| Includes band harmonization and normalization. | |
| """ | |
| import numpy as np | |
| from typing import Tuple, List | |
| # Band indices in the 12-band Sentinel-2 data | |
| BAND_INDICES = { | |
| 'B02': 1, # Blue | |
| 'B03': 2, # Green | |
| 'B04': 3, # Red | |
| 'B05': 4, # Red Edge 1 | |
| 'B08': 7, # NIR | |
| 'B8A': 8, # NIR Narrow | |
| 'B11': 10, # SWIR1 | |
| 'B12': 11 # SWIR2 | |
| } | |
| # Selected bands for stress detection (8 major bands) | |
| SELECTED_BANDS = ['B02', 'B03', 'B04', 'B05', 'B08', 'B8A', 'B11', 'B12'] | |
| def extract_major_bands(all_images: np.ndarray) -> np.ndarray: | |
| """ | |
| Extract 8 major bands from 12-band Sentinel-2 data. | |
| Args: | |
| all_images: Array of shape (time, height, width, 12) | |
| Returns: | |
| Array of shape (time, height, width, 8) with selected bands | |
| """ | |
| band_idx = [BAND_INDICES[band] for band in SELECTED_BANDS] | |
| return all_images[:, :, :, band_idx] | |
| def harmonize_bands(images: np.ndarray) -> np.ndarray: | |
| """ | |
| Harmonize band data to same scale [0, 1]. | |
| Sentinel-2 reflectance values are already in [0, 1] range after DN/10000 conversion. | |
| This function ensures all bands are properly normalized and handles any outliers. | |
| Args: | |
| images: Array of shape (time, height, width, bands) | |
| Returns: | |
| Harmonized array with values clipped to [0, 1] | |
| """ | |
| # Clip to [0, 1] range to handle any outliers | |
| harmonized = np.clip(images, 0, 1) | |
| # Additional per-band normalization to ensure uniform scale | |
| # Use percentile-based normalization to handle outliers | |
| time, height, width, bands = harmonized.shape | |
| for b in range(bands): | |
| band_data = harmonized[:, :, :, b] | |
| # Calculate 2nd and 98th percentiles to handle outliers | |
| p2 = np.nanpercentile(band_data, 2) | |
| p98 = np.nanpercentile(band_data, 98) | |
| # Normalize to [0, 1] using percentiles | |
| if p98 > p2: | |
| harmonized[:, :, :, b] = np.clip((band_data - p2) / (p98 - p2), 0, 1) | |
| return harmonized | |
| def handle_nan_values(images: np.ndarray, method='mean') -> np.ndarray: | |
| """ | |
| Handle NaN values in the data. | |
| Args: | |
| images: Array of shape (time, height, width, bands) | |
| method: 'mean', 'zero', or 'interpolate' | |
| Returns: | |
| Array with NaN values handled | |
| """ | |
| if method == 'zero': | |
| return np.nan_to_num(images, nan=0.0) | |
| elif method == 'mean': | |
| # Replace NaN with temporal mean for each pixel | |
| return np.where(np.isnan(images), | |
| np.nanmean(images, axis=0, keepdims=True), | |
| images) | |
| elif method == 'interpolate': | |
| # Simple linear interpolation along time axis | |
| result = images.copy() | |
| time, height, width, bands = images.shape | |
| for h in range(height): | |
| for w in range(width): | |
| for b in range(bands): | |
| pixel_series = result[:, h, w, b] | |
| if np.any(np.isnan(pixel_series)): | |
| # Interpolate NaN values | |
| mask = ~np.isnan(pixel_series) | |
| if np.any(mask): | |
| indices = np.arange(time) | |
| result[:, h, w, b] = np.interp( | |
| indices, indices[mask], pixel_series[mask] | |
| ) | |
| else: | |
| result[:, h, w, b] = 0.0 | |
| return result | |
| else: | |
| return images | |
| def create_patches(images: np.ndarray, patch_size: int = 4, stride: int = 2) -> Tuple[np.ndarray, List]: | |
| """ | |
| Create overlapping patches from images for spatial analysis. | |
| Args: | |
| images: Array of shape (time, height, width, bands) | |
| patch_size: Size of each patch | |
| stride: Stride for patch extraction | |
| Returns: | |
| patches: Array of shape (num_patches, time, patch_size, patch_size, bands) | |
| patch_coords: List of (h_start, w_start) coordinates for each patch | |
| """ | |
| time, height, width, bands = images.shape | |
| patches = [] | |
| patch_coords = [] | |
| for h in range(0, height - patch_size + 1, stride): | |
| for w in range(0, width - patch_size + 1, stride): | |
| patch = images[:, h:h+patch_size, w:w+patch_size, :] | |
| # Only include patches with sufficient valid data | |
| valid_ratio = np.sum(~np.isnan(patch)) / patch.size | |
| if valid_ratio > 0.5: # At least 50% valid data | |
| patches.append(patch) | |
| patch_coords.append((h, w)) | |
| if len(patches) == 0: | |
| # If no valid patches, create at least one from center | |
| h_center = (height - patch_size) // 2 | |
| w_center = (width - patch_size) // 2 | |
| patch = images[:, h_center:h_center+patch_size, w_center:w_center+patch_size, :] | |
| patches.append(patch) | |
| patch_coords.append((h_center, w_center)) | |
| return np.array(patches), patch_coords | |
| def preprocess_for_model(all_images: np.ndarray, | |
| patch_size: int = 4, | |
| stride: int = 2) -> Tuple[np.ndarray, List, dict]: | |
| """ | |
| Complete preprocessing pipeline for stress detection model. | |
| Args: | |
| all_images: Raw images of shape (time, height, width, 12) | |
| patch_size: Size of patches for spatial analysis | |
| stride: Stride for patch extraction | |
| Returns: | |
| patches: Preprocessed patches ready for model | |
| patch_coords: Coordinates of each patch | |
| metadata: Dictionary with preprocessing information | |
| """ | |
| print("Preprocessing data for stress detection model...") | |
| # Step 1: Extract major bands | |
| print(" [1/4] Extracting 8 major bands...") | |
| major_bands = extract_major_bands(all_images) | |
| # Step 2: Harmonize bands to same scale | |
| print(" [2/4] Harmonizing bands to [0, 1] scale...") | |
| harmonized = harmonize_bands(major_bands) | |
| # Step 3: Handle NaN values | |
| print(" [3/4] Handling NaN values...") | |
| clean_data = handle_nan_values(harmonized, method='mean') | |
| # Step 4: Create patches | |
| print(" [4/4] Creating spatial patches...") | |
| patches, patch_coords = create_patches(clean_data, patch_size, stride) | |
| metadata = { | |
| 'original_shape': all_images.shape, | |
| 'selected_bands': SELECTED_BANDS, | |
| 'num_bands': len(SELECTED_BANDS), | |
| 'patch_size': patch_size, | |
| 'stride': stride, | |
| 'num_patches': len(patches), | |
| 'harmonized': True | |
| } | |
| print(f"[OK] Preprocessing complete. Created {len(patches)} patches.") | |
| print(f" Patch shape: {patches.shape}") | |
| return patches, patch_coords, metadata | |