eye-clahe-processor / image_processing_eye_gpu.py
iszt's picture
adding Coordinate Mapping
9d66986 verified
"""
GPU-Native Eye Image Processor for Color Fundus Photography (CFP) Images.
This module implements a fully PyTorch-based image processor that:
1. Localizes the eye/fundus region using gradient-based radial symmetry
2. Crops to a border-minimized square centered on the eye
3. Applies CLAHE for contrast enhancement
4. Outputs tensors compatible with Hugging Face vision models
Constraints:
- PyTorch only (no OpenCV, PIL, NumPy in runtime)
- CUDA-compatible, batch-friendly, deterministic
"""
from typing import Dict, List, Optional, Union
import math
import torch
import torch.nn.functional as F
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_processing_base import BatchFeature
# Optional imports for broader input support
try:
from PIL import Image
PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False
try:
import numpy as np
NUMPY_AVAILABLE = True
except ImportError:
NUMPY_AVAILABLE = False
# =============================================================================
# PHASE 1: Input & Tensor Standardization
# =============================================================================
def _pil_to_tensor(image: "Image.Image") -> torch.Tensor:
"""Convert PIL Image to tensor (C, H, W) in [0, 1]."""
if not PIL_AVAILABLE:
raise ImportError("PIL is required to process PIL Images")
# Convert to RGB if necessary
if image.mode != "RGB":
image = image.convert("RGB")
# Use numpy as intermediate if available, otherwise manual conversion
if NUMPY_AVAILABLE:
arr = np.array(image, dtype=np.float32) / 255.0
# (H, W, C) -> (C, H, W)
tensor = torch.from_numpy(arr).permute(2, 0, 1)
else:
# Manual conversion without numpy
width, height = image.size
pixels = list(image.getdata())
tensor = torch.tensor(pixels, dtype=torch.float32).view(height, width, 3) / 255.0
tensor = tensor.permute(2, 0, 1)
return tensor
def _numpy_to_tensor(arr: "np.ndarray") -> torch.Tensor:
"""Convert numpy array to tensor (C, H, W) in [0, 1]."""
if not NUMPY_AVAILABLE:
raise ImportError("NumPy is required to process numpy arrays")
# Handle different array shapes
if arr.ndim == 2:
# Grayscale (H, W) -> (1, H, W)
arr = arr[..., None]
if arr.ndim == 3 and arr.shape[-1] in [1, 3, 4]:
# (H, W, C) -> (C, H, W)
arr = arr.transpose(2, 0, 1)
# Convert to float and normalize
if arr.dtype == np.uint8:
arr = arr.astype(np.float32) / 255.0
elif arr.dtype != np.float32:
arr = arr.astype(np.float32)
return torch.from_numpy(arr.copy())
def standardize_input(
images: Union[torch.Tensor, List[torch.Tensor], "Image.Image", List["Image.Image"], "np.ndarray", List["np.ndarray"]],
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Convert input images to standardized tensor format.
Args:
images: Input as:
- torch.Tensor (C,H,W), (B,C,H,W), or list of tensors
- PIL.Image.Image or list of PIL Images
- numpy.ndarray (H,W,C), (B,H,W,C), or list of arrays
device: Target device (defaults to input device or CPU)
Returns:
Tensor of shape (B, C, H, W) in float32, range [0, 1]
"""
# Handle single inputs by wrapping in list
if PIL_AVAILABLE and isinstance(images, Image.Image):
images = [images]
if NUMPY_AVAILABLE and isinstance(images, np.ndarray) and images.ndim == 3:
# Could be single (H,W,C) or batch (B,H,W) grayscale - assume single if last dim is 1-4
if images.shape[-1] in [1, 3, 4]:
images = [images]
# Convert list inputs to tensors
if isinstance(images, list):
converted = []
for img in images:
if PIL_AVAILABLE and isinstance(img, Image.Image):
converted.append(_pil_to_tensor(img))
elif NUMPY_AVAILABLE and isinstance(img, np.ndarray):
converted.append(_numpy_to_tensor(img))
elif isinstance(img, torch.Tensor):
t = img if img.dim() == 3 else img.squeeze(0)
converted.append(t)
else:
raise TypeError(f"Unsupported image type: {type(img)}")
images = torch.stack(converted)
elif NUMPY_AVAILABLE and isinstance(images, np.ndarray):
# Batch of numpy arrays (B, H, W, C)
if images.ndim == 4:
images = images.transpose(0, 3, 1, 2) # (B, C, H, W)
if images.dtype == np.uint8:
images = images.astype(np.float32) / 255.0
images = torch.from_numpy(images.copy())
if images.dim() == 3:
# Add batch dimension: (C, H, W) -> (B, C, H, W)
images = images.unsqueeze(0)
# Move to target device if specified
if device is not None:
images = images.to(device)
# Convert to float32 and normalize to [0, 1]
if images.dtype == torch.uint8:
images = images.float() / 255.0
elif images.dtype != torch.float32:
images = images.float()
# Clamp to valid range
images = images.clamp(0.0, 1.0)
return images
def rgb_to_grayscale(images: torch.Tensor) -> torch.Tensor:
"""
Convert RGB images to grayscale using luminance formula.
Y = 0.299 * R + 0.587 * G + 0.114 * B
Args:
images: Tensor of shape (B, 3, H, W)
Returns:
Tensor of shape (B, 1, H, W)
"""
# Luminance weights
weights = torch.tensor([0.299, 0.587, 0.114], device=images.device, dtype=images.dtype)
weights = weights.view(1, 3, 1, 1)
grayscale = (images * weights).sum(dim=1, keepdim=True)
return grayscale
# =============================================================================
# PHASE 2: Eye Region Localization (GPU-Safe)
# =============================================================================
def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple:
"""
Create Sobel kernels for gradient computation.
Returns:
Tuple of (sobel_x, sobel_y) kernels, each of shape (1, 1, 3, 3)
"""
sobel_x = torch.tensor([
[-1, 0, 1],
[-2, 0, 2],
[-1, 0, 1]
], device=device, dtype=dtype).view(1, 1, 3, 3)
sobel_y = torch.tensor([
[-1, -2, -1],
[ 0, 0, 0],
[ 1, 2, 1]
], device=device, dtype=dtype).view(1, 1, 3, 3)
return sobel_x, sobel_y
def compute_gradients(grayscale: torch.Tensor) -> tuple:
"""
Compute image gradients using Sobel filters.
Args:
grayscale: Tensor of shape (B, 1, H, W)
Returns:
Tuple of (grad_x, grad_y, grad_magnitude)
"""
sobel_x, sobel_y = create_sobel_kernels(grayscale.device, grayscale.dtype)
# Apply Sobel filters with padding to maintain size
grad_x = F.conv2d(grayscale, sobel_x, padding=1)
grad_y = F.conv2d(grayscale, sobel_y, padding=1)
# Compute gradient magnitude
grad_magnitude = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-8)
return grad_x, grad_y, grad_magnitude
def compute_radial_symmetry_response(
grayscale: torch.Tensor,
grad_x: torch.Tensor,
grad_y: torch.Tensor,
grad_magnitude: torch.Tensor,
) -> torch.Tensor:
"""
Compute radial symmetry response for circle detection.
This weights regions that are:
1. Dark (low intensity - typical of pupil/iris)
2. Have strong radial gradients pointing inward
Args:
grayscale: Grayscale image (B, 1, H, W)
grad_x, grad_y: Gradient components
grad_magnitude: Gradient magnitude
Returns:
Radial symmetry response map (B, 1, H, W)
"""
B, _, H, W = grayscale.shape
device = grayscale.device
dtype = grayscale.dtype
# Create coordinate grids
y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W)
x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W)
# Compute center of mass of dark regions as initial estimate
# Invert intensity so dark regions have high weight
dark_weight = 1.0 - grayscale
dark_weight = dark_weight ** 2 # Emphasize darker regions
# Normalize weights
weight_sum = dark_weight.sum(dim=(2, 3), keepdim=True) + 1e-8
# Weighted center of mass
cx_init = (dark_weight * x_coords).sum(dim=(2, 3), keepdim=True) / weight_sum
cy_init = (dark_weight * y_coords).sum(dim=(2, 3), keepdim=True) / weight_sum
# Compute vectors from each pixel to estimated center
dx_to_center = cx_init - x_coords
dy_to_center = cy_init - y_coords
dist_to_center = torch.sqrt(dx_to_center ** 2 + dy_to_center ** 2 + 1e-8)
# Normalize direction vectors
dx_norm = dx_to_center / dist_to_center
dy_norm = dy_to_center / dist_to_center
# Normalize gradient vectors
grad_norm = grad_magnitude + 1e-8
gx_norm = grad_x / grad_norm
gy_norm = grad_y / grad_norm
# Radial symmetry: gradient should point toward center
# Dot product between gradient and direction to center
radial_alignment = gx_norm * dx_norm + gy_norm * dy_norm
# Weight by gradient magnitude and darkness
response = radial_alignment * grad_magnitude * dark_weight
# Apply Gaussian smoothing to get robust response
kernel_size = max(H, W) // 8
if kernel_size % 2 == 0:
kernel_size += 1
kernel_size = max(kernel_size, 5)
sigma = kernel_size / 6.0
# Create 1D Gaussian kernel
x = torch.arange(kernel_size, device=device, dtype=dtype) - kernel_size // 2
gaussian_1d = torch.exp(-x ** 2 / (2 * sigma ** 2))
gaussian_1d = gaussian_1d / gaussian_1d.sum()
# Separable 2D convolution
gaussian_1d_h = gaussian_1d.view(1, 1, 1, kernel_size)
gaussian_1d_v = gaussian_1d.view(1, 1, kernel_size, 1)
pad_h = kernel_size // 2
pad_v = kernel_size // 2
response = F.pad(response, (pad_h, pad_h, 0, 0), mode='reflect')
response = F.conv2d(response, gaussian_1d_h)
response = F.pad(response, (0, 0, pad_v, pad_v), mode='reflect')
response = F.conv2d(response, gaussian_1d_v)
return response
def soft_argmax_2d(response: torch.Tensor, temperature: float = 0.1) -> tuple:
"""
Compute soft argmax to find the center coordinates.
Args:
response: Response map (B, 1, H, W)
temperature: Softmax temperature (lower = sharper)
Returns:
Tuple of (cx, cy) each of shape (B,)
"""
B, _, H, W = response.shape
device = response.device
dtype = response.dtype
# Flatten spatial dimensions
response_flat = response.view(B, -1)
# Apply softmax with temperature
weights = F.softmax(response_flat / temperature, dim=1)
weights = weights.view(B, 1, H, W)
# Create coordinate grids
y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W)
x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W)
# Weighted sum of coordinates
cx = (weights * x_coords).sum(dim=(2, 3)).squeeze(-1) # (B,)
cy = (weights * y_coords).sum(dim=(2, 3)).squeeze(-1) # (B,)
return cx, cy
def estimate_eye_center(
images: torch.Tensor,
softmax_temperature: float = 0.1,
) -> tuple:
"""
Estimate the center of the eye region in each image.
Args:
images: RGB images of shape (B, 3, H, W)
softmax_temperature: Temperature for soft argmax (lower = sharper peak detection,
higher = more averaging). Typical range: 0.01-1.0. Default 0.1 works well
for most fundus images. Use higher values (0.3-0.5) for noisy images.
Returns:
Tuple of (cx, cy) each of shape (B,) in pixel coordinates
"""
grayscale = rgb_to_grayscale(images)
grad_x, grad_y, grad_magnitude = compute_gradients(grayscale)
response = compute_radial_symmetry_response(grayscale, grad_x, grad_y, grad_magnitude)
cx, cy = soft_argmax_2d(response, temperature=softmax_temperature)
return cx, cy
# =============================================================================
# PHASE 2.3: Radius Estimation
# =============================================================================
def estimate_radius(
images: torch.Tensor,
cx: torch.Tensor,
cy: torch.Tensor,
num_radii: int = 100,
num_angles: int = 36,
min_radius_frac: float = 0.1,
max_radius_frac: float = 0.5,
) -> torch.Tensor:
"""
Estimate the radius of the eye region by analyzing radial intensity profiles.
Args:
images: RGB images (B, 3, H, W)
cx, cy: Center coordinates (B,)
num_radii: Number of radius samples
num_angles: Number of angular samples
min_radius_frac: Minimum radius as fraction of image size
max_radius_frac: Maximum radius as fraction of image size
Returns:
Estimated radius for each image (B,)
"""
B, _, H, W = images.shape
device = images.device
dtype = images.dtype
grayscale = rgb_to_grayscale(images) # (B, 1, H, W)
min_dim = min(H, W)
min_radius = int(min_radius_frac * min_dim)
max_radius = int(max_radius_frac * min_dim)
# Create radius and angle samples
radii = torch.linspace(min_radius, max_radius, num_radii, device=device, dtype=dtype)
angles = torch.linspace(0, 2 * math.pi, num_angles + 1, device=device, dtype=dtype)[:-1]
# Create sampling grid: (num_angles, num_radii)
cos_angles = torch.cos(angles).view(-1, 1) # (num_angles, 1)
sin_angles = torch.sin(angles).view(-1, 1) # (num_angles, 1)
# Offset coordinates from center
dx = cos_angles * radii # (num_angles, num_radii)
dy = sin_angles * radii # (num_angles, num_radii)
# Compute absolute coordinates for each batch item
# cx, cy: (B,) -> expand to (B, num_angles, num_radii)
cx_expanded = cx.view(B, 1, 1).expand(B, num_angles, num_radii)
cy_expanded = cy.view(B, 1, 1).expand(B, num_angles, num_radii)
sample_x = cx_expanded + dx.unsqueeze(0) # (B, num_angles, num_radii)
sample_y = cy_expanded + dy.unsqueeze(0) # (B, num_angles, num_radii)
# Normalize to [-1, 1] for grid_sample
sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0
sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0
# Create sampling grid: (B, num_angles, num_radii, 2)
grid = torch.stack([sample_x_norm, sample_y_norm], dim=-1)
# Sample intensities
sampled = F.grid_sample(
grayscale, grid, mode='bilinear', padding_mode='border', align_corners=True
) # (B, 1, num_angles, num_radii)
# Average over angles to get radial profile
radial_profile = sampled.mean(dim=2).squeeze(1) # (B, num_radii)
# Compute gradient of radial profile (looking for strong negative gradient at iris edge)
radial_gradient = radial_profile[:, 1:] - radial_profile[:, :-1] # (B, num_radii-1)
# Find the radius with strongest negative gradient (edge of iris)
# Weight by radius to prefer larger circles (avoid pupil boundary)
radius_weights = torch.linspace(0.5, 1.5, num_radii - 1, device=device, dtype=dtype)
weighted_gradient = radial_gradient * radius_weights.unsqueeze(0)
# Find minimum (strongest negative gradient)
min_idx = weighted_gradient.argmin(dim=1) # (B,)
# Convert index to radius value
estimated_radius = radii[min_idx + 1] # +1 because gradient has one less element
# Clamp to valid range
estimated_radius = estimated_radius.clamp(min_radius, max_radius)
return estimated_radius
# =============================================================================
# PHASE 3: Border-Minimized Square Crop
# =============================================================================
def compute_crop_box(
cx: torch.Tensor,
cy: torch.Tensor,
radius: torch.Tensor,
H: int,
W: int,
scale_factor: float = 1.1,
allow_overflow: bool = False,
) -> tuple:
"""
Compute square bounding box for cropping.
Args:
cx, cy: Center coordinates (B,)
radius: Estimated radius (B,)
H, W: Image dimensions
scale_factor: Multiply radius by this factor for padding
allow_overflow: If True, don't clamp box to image bounds (for pre-cropped images)
Returns:
Tuple of (x1, y1, x2, y2) each of shape (B,)
"""
# Compute half side length
half_side = radius * scale_factor
# Initial box centered on detected eye
x1 = cx - half_side
y1 = cy - half_side
x2 = cx + half_side
y2 = cy + half_side
if allow_overflow:
# Keep the box centered on the eye, don't clamp
# Out-of-bounds regions will be filled with black during cropping
return x1, y1, x2, y2
# Clamp to image bounds while maintaining square shape
# If box exceeds bounds, shift it
x1 = x1.clamp(min=0)
y1 = y1.clamp(min=0)
x2 = x2.clamp(max=W - 1)
y2 = y2.clamp(max=H - 1)
# Ensure square by taking minimum side
side_x = x2 - x1
side_y = y2 - y1
side = torch.minimum(side_x, side_y)
# Recenter the box
cx_new = (x1 + x2) / 2
cy_new = (y1 + y2) / 2
x1 = (cx_new - side / 2).clamp(min=0)
y1 = (cy_new - side / 2).clamp(min=0)
x2 = x1 + side
y2 = y1 + side
# Final clamp
x2 = x2.clamp(max=W - 1)
y2 = y2.clamp(max=H - 1)
return x1, y1, x2, y2
def batch_crop_and_resize(
images: torch.Tensor,
x1: torch.Tensor,
y1: torch.Tensor,
x2: torch.Tensor,
y2: torch.Tensor,
output_size: int,
padding_mode: str = 'border',
) -> torch.Tensor:
"""
Crop and resize images using grid_sample for GPU efficiency.
Args:
images: Input images (B, C, H, W)
x1, y1, x2, y2: Crop coordinates (B,) - can extend beyond image bounds
output_size: Output square size
padding_mode: How to handle out-of-bounds sampling:
- 'border': repeat edge pixels (default)
- 'zeros': fill with black (useful for pre-cropped images)
Returns:
Cropped and resized images (B, C, output_size, output_size)
"""
B, C, H, W = images.shape
device = images.device
dtype = images.dtype
# Create output grid coordinates
out_coords = torch.linspace(0, 1, output_size, device=device, dtype=dtype)
out_y, out_x = torch.meshgrid(out_coords, out_coords, indexing='ij')
out_grid = torch.stack([out_x, out_y], dim=-1) # (output_size, output_size, 2)
out_grid = out_grid.unsqueeze(0).expand(B, -1, -1, -1) # (B, output_size, output_size, 2)
# Scale grid to crop coordinates
# out_grid is in [0, 1], need to map to [x1, x2] and [y1, y2]
x1 = x1.view(B, 1, 1, 1)
y1 = y1.view(B, 1, 1, 1)
x2 = x2.view(B, 1, 1, 1)
y2 = y2.view(B, 1, 1, 1)
# Map [0, 1] to pixel coordinates
sample_x = x1 + out_grid[..., 0:1] * (x2 - x1)
sample_y = y1 + out_grid[..., 1:2] * (y2 - y1)
# Normalize to [-1, 1] for grid_sample
sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0
sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0
grid = torch.cat([sample_x_norm, sample_y_norm], dim=-1) # (B, output_size, output_size, 2)
# Sample with specified padding mode
cropped = F.grid_sample(
images, grid, mode='bilinear', padding_mode=padding_mode, align_corners=True
)
return cropped
# =============================================================================
# PHASE 4: CLAHE (Torch-Native)
# =============================================================================
def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor:
"""Convert sRGB to linear RGB."""
threshold = 0.04045
linear = torch.where(
rgb <= threshold,
rgb / 12.92,
((rgb + 0.055) / 1.055) ** 2.4
)
return linear
def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor:
"""Convert linear RGB to sRGB."""
threshold = 0.0031308
srgb = torch.where(
linear <= threshold,
linear * 12.92,
1.055 * (linear ** (1.0 / 2.4)) - 0.055
)
return srgb
def rgb_to_lab(images: torch.Tensor) -> tuple:
"""
Convert sRGB images to CIE LAB color space.
This is a proper LAB conversion that:
1. Converts sRGB to linear RGB
2. Converts linear RGB to XYZ
3. Converts XYZ to LAB
Args:
images: RGB images (B, C, H, W) in [0, 1] sRGB
Returns:
Tuple of (L, a, b) where:
- L: Luminance in [0, 1] (normalized from [0, 100])
- a, b: Chrominance (normalized to roughly [-0.5, 0.5])
"""
device = images.device
dtype = images.dtype
# Step 1: sRGB to linear RGB
linear_rgb = _srgb_to_linear(images)
# Step 2: Linear RGB to XYZ (D65 illuminant)
# RGB to XYZ matrix
r = linear_rgb[:, 0:1, :, :]
g = linear_rgb[:, 1:2, :, :]
b = linear_rgb[:, 2:3, :, :]
x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b
y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b
z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b
# D65 reference white
xn, yn, zn = 0.95047, 1.0, 1.08883
x = x / xn
y = y / yn
z = z / zn
# Step 3: XYZ to LAB
delta = 6.0 / 29.0
delta_cube = delta ** 3
def f(t):
return torch.where(
t > delta_cube,
t ** (1.0 / 3.0),
t / (3.0 * delta ** 2) + 4.0 / 29.0
)
fx = f(x)
fy = f(y)
fz = f(z)
L = 116.0 * fy - 16.0 # Range [0, 100]
a = 500.0 * (fx - fy) # Range roughly [-128, 127]
b_ch = 200.0 * (fy - fz) # Range roughly [-128, 127]
# Normalize to convenient ranges for processing
L = L / 100.0 # [0, 1]
a = a / 256.0 + 0.5 # Roughly [0, 1]
b_ch = b_ch / 256.0 + 0.5 # Roughly [0, 1]
return L, a, b_ch
def lab_to_rgb(L: torch.Tensor, a: torch.Tensor, b_ch: torch.Tensor) -> torch.Tensor:
"""
Convert CIE LAB to sRGB.
Args:
L: Luminance in [0, 1] (normalized from [0, 100])
a, b_ch: Chrominance (normalized, roughly [0, 1])
Returns:
RGB images (B, 3, H, W) in [0, 1] sRGB
"""
# Denormalize
L_lab = L * 100.0
a_lab = (a - 0.5) * 256.0
b_lab = (b_ch - 0.5) * 256.0
# LAB to XYZ
fy = (L_lab + 16.0) / 116.0
fx = a_lab / 500.0 + fy
fz = fy - b_lab / 200.0
delta = 6.0 / 29.0
def f_inv(t):
return torch.where(
t > delta,
t ** 3,
3.0 * (delta ** 2) * (t - 4.0 / 29.0)
)
# D65 reference white
xn, yn, zn = 0.95047, 1.0, 1.08883
x = xn * f_inv(fx)
y = yn * f_inv(fy)
z = zn * f_inv(fz)
# XYZ to linear RGB
r = 3.2404542 * x - 1.5371385 * y - 0.4985314 * z
g = -0.9692660 * x + 1.8760108 * y + 0.0415560 * z
b = 0.0556434 * x - 0.2040259 * y + 1.0572252 * z
linear_rgb = torch.cat([r, g, b], dim=1)
# Clamp before gamma correction to avoid NaN from negative values
linear_rgb = linear_rgb.clamp(0.0, 1.0)
# Linear RGB to sRGB
srgb = _linear_to_srgb(linear_rgb)
return srgb.clamp(0.0, 1.0)
def compute_histogram(
tensor: torch.Tensor,
num_bins: int = 256,
) -> torch.Tensor:
"""
Compute histogram for a batch of single-channel images.
Args:
tensor: Input tensor (B, 1, H, W) with values in [0, 1]
num_bins: Number of histogram bins
Returns:
Histograms (B, num_bins)
"""
B = tensor.shape[0]
device = tensor.device
dtype = tensor.dtype
# Flatten spatial dimensions
flat = tensor.view(B, -1) # (B, H*W)
# Bin indices
bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1)
# Compute histogram using scatter_add
histograms = torch.zeros(B, num_bins, device=device, dtype=dtype)
ones = torch.ones_like(flat, dtype=dtype)
for i in range(B):
histograms[i] = histograms[i].scatter_add(0, bin_indices[i], ones[i])
return histograms
def clahe_single_tile(
tile: torch.Tensor,
clip_limit: float,
num_bins: int = 256,
) -> torch.Tensor:
"""
Apply CLAHE to a single tile.
Args:
tile: Input tile (B, 1, tile_h, tile_w)
clip_limit: Histogram clip limit
num_bins: Number of histogram bins
Returns:
CDF lookup table (B, num_bins)
"""
B, _, tile_h, tile_w = tile.shape
device = tile.device
dtype = tile.dtype
num_pixels = tile_h * tile_w
# Compute histogram
hist = compute_histogram(tile, num_bins) # (B, num_bins)
# Clip histogram
clip_value = clip_limit * num_pixels / num_bins
excess = (hist - clip_value).clamp(min=0).sum(dim=1, keepdim=True) # (B, 1)
hist = hist.clamp(max=clip_value)
# Redistribute excess uniformly
redistribution = excess / num_bins
hist = hist + redistribution
# Compute CDF
cdf = hist.cumsum(dim=1) # (B, num_bins)
# Normalize CDF to [0, 1]
cdf_min = cdf[:, 0:1]
cdf_max = cdf[:, -1:]
cdf = (cdf - cdf_min) / (cdf_max - cdf_min + 1e-8)
return cdf
def apply_clahe_vectorized(
images: torch.Tensor,
grid_size: int = 8,
clip_limit: float = 2.0,
num_bins: int = 256,
) -> torch.Tensor:
"""
Vectorized CLAHE implementation (more efficient for GPU).
Args:
images: Input images (B, C, H, W)
grid_size: Number of tiles in each dimension
clip_limit: Histogram clip limit
num_bins: Number of histogram bins
Returns:
CLAHE-enhanced images (B, C, H, W)
"""
B, C, H, W = images.shape
device = images.device
dtype = images.dtype
# Work on luminance only
if C == 3:
L, a, b_ch = rgb_to_lab(images)
else:
L = images.clone()
a = b_ch = None
# Ensure divisibility
pad_h = (grid_size - H % grid_size) % grid_size
pad_w = (grid_size - W % grid_size) % grid_size
if pad_h > 0 or pad_w > 0:
L_padded = F.pad(L, (0, pad_w, 0, pad_h), mode='reflect')
else:
L_padded = L
_, _, H_pad, W_pad = L_padded.shape
tile_h = H_pad // grid_size
tile_w = W_pad // grid_size
# Reshape into tiles: (B, 1, grid_size, tile_h, grid_size, tile_w)
L_tiles = L_padded.view(B, 1, grid_size, tile_h, grid_size, tile_w)
L_tiles = L_tiles.permute(0, 2, 4, 1, 3, 5) # (B, grid_size, grid_size, 1, tile_h, tile_w)
L_tiles = L_tiles.reshape(B * grid_size * grid_size, 1, tile_h, tile_w)
# Compute histograms for all tiles at once
num_pixels = tile_h * tile_w
flat = L_tiles.view(B * grid_size * grid_size, -1)
bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1)
# Vectorized histogram computation
histograms = torch.zeros(B * grid_size * grid_size, num_bins, device=device, dtype=dtype)
histograms.scatter_add_(1, bin_indices, torch.ones_like(flat))
# Clip and redistribute
clip_value = clip_limit * num_pixels / num_bins
excess = (histograms - clip_value).clamp(min=0).sum(dim=1, keepdim=True)
histograms = histograms.clamp(max=clip_value)
histograms = histograms + excess / num_bins
# Compute CDFs
cdfs = histograms.cumsum(dim=1)
cdf_min = cdfs[:, 0:1]
cdf_max = cdfs[:, -1:]
cdfs = (cdfs - cdf_min) / (cdf_max - cdf_min + 1e-8)
# Reshape CDFs: (B, grid_size, grid_size, num_bins)
cdfs = cdfs.view(B, grid_size, grid_size, num_bins)
# Create coordinate grids for interpolation
y_coords = torch.arange(H_pad, device=device, dtype=dtype)
x_coords = torch.arange(W_pad, device=device, dtype=dtype)
# Map to tile coordinates (centered on tiles)
tile_y = (y_coords + 0.5) / tile_h - 0.5
tile_x = (x_coords + 0.5) / tile_w - 0.5
tile_y = tile_y.clamp(0, grid_size - 1.001)
tile_x = tile_x.clamp(0, grid_size - 1.001)
# Integer indices and weights
ty0 = tile_y.long().clamp(0, grid_size - 2)
tx0 = tile_x.long().clamp(0, grid_size - 2)
ty1 = (ty0 + 1).clamp(max=grid_size - 1)
tx1 = (tx0 + 1).clamp(max=grid_size - 1)
wy = (tile_y - ty0.float()).view(1, H_pad, 1, 1)
wx = (tile_x - tx0.float()).view(1, 1, W_pad, 1)
# Get bin indices for all pixels
bin_idx = (L_padded * (num_bins - 1)).long().clamp(0, num_bins - 1) # (B, 1, H_pad, W_pad)
bin_idx = bin_idx.squeeze(1) # (B, H_pad, W_pad)
# Gather CDF values for each corner
# We need cdfs[b, ty, tx, bin_idx[b, y, x]] for all combinations
# Expand indices for gathering
b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, H_pad, W_pad)
ty0_exp = ty0.view(1, H_pad, 1).expand(B, H_pad, W_pad)
ty1_exp = ty1.view(1, H_pad, 1).expand(B, H_pad, W_pad)
tx0_exp = tx0.view(1, 1, W_pad).expand(B, H_pad, W_pad)
tx1_exp = tx1.view(1, 1, W_pad).expand(B, H_pad, W_pad)
# Gather using advanced indexing
v00 = cdfs[b_idx, ty0_exp, tx0_exp, bin_idx] # (B, H_pad, W_pad)
v01 = cdfs[b_idx, ty0_exp, tx1_exp, bin_idx]
v10 = cdfs[b_idx, ty1_exp, tx0_exp, bin_idx]
v11 = cdfs[b_idx, ty1_exp, tx1_exp, bin_idx]
# Bilinear interpolation
wy = wy.squeeze(-1) # (1, H_pad, 1)
wx = wx.squeeze(-1) # (1, 1, W_pad)
L_out = (1 - wy) * (1 - wx) * v00 + (1 - wy) * wx * v01 + wy * (1 - wx) * v10 + wy * wx * v11
L_out = L_out.unsqueeze(1) # (B, 1, H_pad, W_pad)
# Remove padding
if pad_h > 0 or pad_w > 0:
L_out = L_out[:, :, :H, :W]
# Convert back to RGB
if C == 3:
output = lab_to_rgb(L_out, a, b_ch)
else:
output = L_out
return output
# =============================================================================
# PHASE 5: Resize & Normalization
# =============================================================================
# ImageNet normalization constants
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
def resize_images(
images: torch.Tensor,
size: int,
mode: str = 'bilinear',
antialias: bool = True,
) -> torch.Tensor:
"""
Resize images to target size.
Args:
images: Input images (B, C, H, W)
size: Target size (square)
mode: Interpolation mode
antialias: Whether to use antialiasing
Returns:
Resized images (B, C, size, size)
"""
return F.interpolate(
images,
size=(size, size),
mode=mode,
align_corners=False if mode in ['bilinear', 'bicubic'] else None,
antialias=antialias if mode in ['bilinear', 'bicubic'] else False,
)
def normalize_images(
images: torch.Tensor,
mean: Optional[List[float]] = None,
std: Optional[List[float]] = None,
mode: str = 'imagenet',
) -> torch.Tensor:
"""
Normalize images.
Args:
images: Input images (B, C, H, W) in [0, 1]
mean: Custom mean (per channel)
std: Custom std (per channel)
mode: 'imagenet', 'none', or 'custom'
Returns:
Normalized images
"""
if mode == 'none':
return images
if mode == 'imagenet':
mean = IMAGENET_MEAN
std = IMAGENET_STD
elif mode == 'custom':
if mean is None or std is None:
raise ValueError("Custom mode requires mean and std")
else:
raise ValueError(f"Unknown normalization mode: {mode}")
device = images.device
dtype = images.dtype
mean_tensor = torch.tensor(mean, device=device, dtype=dtype).view(1, -1, 1, 1)
std_tensor = torch.tensor(std, device=device, dtype=dtype).view(1, -1, 1, 1)
return (images - mean_tensor) / std_tensor
# =============================================================================
# PHASE 6: Hugging Face ImageProcessor Integration
# =============================================================================
class EyeCLAHEImageProcessor(BaseImageProcessor):
"""
GPU-native image processor for Color Fundus Photography (CFP) images.
This processor:
1. Localizes the eye region using gradient-based radial symmetry
2. Crops to a border-minimized square centered on the eye
3. Applies CLAHE for contrast enhancement
4. Resizes and normalizes for vision model input
All operations are implemented in pure PyTorch and are CUDA-compatible.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
size: int = 224,
crop_scale_factor: float = 1.1,
clahe_grid_size: int = 8,
clahe_clip_limit: float = 2.0,
normalization_mode: str = "imagenet",
custom_mean: Optional[List[float]] = None,
custom_std: Optional[List[float]] = None,
do_clahe: bool = True,
do_crop: bool = True,
min_radius_frac: float = 0.1,
max_radius_frac: float = 0.5,
allow_overflow: bool = False,
softmax_temperature: float = 0.1,
**kwargs,
):
"""
Initialize the EyeCLAHEImageProcessor.
Args:
size: Output image size (square)
crop_scale_factor: Scale factor for crop box (relative to detected radius)
clahe_grid_size: Number of tiles for CLAHE
clahe_clip_limit: Histogram clip limit for CLAHE
normalization_mode: 'imagenet', 'none', or 'custom'
custom_mean: Custom normalization mean (if mode='custom')
custom_std: Custom normalization std (if mode='custom')
do_clahe: Whether to apply CLAHE
do_crop: Whether to perform eye-centered cropping
min_radius_frac: Minimum radius as fraction of image size
max_radius_frac: Maximum radius as fraction of image size
allow_overflow: If True, allow crop box to extend beyond image bounds
and fill missing regions with black. Useful for pre-cropped
images where the fundus circle is partially cut off.
softmax_temperature: Temperature for soft argmax in eye center detection.
Lower values (0.01-0.1) give sharper peak detection, higher values
(0.3-0.5) provide more averaging for noisy images. Default: 0.1.
"""
super().__init__(**kwargs)
self.size = size
self.crop_scale_factor = crop_scale_factor
self.clahe_grid_size = clahe_grid_size
self.clahe_clip_limit = clahe_clip_limit
self.normalization_mode = normalization_mode
self.custom_mean = custom_mean
self.custom_std = custom_std
self.do_clahe = do_clahe
self.do_crop = do_crop
self.min_radius_frac = min_radius_frac
self.max_radius_frac = max_radius_frac
self.allow_overflow = allow_overflow
self.softmax_temperature = softmax_temperature
def preprocess(
self,
images,
return_tensors: str = "pt",
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> BatchFeature:
"""
Preprocess images for model input.
Args:
images: Input images in any of these formats:
- torch.Tensor: (C,H,W), (B,C,H,W), or list of tensors
- PIL.Image.Image: single image or list of images
- numpy.ndarray: (H,W,C), (B,H,W,C), or list of arrays
return_tensors: Return type (only "pt" supported)
device: Target device for processing (e.g., "cuda", "cpu")
Returns:
BatchFeature with keys:
- 'pixel_values': Processed images (B, C, size, size)
- 'scale_x', 'scale_y': Scale factors for coordinate mapping (B,)
- 'offset_x', 'offset_y': Offsets for coordinate mapping (B,)
To map coordinates from processed image back to original:
orig_x = offset_x + cropped_x * scale_x
orig_y = offset_y + cropped_y * scale_y
"""
if return_tensors != "pt":
raise ValueError("Only 'pt' (PyTorch) tensors are supported")
# Determine device
if device is not None:
device = torch.device(device)
elif isinstance(images, torch.Tensor):
device = images.device
elif isinstance(images, list) and len(images) > 0 and isinstance(images[0], torch.Tensor):
device = images[0].device
else:
# PIL images and numpy arrays default to CPU
device = torch.device('cpu')
# Standardize input
images = standardize_input(images, device)
B, C, H_orig, W_orig = images.shape
if self.do_crop:
# Estimate eye center
cx, cy = estimate_eye_center(images, softmax_temperature=self.softmax_temperature)
# Estimate radius
radius = estimate_radius(
images, cx, cy,
min_radius_frac=self.min_radius_frac,
max_radius_frac=self.max_radius_frac,
)
# Compute crop box
x1, y1, x2, y2 = compute_crop_box(
cx, cy, radius, H_orig, W_orig,
scale_factor=self.crop_scale_factor,
allow_overflow=self.allow_overflow,
)
# Compute coordinate mapping
# For processed coordinates in [0, self.size-1], map back to original
scale_x = (x2 - x1) / (self.size - 1)
scale_y = (y2 - y1) / (self.size - 1)
offset_x = x1
offset_y = y1
# Crop and resize
# Use 'zeros' padding when allow_overflow is True to fill out-of-bounds with black
padding_mode = 'zeros' if self.allow_overflow else 'border'
images = batch_crop_and_resize(images, x1, y1, x2, y2, self.size, padding_mode=padding_mode)
else:
# Just resize - no crop
# Compute coordinate mapping for direct resize
scale_x = torch.full((B,), (W_orig - 1) / (self.size - 1), device=device, dtype=images.dtype)
scale_y = torch.full((B,), (H_orig - 1) / (self.size - 1), device=device, dtype=images.dtype)
offset_x = torch.zeros(B, device=device, dtype=images.dtype)
offset_y = torch.zeros(B, device=device, dtype=images.dtype)
images = resize_images(images, self.size)
# Apply CLAHE
if self.do_clahe:
images = apply_clahe_vectorized(
images,
grid_size=self.clahe_grid_size,
clip_limit=self.clahe_clip_limit,
)
# Normalize
images = normalize_images(
images,
mean=self.custom_mean,
std=self.custom_std,
mode=self.normalization_mode,
)
# Return with coordinate mapping information (flattened structure)
return BatchFeature(
data={
"pixel_values": images,
"scale_x": scale_x,
"scale_y": scale_y,
"offset_x": offset_x,
"offset_y": offset_y,
},
tensor_type="pt"
)
def __call__(
self,
images: Union[torch.Tensor, List[torch.Tensor]],
**kwargs,
) -> BatchFeature:
"""
Process images (alias for preprocess).
"""
return self.preprocess(images, **kwargs)
# For AutoImageProcessor registration
EyeGPUImageProcessor = EyeCLAHEImageProcessor