""" GPU-Native Eye Image Processor for Color Fundus Photography (CFP) Images. This module implements a fully PyTorch-based image processor that: 1. Localizes the eye/fundus region using gradient-based radial symmetry 2. Crops to a border-minimized square centered on the eye 3. Applies CLAHE for contrast enhancement 4. Outputs tensors compatible with Hugging Face vision models Constraints: - PyTorch only (no OpenCV, PIL, NumPy in runtime) - CUDA-compatible, batch-friendly, deterministic """ from typing import Dict, List, Optional, Union import math import torch import torch.nn.functional as F from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_base import BatchFeature # Optional imports for broader input support try: from PIL import Image PIL_AVAILABLE = True except ImportError: PIL_AVAILABLE = False try: import numpy as np NUMPY_AVAILABLE = True except ImportError: NUMPY_AVAILABLE = False # ============================================================================= # PHASE 1: Input & Tensor Standardization # ============================================================================= def _pil_to_tensor(image: "Image.Image") -> torch.Tensor: """Convert PIL Image to tensor (C, H, W) in [0, 1].""" if not PIL_AVAILABLE: raise ImportError("PIL is required to process PIL Images") # Convert to RGB if necessary if image.mode != "RGB": image = image.convert("RGB") # Use numpy as intermediate if available, otherwise manual conversion if NUMPY_AVAILABLE: arr = np.array(image, dtype=np.float32) / 255.0 # (H, W, C) -> (C, H, W) tensor = torch.from_numpy(arr).permute(2, 0, 1) else: # Manual conversion without numpy width, height = image.size pixels = list(image.getdata()) tensor = torch.tensor(pixels, dtype=torch.float32).view(height, width, 3) / 255.0 tensor = tensor.permute(2, 0, 1) return tensor def _numpy_to_tensor(arr: "np.ndarray") -> torch.Tensor: """Convert numpy array to tensor (C, H, W) in [0, 1].""" if not NUMPY_AVAILABLE: raise ImportError("NumPy is required to process numpy arrays") # Handle different array shapes if arr.ndim == 2: # Grayscale (H, W) -> (1, H, W) arr = arr[..., None] if arr.ndim == 3 and arr.shape[-1] in [1, 3, 4]: # (H, W, C) -> (C, H, W) arr = arr.transpose(2, 0, 1) # Convert to float and normalize if arr.dtype == np.uint8: arr = arr.astype(np.float32) / 255.0 elif arr.dtype != np.float32: arr = arr.astype(np.float32) return torch.from_numpy(arr.copy()) def standardize_input( images: Union[torch.Tensor, List[torch.Tensor], "Image.Image", List["Image.Image"], "np.ndarray", List["np.ndarray"]], device: Optional[torch.device] = None, ) -> torch.Tensor: """ Convert input images to standardized tensor format. Args: images: Input as: - torch.Tensor (C,H,W), (B,C,H,W), or list of tensors - PIL.Image.Image or list of PIL Images - numpy.ndarray (H,W,C), (B,H,W,C), or list of arrays device: Target device (defaults to input device or CPU) Returns: Tensor of shape (B, C, H, W) in float32, range [0, 1] """ # Handle single inputs by wrapping in list if PIL_AVAILABLE and isinstance(images, Image.Image): images = [images] if NUMPY_AVAILABLE and isinstance(images, np.ndarray) and images.ndim == 3: # Could be single (H,W,C) or batch (B,H,W) grayscale - assume single if last dim is 1-4 if images.shape[-1] in [1, 3, 4]: images = [images] # Convert list inputs to tensors if isinstance(images, list): converted = [] for img in images: if PIL_AVAILABLE and isinstance(img, Image.Image): converted.append(_pil_to_tensor(img)) elif NUMPY_AVAILABLE and isinstance(img, np.ndarray): converted.append(_numpy_to_tensor(img)) elif isinstance(img, torch.Tensor): t = img if img.dim() == 3 else img.squeeze(0) converted.append(t) else: raise TypeError(f"Unsupported image type: {type(img)}") images = torch.stack(converted) elif NUMPY_AVAILABLE and isinstance(images, np.ndarray): # Batch of numpy arrays (B, H, W, C) if images.ndim == 4: images = images.transpose(0, 3, 1, 2) # (B, C, H, W) if images.dtype == np.uint8: images = images.astype(np.float32) / 255.0 images = torch.from_numpy(images.copy()) if images.dim() == 3: # Add batch dimension: (C, H, W) -> (B, C, H, W) images = images.unsqueeze(0) # Move to target device if specified if device is not None: images = images.to(device) # Convert to float32 and normalize to [0, 1] if images.dtype == torch.uint8: images = images.float() / 255.0 elif images.dtype != torch.float32: images = images.float() # Clamp to valid range images = images.clamp(0.0, 1.0) return images def rgb_to_grayscale(images: torch.Tensor) -> torch.Tensor: """ Convert RGB images to grayscale using luminance formula. Y = 0.299 * R + 0.587 * G + 0.114 * B Args: images: Tensor of shape (B, 3, H, W) Returns: Tensor of shape (B, 1, H, W) """ # Luminance weights weights = torch.tensor([0.299, 0.587, 0.114], device=images.device, dtype=images.dtype) weights = weights.view(1, 3, 1, 1) grayscale = (images * weights).sum(dim=1, keepdim=True) return grayscale # ============================================================================= # PHASE 2: Eye Region Localization (GPU-Safe) # ============================================================================= def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple: """ Create Sobel kernels for gradient computation. Returns: Tuple of (sobel_x, sobel_y) kernels, each of shape (1, 1, 3, 3) """ sobel_x = torch.tensor([ [-1, 0, 1], [-2, 0, 2], [-1, 0, 1] ], device=device, dtype=dtype).view(1, 1, 3, 3) sobel_y = torch.tensor([ [-1, -2, -1], [ 0, 0, 0], [ 1, 2, 1] ], device=device, dtype=dtype).view(1, 1, 3, 3) return sobel_x, sobel_y def compute_gradients(grayscale: torch.Tensor) -> tuple: """ Compute image gradients using Sobel filters. Args: grayscale: Tensor of shape (B, 1, H, W) Returns: Tuple of (grad_x, grad_y, grad_magnitude) """ sobel_x, sobel_y = create_sobel_kernels(grayscale.device, grayscale.dtype) # Apply Sobel filters with padding to maintain size grad_x = F.conv2d(grayscale, sobel_x, padding=1) grad_y = F.conv2d(grayscale, sobel_y, padding=1) # Compute gradient magnitude grad_magnitude = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-8) return grad_x, grad_y, grad_magnitude def compute_radial_symmetry_response( grayscale: torch.Tensor, grad_x: torch.Tensor, grad_y: torch.Tensor, grad_magnitude: torch.Tensor, ) -> torch.Tensor: """ Compute radial symmetry response for circle detection. This weights regions that are: 1. Dark (low intensity - typical of pupil/iris) 2. Have strong radial gradients pointing inward Args: grayscale: Grayscale image (B, 1, H, W) grad_x, grad_y: Gradient components grad_magnitude: Gradient magnitude Returns: Radial symmetry response map (B, 1, H, W) """ B, _, H, W = grayscale.shape device = grayscale.device dtype = grayscale.dtype # Create coordinate grids y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W) x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W) # Compute center of mass of dark regions as initial estimate # Invert intensity so dark regions have high weight dark_weight = 1.0 - grayscale dark_weight = dark_weight ** 2 # Emphasize darker regions # Normalize weights weight_sum = dark_weight.sum(dim=(2, 3), keepdim=True) + 1e-8 # Weighted center of mass cx_init = (dark_weight * x_coords).sum(dim=(2, 3), keepdim=True) / weight_sum cy_init = (dark_weight * y_coords).sum(dim=(2, 3), keepdim=True) / weight_sum # Compute vectors from each pixel to estimated center dx_to_center = cx_init - x_coords dy_to_center = cy_init - y_coords dist_to_center = torch.sqrt(dx_to_center ** 2 + dy_to_center ** 2 + 1e-8) # Normalize direction vectors dx_norm = dx_to_center / dist_to_center dy_norm = dy_to_center / dist_to_center # Normalize gradient vectors grad_norm = grad_magnitude + 1e-8 gx_norm = grad_x / grad_norm gy_norm = grad_y / grad_norm # Radial symmetry: gradient should point toward center # Dot product between gradient and direction to center radial_alignment = gx_norm * dx_norm + gy_norm * dy_norm # Weight by gradient magnitude and darkness response = radial_alignment * grad_magnitude * dark_weight # Apply Gaussian smoothing to get robust response kernel_size = max(H, W) // 8 if kernel_size % 2 == 0: kernel_size += 1 kernel_size = max(kernel_size, 5) sigma = kernel_size / 6.0 # Create 1D Gaussian kernel x = torch.arange(kernel_size, device=device, dtype=dtype) - kernel_size // 2 gaussian_1d = torch.exp(-x ** 2 / (2 * sigma ** 2)) gaussian_1d = gaussian_1d / gaussian_1d.sum() # Separable 2D convolution gaussian_1d_h = gaussian_1d.view(1, 1, 1, kernel_size) gaussian_1d_v = gaussian_1d.view(1, 1, kernel_size, 1) pad_h = kernel_size // 2 pad_v = kernel_size // 2 response = F.pad(response, (pad_h, pad_h, 0, 0), mode='reflect') response = F.conv2d(response, gaussian_1d_h) response = F.pad(response, (0, 0, pad_v, pad_v), mode='reflect') response = F.conv2d(response, gaussian_1d_v) return response def soft_argmax_2d(response: torch.Tensor, temperature: float = 0.1) -> tuple: """ Compute soft argmax to find the center coordinates. Args: response: Response map (B, 1, H, W) temperature: Softmax temperature (lower = sharper) Returns: Tuple of (cx, cy) each of shape (B,) """ B, _, H, W = response.shape device = response.device dtype = response.dtype # Flatten spatial dimensions response_flat = response.view(B, -1) # Apply softmax with temperature weights = F.softmax(response_flat / temperature, dim=1) weights = weights.view(B, 1, H, W) # Create coordinate grids y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W) x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W) # Weighted sum of coordinates cx = (weights * x_coords).sum(dim=(2, 3)).squeeze(-1) # (B,) cy = (weights * y_coords).sum(dim=(2, 3)).squeeze(-1) # (B,) return cx, cy def estimate_eye_center( images: torch.Tensor, softmax_temperature: float = 0.1, ) -> tuple: """ Estimate the center of the eye region in each image. Args: images: RGB images of shape (B, 3, H, W) softmax_temperature: Temperature for soft argmax (lower = sharper peak detection, higher = more averaging). Typical range: 0.01-1.0. Default 0.1 works well for most fundus images. Use higher values (0.3-0.5) for noisy images. Returns: Tuple of (cx, cy) each of shape (B,) in pixel coordinates """ grayscale = rgb_to_grayscale(images) grad_x, grad_y, grad_magnitude = compute_gradients(grayscale) response = compute_radial_symmetry_response(grayscale, grad_x, grad_y, grad_magnitude) cx, cy = soft_argmax_2d(response, temperature=softmax_temperature) return cx, cy # ============================================================================= # PHASE 2.3: Radius Estimation # ============================================================================= def estimate_radius( images: torch.Tensor, cx: torch.Tensor, cy: torch.Tensor, num_radii: int = 100, num_angles: int = 36, min_radius_frac: float = 0.1, max_radius_frac: float = 0.5, ) -> torch.Tensor: """ Estimate the radius of the eye region by analyzing radial intensity profiles. Args: images: RGB images (B, 3, H, W) cx, cy: Center coordinates (B,) num_radii: Number of radius samples num_angles: Number of angular samples min_radius_frac: Minimum radius as fraction of image size max_radius_frac: Maximum radius as fraction of image size Returns: Estimated radius for each image (B,) """ B, _, H, W = images.shape device = images.device dtype = images.dtype grayscale = rgb_to_grayscale(images) # (B, 1, H, W) min_dim = min(H, W) min_radius = int(min_radius_frac * min_dim) max_radius = int(max_radius_frac * min_dim) # Create radius and angle samples radii = torch.linspace(min_radius, max_radius, num_radii, device=device, dtype=dtype) angles = torch.linspace(0, 2 * math.pi, num_angles + 1, device=device, dtype=dtype)[:-1] # Create sampling grid: (num_angles, num_radii) cos_angles = torch.cos(angles).view(-1, 1) # (num_angles, 1) sin_angles = torch.sin(angles).view(-1, 1) # (num_angles, 1) # Offset coordinates from center dx = cos_angles * radii # (num_angles, num_radii) dy = sin_angles * radii # (num_angles, num_radii) # Compute absolute coordinates for each batch item # cx, cy: (B,) -> expand to (B, num_angles, num_radii) cx_expanded = cx.view(B, 1, 1).expand(B, num_angles, num_radii) cy_expanded = cy.view(B, 1, 1).expand(B, num_angles, num_radii) sample_x = cx_expanded + dx.unsqueeze(0) # (B, num_angles, num_radii) sample_y = cy_expanded + dy.unsqueeze(0) # (B, num_angles, num_radii) # Normalize to [-1, 1] for grid_sample sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0 sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0 # Create sampling grid: (B, num_angles, num_radii, 2) grid = torch.stack([sample_x_norm, sample_y_norm], dim=-1) # Sample intensities sampled = F.grid_sample( grayscale, grid, mode='bilinear', padding_mode='border', align_corners=True ) # (B, 1, num_angles, num_radii) # Average over angles to get radial profile radial_profile = sampled.mean(dim=2).squeeze(1) # (B, num_radii) # Compute gradient of radial profile (looking for strong negative gradient at iris edge) radial_gradient = radial_profile[:, 1:] - radial_profile[:, :-1] # (B, num_radii-1) # Find the radius with strongest negative gradient (edge of iris) # Weight by radius to prefer larger circles (avoid pupil boundary) radius_weights = torch.linspace(0.5, 1.5, num_radii - 1, device=device, dtype=dtype) weighted_gradient = radial_gradient * radius_weights.unsqueeze(0) # Find minimum (strongest negative gradient) min_idx = weighted_gradient.argmin(dim=1) # (B,) # Convert index to radius value estimated_radius = radii[min_idx + 1] # +1 because gradient has one less element # Clamp to valid range estimated_radius = estimated_radius.clamp(min_radius, max_radius) return estimated_radius # ============================================================================= # PHASE 3: Border-Minimized Square Crop # ============================================================================= def compute_crop_box( cx: torch.Tensor, cy: torch.Tensor, radius: torch.Tensor, H: int, W: int, scale_factor: float = 1.1, allow_overflow: bool = False, ) -> tuple: """ Compute square bounding box for cropping. Args: cx, cy: Center coordinates (B,) radius: Estimated radius (B,) H, W: Image dimensions scale_factor: Multiply radius by this factor for padding allow_overflow: If True, don't clamp box to image bounds (for pre-cropped images) Returns: Tuple of (x1, y1, x2, y2) each of shape (B,) """ # Compute half side length half_side = radius * scale_factor # Initial box centered on detected eye x1 = cx - half_side y1 = cy - half_side x2 = cx + half_side y2 = cy + half_side if allow_overflow: # Keep the box centered on the eye, don't clamp # Out-of-bounds regions will be filled with black during cropping return x1, y1, x2, y2 # Clamp to image bounds while maintaining square shape # If box exceeds bounds, shift it x1 = x1.clamp(min=0) y1 = y1.clamp(min=0) x2 = x2.clamp(max=W - 1) y2 = y2.clamp(max=H - 1) # Ensure square by taking minimum side side_x = x2 - x1 side_y = y2 - y1 side = torch.minimum(side_x, side_y) # Recenter the box cx_new = (x1 + x2) / 2 cy_new = (y1 + y2) / 2 x1 = (cx_new - side / 2).clamp(min=0) y1 = (cy_new - side / 2).clamp(min=0) x2 = x1 + side y2 = y1 + side # Final clamp x2 = x2.clamp(max=W - 1) y2 = y2.clamp(max=H - 1) return x1, y1, x2, y2 def batch_crop_and_resize( images: torch.Tensor, x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor, output_size: int, padding_mode: str = 'border', ) -> torch.Tensor: """ Crop and resize images using grid_sample for GPU efficiency. Args: images: Input images (B, C, H, W) x1, y1, x2, y2: Crop coordinates (B,) - can extend beyond image bounds output_size: Output square size padding_mode: How to handle out-of-bounds sampling: - 'border': repeat edge pixels (default) - 'zeros': fill with black (useful for pre-cropped images) Returns: Cropped and resized images (B, C, output_size, output_size) """ B, C, H, W = images.shape device = images.device dtype = images.dtype # Create output grid coordinates out_coords = torch.linspace(0, 1, output_size, device=device, dtype=dtype) out_y, out_x = torch.meshgrid(out_coords, out_coords, indexing='ij') out_grid = torch.stack([out_x, out_y], dim=-1) # (output_size, output_size, 2) out_grid = out_grid.unsqueeze(0).expand(B, -1, -1, -1) # (B, output_size, output_size, 2) # Scale grid to crop coordinates # out_grid is in [0, 1], need to map to [x1, x2] and [y1, y2] x1 = x1.view(B, 1, 1, 1) y1 = y1.view(B, 1, 1, 1) x2 = x2.view(B, 1, 1, 1) y2 = y2.view(B, 1, 1, 1) # Map [0, 1] to pixel coordinates sample_x = x1 + out_grid[..., 0:1] * (x2 - x1) sample_y = y1 + out_grid[..., 1:2] * (y2 - y1) # Normalize to [-1, 1] for grid_sample sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0 sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0 grid = torch.cat([sample_x_norm, sample_y_norm], dim=-1) # (B, output_size, output_size, 2) # Sample with specified padding mode cropped = F.grid_sample( images, grid, mode='bilinear', padding_mode=padding_mode, align_corners=True ) return cropped # ============================================================================= # PHASE 4: CLAHE (Torch-Native) # ============================================================================= def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor: """Convert sRGB to linear RGB.""" threshold = 0.04045 linear = torch.where( rgb <= threshold, rgb / 12.92, ((rgb + 0.055) / 1.055) ** 2.4 ) return linear def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor: """Convert linear RGB to sRGB.""" threshold = 0.0031308 srgb = torch.where( linear <= threshold, linear * 12.92, 1.055 * (linear ** (1.0 / 2.4)) - 0.055 ) return srgb def rgb_to_lab(images: torch.Tensor) -> tuple: """ Convert sRGB images to CIE LAB color space. This is a proper LAB conversion that: 1. Converts sRGB to linear RGB 2. Converts linear RGB to XYZ 3. Converts XYZ to LAB Args: images: RGB images (B, C, H, W) in [0, 1] sRGB Returns: Tuple of (L, a, b) where: - L: Luminance in [0, 1] (normalized from [0, 100]) - a, b: Chrominance (normalized to roughly [-0.5, 0.5]) """ device = images.device dtype = images.dtype # Step 1: sRGB to linear RGB linear_rgb = _srgb_to_linear(images) # Step 2: Linear RGB to XYZ (D65 illuminant) # RGB to XYZ matrix r = linear_rgb[:, 0:1, :, :] g = linear_rgb[:, 1:2, :, :] b = linear_rgb[:, 2:3, :, :] x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b # D65 reference white xn, yn, zn = 0.95047, 1.0, 1.08883 x = x / xn y = y / yn z = z / zn # Step 3: XYZ to LAB delta = 6.0 / 29.0 delta_cube = delta ** 3 def f(t): return torch.where( t > delta_cube, t ** (1.0 / 3.0), t / (3.0 * delta ** 2) + 4.0 / 29.0 ) fx = f(x) fy = f(y) fz = f(z) L = 116.0 * fy - 16.0 # Range [0, 100] a = 500.0 * (fx - fy) # Range roughly [-128, 127] b_ch = 200.0 * (fy - fz) # Range roughly [-128, 127] # Normalize to convenient ranges for processing L = L / 100.0 # [0, 1] a = a / 256.0 + 0.5 # Roughly [0, 1] b_ch = b_ch / 256.0 + 0.5 # Roughly [0, 1] return L, a, b_ch def lab_to_rgb(L: torch.Tensor, a: torch.Tensor, b_ch: torch.Tensor) -> torch.Tensor: """ Convert CIE LAB to sRGB. Args: L: Luminance in [0, 1] (normalized from [0, 100]) a, b_ch: Chrominance (normalized, roughly [0, 1]) Returns: RGB images (B, 3, H, W) in [0, 1] sRGB """ # Denormalize L_lab = L * 100.0 a_lab = (a - 0.5) * 256.0 b_lab = (b_ch - 0.5) * 256.0 # LAB to XYZ fy = (L_lab + 16.0) / 116.0 fx = a_lab / 500.0 + fy fz = fy - b_lab / 200.0 delta = 6.0 / 29.0 def f_inv(t): return torch.where( t > delta, t ** 3, 3.0 * (delta ** 2) * (t - 4.0 / 29.0) ) # D65 reference white xn, yn, zn = 0.95047, 1.0, 1.08883 x = xn * f_inv(fx) y = yn * f_inv(fy) z = zn * f_inv(fz) # XYZ to linear RGB r = 3.2404542 * x - 1.5371385 * y - 0.4985314 * z g = -0.9692660 * x + 1.8760108 * y + 0.0415560 * z b = 0.0556434 * x - 0.2040259 * y + 1.0572252 * z linear_rgb = torch.cat([r, g, b], dim=1) # Clamp before gamma correction to avoid NaN from negative values linear_rgb = linear_rgb.clamp(0.0, 1.0) # Linear RGB to sRGB srgb = _linear_to_srgb(linear_rgb) return srgb.clamp(0.0, 1.0) def compute_histogram( tensor: torch.Tensor, num_bins: int = 256, ) -> torch.Tensor: """ Compute histogram for a batch of single-channel images. Args: tensor: Input tensor (B, 1, H, W) with values in [0, 1] num_bins: Number of histogram bins Returns: Histograms (B, num_bins) """ B = tensor.shape[0] device = tensor.device dtype = tensor.dtype # Flatten spatial dimensions flat = tensor.view(B, -1) # (B, H*W) # Bin indices bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1) # Compute histogram using scatter_add histograms = torch.zeros(B, num_bins, device=device, dtype=dtype) ones = torch.ones_like(flat, dtype=dtype) for i in range(B): histograms[i] = histograms[i].scatter_add(0, bin_indices[i], ones[i]) return histograms def clahe_single_tile( tile: torch.Tensor, clip_limit: float, num_bins: int = 256, ) -> torch.Tensor: """ Apply CLAHE to a single tile. Args: tile: Input tile (B, 1, tile_h, tile_w) clip_limit: Histogram clip limit num_bins: Number of histogram bins Returns: CDF lookup table (B, num_bins) """ B, _, tile_h, tile_w = tile.shape device = tile.device dtype = tile.dtype num_pixels = tile_h * tile_w # Compute histogram hist = compute_histogram(tile, num_bins) # (B, num_bins) # Clip histogram clip_value = clip_limit * num_pixels / num_bins excess = (hist - clip_value).clamp(min=0).sum(dim=1, keepdim=True) # (B, 1) hist = hist.clamp(max=clip_value) # Redistribute excess uniformly redistribution = excess / num_bins hist = hist + redistribution # Compute CDF cdf = hist.cumsum(dim=1) # (B, num_bins) # Normalize CDF to [0, 1] cdf_min = cdf[:, 0:1] cdf_max = cdf[:, -1:] cdf = (cdf - cdf_min) / (cdf_max - cdf_min + 1e-8) return cdf def apply_clahe_vectorized( images: torch.Tensor, grid_size: int = 8, clip_limit: float = 2.0, num_bins: int = 256, ) -> torch.Tensor: """ Vectorized CLAHE implementation (more efficient for GPU). Args: images: Input images (B, C, H, W) grid_size: Number of tiles in each dimension clip_limit: Histogram clip limit num_bins: Number of histogram bins Returns: CLAHE-enhanced images (B, C, H, W) """ B, C, H, W = images.shape device = images.device dtype = images.dtype # Work on luminance only if C == 3: L, a, b_ch = rgb_to_lab(images) else: L = images.clone() a = b_ch = None # Ensure divisibility pad_h = (grid_size - H % grid_size) % grid_size pad_w = (grid_size - W % grid_size) % grid_size if pad_h > 0 or pad_w > 0: L_padded = F.pad(L, (0, pad_w, 0, pad_h), mode='reflect') else: L_padded = L _, _, H_pad, W_pad = L_padded.shape tile_h = H_pad // grid_size tile_w = W_pad // grid_size # Reshape into tiles: (B, 1, grid_size, tile_h, grid_size, tile_w) L_tiles = L_padded.view(B, 1, grid_size, tile_h, grid_size, tile_w) L_tiles = L_tiles.permute(0, 2, 4, 1, 3, 5) # (B, grid_size, grid_size, 1, tile_h, tile_w) L_tiles = L_tiles.reshape(B * grid_size * grid_size, 1, tile_h, tile_w) # Compute histograms for all tiles at once num_pixels = tile_h * tile_w flat = L_tiles.view(B * grid_size * grid_size, -1) bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1) # Vectorized histogram computation histograms = torch.zeros(B * grid_size * grid_size, num_bins, device=device, dtype=dtype) histograms.scatter_add_(1, bin_indices, torch.ones_like(flat)) # Clip and redistribute clip_value = clip_limit * num_pixels / num_bins excess = (histograms - clip_value).clamp(min=0).sum(dim=1, keepdim=True) histograms = histograms.clamp(max=clip_value) histograms = histograms + excess / num_bins # Compute CDFs cdfs = histograms.cumsum(dim=1) cdf_min = cdfs[:, 0:1] cdf_max = cdfs[:, -1:] cdfs = (cdfs - cdf_min) / (cdf_max - cdf_min + 1e-8) # Reshape CDFs: (B, grid_size, grid_size, num_bins) cdfs = cdfs.view(B, grid_size, grid_size, num_bins) # Create coordinate grids for interpolation y_coords = torch.arange(H_pad, device=device, dtype=dtype) x_coords = torch.arange(W_pad, device=device, dtype=dtype) # Map to tile coordinates (centered on tiles) tile_y = (y_coords + 0.5) / tile_h - 0.5 tile_x = (x_coords + 0.5) / tile_w - 0.5 tile_y = tile_y.clamp(0, grid_size - 1.001) tile_x = tile_x.clamp(0, grid_size - 1.001) # Integer indices and weights ty0 = tile_y.long().clamp(0, grid_size - 2) tx0 = tile_x.long().clamp(0, grid_size - 2) ty1 = (ty0 + 1).clamp(max=grid_size - 1) tx1 = (tx0 + 1).clamp(max=grid_size - 1) wy = (tile_y - ty0.float()).view(1, H_pad, 1, 1) wx = (tile_x - tx0.float()).view(1, 1, W_pad, 1) # Get bin indices for all pixels bin_idx = (L_padded * (num_bins - 1)).long().clamp(0, num_bins - 1) # (B, 1, H_pad, W_pad) bin_idx = bin_idx.squeeze(1) # (B, H_pad, W_pad) # Gather CDF values for each corner # We need cdfs[b, ty, tx, bin_idx[b, y, x]] for all combinations # Expand indices for gathering b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, H_pad, W_pad) ty0_exp = ty0.view(1, H_pad, 1).expand(B, H_pad, W_pad) ty1_exp = ty1.view(1, H_pad, 1).expand(B, H_pad, W_pad) tx0_exp = tx0.view(1, 1, W_pad).expand(B, H_pad, W_pad) tx1_exp = tx1.view(1, 1, W_pad).expand(B, H_pad, W_pad) # Gather using advanced indexing v00 = cdfs[b_idx, ty0_exp, tx0_exp, bin_idx] # (B, H_pad, W_pad) v01 = cdfs[b_idx, ty0_exp, tx1_exp, bin_idx] v10 = cdfs[b_idx, ty1_exp, tx0_exp, bin_idx] v11 = cdfs[b_idx, ty1_exp, tx1_exp, bin_idx] # Bilinear interpolation wy = wy.squeeze(-1) # (1, H_pad, 1) wx = wx.squeeze(-1) # (1, 1, W_pad) L_out = (1 - wy) * (1 - wx) * v00 + (1 - wy) * wx * v01 + wy * (1 - wx) * v10 + wy * wx * v11 L_out = L_out.unsqueeze(1) # (B, 1, H_pad, W_pad) # Remove padding if pad_h > 0 or pad_w > 0: L_out = L_out[:, :, :H, :W] # Convert back to RGB if C == 3: output = lab_to_rgb(L_out, a, b_ch) else: output = L_out return output # ============================================================================= # PHASE 5: Resize & Normalization # ============================================================================= # ImageNet normalization constants IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] def resize_images( images: torch.Tensor, size: int, mode: str = 'bilinear', antialias: bool = True, ) -> torch.Tensor: """ Resize images to target size. Args: images: Input images (B, C, H, W) size: Target size (square) mode: Interpolation mode antialias: Whether to use antialiasing Returns: Resized images (B, C, size, size) """ return F.interpolate( images, size=(size, size), mode=mode, align_corners=False if mode in ['bilinear', 'bicubic'] else None, antialias=antialias if mode in ['bilinear', 'bicubic'] else False, ) def normalize_images( images: torch.Tensor, mean: Optional[List[float]] = None, std: Optional[List[float]] = None, mode: str = 'imagenet', ) -> torch.Tensor: """ Normalize images. Args: images: Input images (B, C, H, W) in [0, 1] mean: Custom mean (per channel) std: Custom std (per channel) mode: 'imagenet', 'none', or 'custom' Returns: Normalized images """ if mode == 'none': return images if mode == 'imagenet': mean = IMAGENET_MEAN std = IMAGENET_STD elif mode == 'custom': if mean is None or std is None: raise ValueError("Custom mode requires mean and std") else: raise ValueError(f"Unknown normalization mode: {mode}") device = images.device dtype = images.dtype mean_tensor = torch.tensor(mean, device=device, dtype=dtype).view(1, -1, 1, 1) std_tensor = torch.tensor(std, device=device, dtype=dtype).view(1, -1, 1, 1) return (images - mean_tensor) / std_tensor # ============================================================================= # PHASE 6: Hugging Face ImageProcessor Integration # ============================================================================= class EyeCLAHEImageProcessor(BaseImageProcessor): """ GPU-native image processor for Color Fundus Photography (CFP) images. This processor: 1. Localizes the eye region using gradient-based radial symmetry 2. Crops to a border-minimized square centered on the eye 3. Applies CLAHE for contrast enhancement 4. Resizes and normalizes for vision model input All operations are implemented in pure PyTorch and are CUDA-compatible. """ model_input_names = ["pixel_values"] def __init__( self, size: int = 224, crop_scale_factor: float = 1.1, clahe_grid_size: int = 8, clahe_clip_limit: float = 2.0, normalization_mode: str = "imagenet", custom_mean: Optional[List[float]] = None, custom_std: Optional[List[float]] = None, do_clahe: bool = True, do_crop: bool = True, min_radius_frac: float = 0.1, max_radius_frac: float = 0.5, allow_overflow: bool = False, softmax_temperature: float = 0.1, **kwargs, ): """ Initialize the EyeCLAHEImageProcessor. Args: size: Output image size (square) crop_scale_factor: Scale factor for crop box (relative to detected radius) clahe_grid_size: Number of tiles for CLAHE clahe_clip_limit: Histogram clip limit for CLAHE normalization_mode: 'imagenet', 'none', or 'custom' custom_mean: Custom normalization mean (if mode='custom') custom_std: Custom normalization std (if mode='custom') do_clahe: Whether to apply CLAHE do_crop: Whether to perform eye-centered cropping min_radius_frac: Minimum radius as fraction of image size max_radius_frac: Maximum radius as fraction of image size allow_overflow: If True, allow crop box to extend beyond image bounds and fill missing regions with black. Useful for pre-cropped images where the fundus circle is partially cut off. softmax_temperature: Temperature for soft argmax in eye center detection. Lower values (0.01-0.1) give sharper peak detection, higher values (0.3-0.5) provide more averaging for noisy images. Default: 0.1. """ super().__init__(**kwargs) self.size = size self.crop_scale_factor = crop_scale_factor self.clahe_grid_size = clahe_grid_size self.clahe_clip_limit = clahe_clip_limit self.normalization_mode = normalization_mode self.custom_mean = custom_mean self.custom_std = custom_std self.do_clahe = do_clahe self.do_crop = do_crop self.min_radius_frac = min_radius_frac self.max_radius_frac = max_radius_frac self.allow_overflow = allow_overflow self.softmax_temperature = softmax_temperature def preprocess( self, images, return_tensors: str = "pt", device: Optional[Union[str, torch.device]] = None, **kwargs, ) -> BatchFeature: """ Preprocess images for model input. Args: images: Input images in any of these formats: - torch.Tensor: (C,H,W), (B,C,H,W), or list of tensors - PIL.Image.Image: single image or list of images - numpy.ndarray: (H,W,C), (B,H,W,C), or list of arrays return_tensors: Return type (only "pt" supported) device: Target device for processing (e.g., "cuda", "cpu") Returns: BatchFeature with keys: - 'pixel_values': Processed images (B, C, size, size) - 'scale_x', 'scale_y': Scale factors for coordinate mapping (B,) - 'offset_x', 'offset_y': Offsets for coordinate mapping (B,) To map coordinates from processed image back to original: orig_x = offset_x + cropped_x * scale_x orig_y = offset_y + cropped_y * scale_y """ if return_tensors != "pt": raise ValueError("Only 'pt' (PyTorch) tensors are supported") # Determine device if device is not None: device = torch.device(device) elif isinstance(images, torch.Tensor): device = images.device elif isinstance(images, list) and len(images) > 0 and isinstance(images[0], torch.Tensor): device = images[0].device else: # PIL images and numpy arrays default to CPU device = torch.device('cpu') # Standardize input images = standardize_input(images, device) B, C, H_orig, W_orig = images.shape if self.do_crop: # Estimate eye center cx, cy = estimate_eye_center(images, softmax_temperature=self.softmax_temperature) # Estimate radius radius = estimate_radius( images, cx, cy, min_radius_frac=self.min_radius_frac, max_radius_frac=self.max_radius_frac, ) # Compute crop box x1, y1, x2, y2 = compute_crop_box( cx, cy, radius, H_orig, W_orig, scale_factor=self.crop_scale_factor, allow_overflow=self.allow_overflow, ) # Compute coordinate mapping # For processed coordinates in [0, self.size-1], map back to original scale_x = (x2 - x1) / (self.size - 1) scale_y = (y2 - y1) / (self.size - 1) offset_x = x1 offset_y = y1 # Crop and resize # Use 'zeros' padding when allow_overflow is True to fill out-of-bounds with black padding_mode = 'zeros' if self.allow_overflow else 'border' images = batch_crop_and_resize(images, x1, y1, x2, y2, self.size, padding_mode=padding_mode) else: # Just resize - no crop # Compute coordinate mapping for direct resize scale_x = torch.full((B,), (W_orig - 1) / (self.size - 1), device=device, dtype=images.dtype) scale_y = torch.full((B,), (H_orig - 1) / (self.size - 1), device=device, dtype=images.dtype) offset_x = torch.zeros(B, device=device, dtype=images.dtype) offset_y = torch.zeros(B, device=device, dtype=images.dtype) images = resize_images(images, self.size) # Apply CLAHE if self.do_clahe: images = apply_clahe_vectorized( images, grid_size=self.clahe_grid_size, clip_limit=self.clahe_clip_limit, ) # Normalize images = normalize_images( images, mean=self.custom_mean, std=self.custom_std, mode=self.normalization_mode, ) # Return with coordinate mapping information (flattened structure) return BatchFeature( data={ "pixel_values": images, "scale_x": scale_x, "scale_y": scale_y, "offset_x": offset_x, "offset_y": offset_y, }, tensor_type="pt" ) def __call__( self, images: Union[torch.Tensor, List[torch.Tensor]], **kwargs, ) -> BatchFeature: """ Process images (alias for preprocess). """ return self.preprocess(images, **kwargs) # For AutoImageProcessor registration EyeGPUImageProcessor = EyeCLAHEImageProcessor