|
|
""" |
|
|
GPU-Native Eye Image Processor for Color Fundus Photography (CFP) Images. |
|
|
|
|
|
This module implements a fully PyTorch-based image processor that: |
|
|
1. Localizes the eye/fundus region using gradient-based radial symmetry |
|
|
2. Crops to a border-minimized square centered on the eye |
|
|
3. Applies CLAHE for contrast enhancement |
|
|
4. Outputs tensors compatible with Hugging Face vision models |
|
|
|
|
|
Constraints: |
|
|
- PyTorch only (no OpenCV, PIL, NumPy in runtime) |
|
|
- CUDA-compatible, batch-friendly, deterministic |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Optional, Union |
|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers.image_processing_utils import BaseImageProcessor |
|
|
from transformers.image_processing_base import BatchFeature |
|
|
|
|
|
|
|
|
try: |
|
|
from PIL import Image |
|
|
PIL_AVAILABLE = True |
|
|
except ImportError: |
|
|
PIL_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
import numpy as np |
|
|
NUMPY_AVAILABLE = True |
|
|
except ImportError: |
|
|
NUMPY_AVAILABLE = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _pil_to_tensor(image: "Image.Image") -> torch.Tensor: |
|
|
"""Convert PIL Image to tensor (C, H, W) in [0, 1].""" |
|
|
if not PIL_AVAILABLE: |
|
|
raise ImportError("PIL is required to process PIL Images") |
|
|
|
|
|
|
|
|
if image.mode != "RGB": |
|
|
image = image.convert("RGB") |
|
|
|
|
|
|
|
|
if NUMPY_AVAILABLE: |
|
|
arr = np.array(image, dtype=np.float32) / 255.0 |
|
|
|
|
|
tensor = torch.from_numpy(arr).permute(2, 0, 1) |
|
|
else: |
|
|
|
|
|
width, height = image.size |
|
|
pixels = list(image.getdata()) |
|
|
tensor = torch.tensor(pixels, dtype=torch.float32).view(height, width, 3) / 255.0 |
|
|
tensor = tensor.permute(2, 0, 1) |
|
|
|
|
|
return tensor |
|
|
|
|
|
|
|
|
def _numpy_to_tensor(arr: "np.ndarray") -> torch.Tensor: |
|
|
"""Convert numpy array to tensor (C, H, W) in [0, 1].""" |
|
|
if not NUMPY_AVAILABLE: |
|
|
raise ImportError("NumPy is required to process numpy arrays") |
|
|
|
|
|
|
|
|
if arr.ndim == 2: |
|
|
|
|
|
arr = arr[..., None] |
|
|
|
|
|
if arr.ndim == 3 and arr.shape[-1] in [1, 3, 4]: |
|
|
|
|
|
arr = arr.transpose(2, 0, 1) |
|
|
|
|
|
|
|
|
if arr.dtype == np.uint8: |
|
|
arr = arr.astype(np.float32) / 255.0 |
|
|
elif arr.dtype != np.float32: |
|
|
arr = arr.astype(np.float32) |
|
|
|
|
|
return torch.from_numpy(arr.copy()) |
|
|
|
|
|
|
|
|
def standardize_input( |
|
|
images: Union[torch.Tensor, List[torch.Tensor], "Image.Image", List["Image.Image"], "np.ndarray", List["np.ndarray"]], |
|
|
device: Optional[torch.device] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Convert input images to standardized tensor format. |
|
|
|
|
|
Args: |
|
|
images: Input as: |
|
|
- torch.Tensor (C,H,W), (B,C,H,W), or list of tensors |
|
|
- PIL.Image.Image or list of PIL Images |
|
|
- numpy.ndarray (H,W,C), (B,H,W,C), or list of arrays |
|
|
device: Target device (defaults to input device or CPU) |
|
|
|
|
|
Returns: |
|
|
Tensor of shape (B, C, H, W) in float32, range [0, 1] |
|
|
""" |
|
|
|
|
|
if PIL_AVAILABLE and isinstance(images, Image.Image): |
|
|
images = [images] |
|
|
if NUMPY_AVAILABLE and isinstance(images, np.ndarray) and images.ndim == 3: |
|
|
|
|
|
if images.shape[-1] in [1, 3, 4]: |
|
|
images = [images] |
|
|
|
|
|
|
|
|
if isinstance(images, list): |
|
|
converted = [] |
|
|
for img in images: |
|
|
if PIL_AVAILABLE and isinstance(img, Image.Image): |
|
|
converted.append(_pil_to_tensor(img)) |
|
|
elif NUMPY_AVAILABLE and isinstance(img, np.ndarray): |
|
|
converted.append(_numpy_to_tensor(img)) |
|
|
elif isinstance(img, torch.Tensor): |
|
|
t = img if img.dim() == 3 else img.squeeze(0) |
|
|
converted.append(t) |
|
|
else: |
|
|
raise TypeError(f"Unsupported image type: {type(img)}") |
|
|
images = torch.stack(converted) |
|
|
elif NUMPY_AVAILABLE and isinstance(images, np.ndarray): |
|
|
|
|
|
if images.ndim == 4: |
|
|
images = images.transpose(0, 3, 1, 2) |
|
|
if images.dtype == np.uint8: |
|
|
images = images.astype(np.float32) / 255.0 |
|
|
images = torch.from_numpy(images.copy()) |
|
|
|
|
|
if images.dim() == 3: |
|
|
|
|
|
images = images.unsqueeze(0) |
|
|
|
|
|
|
|
|
if device is not None: |
|
|
images = images.to(device) |
|
|
|
|
|
|
|
|
if images.dtype == torch.uint8: |
|
|
images = images.float() / 255.0 |
|
|
elif images.dtype != torch.float32: |
|
|
images = images.float() |
|
|
|
|
|
|
|
|
images = images.clamp(0.0, 1.0) |
|
|
|
|
|
return images |
|
|
|
|
|
|
|
|
def rgb_to_grayscale(images: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Convert RGB images to grayscale using luminance formula. |
|
|
|
|
|
Y = 0.299 * R + 0.587 * G + 0.114 * B |
|
|
|
|
|
Args: |
|
|
images: Tensor of shape (B, 3, H, W) |
|
|
|
|
|
Returns: |
|
|
Tensor of shape (B, 1, H, W) |
|
|
""" |
|
|
|
|
|
weights = torch.tensor([0.299, 0.587, 0.114], device=images.device, dtype=images.dtype) |
|
|
weights = weights.view(1, 3, 1, 1) |
|
|
|
|
|
grayscale = (images * weights).sum(dim=1, keepdim=True) |
|
|
return grayscale |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple: |
|
|
""" |
|
|
Create Sobel kernels for gradient computation. |
|
|
|
|
|
Returns: |
|
|
Tuple of (sobel_x, sobel_y) kernels, each of shape (1, 1, 3, 3) |
|
|
""" |
|
|
sobel_x = torch.tensor([ |
|
|
[-1, 0, 1], |
|
|
[-2, 0, 2], |
|
|
[-1, 0, 1] |
|
|
], device=device, dtype=dtype).view(1, 1, 3, 3) |
|
|
|
|
|
sobel_y = torch.tensor([ |
|
|
[-1, -2, -1], |
|
|
[ 0, 0, 0], |
|
|
[ 1, 2, 1] |
|
|
], device=device, dtype=dtype).view(1, 1, 3, 3) |
|
|
|
|
|
return sobel_x, sobel_y |
|
|
|
|
|
|
|
|
def compute_gradients(grayscale: torch.Tensor) -> tuple: |
|
|
""" |
|
|
Compute image gradients using Sobel filters. |
|
|
|
|
|
Args: |
|
|
grayscale: Tensor of shape (B, 1, H, W) |
|
|
|
|
|
Returns: |
|
|
Tuple of (grad_x, grad_y, grad_magnitude) |
|
|
""" |
|
|
sobel_x, sobel_y = create_sobel_kernels(grayscale.device, grayscale.dtype) |
|
|
|
|
|
|
|
|
grad_x = F.conv2d(grayscale, sobel_x, padding=1) |
|
|
grad_y = F.conv2d(grayscale, sobel_y, padding=1) |
|
|
|
|
|
|
|
|
grad_magnitude = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-8) |
|
|
|
|
|
return grad_x, grad_y, grad_magnitude |
|
|
|
|
|
|
|
|
def compute_radial_symmetry_response( |
|
|
grayscale: torch.Tensor, |
|
|
grad_x: torch.Tensor, |
|
|
grad_y: torch.Tensor, |
|
|
grad_magnitude: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute radial symmetry response for circle detection. |
|
|
|
|
|
This weights regions that are: |
|
|
1. Dark (low intensity - typical of pupil/iris) |
|
|
2. Have strong radial gradients pointing inward |
|
|
|
|
|
Args: |
|
|
grayscale: Grayscale image (B, 1, H, W) |
|
|
grad_x, grad_y: Gradient components |
|
|
grad_magnitude: Gradient magnitude |
|
|
|
|
|
Returns: |
|
|
Radial symmetry response map (B, 1, H, W) |
|
|
""" |
|
|
B, _, H, W = grayscale.shape |
|
|
device = grayscale.device |
|
|
dtype = grayscale.dtype |
|
|
|
|
|
|
|
|
y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W) |
|
|
x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W) |
|
|
|
|
|
|
|
|
|
|
|
dark_weight = 1.0 - grayscale |
|
|
dark_weight = dark_weight ** 2 |
|
|
|
|
|
|
|
|
weight_sum = dark_weight.sum(dim=(2, 3), keepdim=True) + 1e-8 |
|
|
|
|
|
|
|
|
cx_init = (dark_weight * x_coords).sum(dim=(2, 3), keepdim=True) / weight_sum |
|
|
cy_init = (dark_weight * y_coords).sum(dim=(2, 3), keepdim=True) / weight_sum |
|
|
|
|
|
|
|
|
dx_to_center = cx_init - x_coords |
|
|
dy_to_center = cy_init - y_coords |
|
|
dist_to_center = torch.sqrt(dx_to_center ** 2 + dy_to_center ** 2 + 1e-8) |
|
|
|
|
|
|
|
|
dx_norm = dx_to_center / dist_to_center |
|
|
dy_norm = dy_to_center / dist_to_center |
|
|
|
|
|
|
|
|
grad_norm = grad_magnitude + 1e-8 |
|
|
gx_norm = grad_x / grad_norm |
|
|
gy_norm = grad_y / grad_norm |
|
|
|
|
|
|
|
|
|
|
|
radial_alignment = gx_norm * dx_norm + gy_norm * dy_norm |
|
|
|
|
|
|
|
|
response = radial_alignment * grad_magnitude * dark_weight |
|
|
|
|
|
|
|
|
kernel_size = max(H, W) // 8 |
|
|
if kernel_size % 2 == 0: |
|
|
kernel_size += 1 |
|
|
kernel_size = max(kernel_size, 5) |
|
|
|
|
|
sigma = kernel_size / 6.0 |
|
|
|
|
|
|
|
|
x = torch.arange(kernel_size, device=device, dtype=dtype) - kernel_size // 2 |
|
|
gaussian_1d = torch.exp(-x ** 2 / (2 * sigma ** 2)) |
|
|
gaussian_1d = gaussian_1d / gaussian_1d.sum() |
|
|
|
|
|
|
|
|
gaussian_1d_h = gaussian_1d.view(1, 1, 1, kernel_size) |
|
|
gaussian_1d_v = gaussian_1d.view(1, 1, kernel_size, 1) |
|
|
|
|
|
pad_h = kernel_size // 2 |
|
|
pad_v = kernel_size // 2 |
|
|
|
|
|
response = F.pad(response, (pad_h, pad_h, 0, 0), mode='reflect') |
|
|
response = F.conv2d(response, gaussian_1d_h) |
|
|
response = F.pad(response, (0, 0, pad_v, pad_v), mode='reflect') |
|
|
response = F.conv2d(response, gaussian_1d_v) |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
def soft_argmax_2d(response: torch.Tensor, temperature: float = 0.1) -> tuple: |
|
|
""" |
|
|
Compute soft argmax to find the center coordinates. |
|
|
|
|
|
Args: |
|
|
response: Response map (B, 1, H, W) |
|
|
temperature: Softmax temperature (lower = sharper) |
|
|
|
|
|
Returns: |
|
|
Tuple of (cx, cy) each of shape (B,) |
|
|
""" |
|
|
B, _, H, W = response.shape |
|
|
device = response.device |
|
|
dtype = response.dtype |
|
|
|
|
|
|
|
|
response_flat = response.view(B, -1) |
|
|
|
|
|
|
|
|
weights = F.softmax(response_flat / temperature, dim=1) |
|
|
weights = weights.view(B, 1, H, W) |
|
|
|
|
|
|
|
|
y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W) |
|
|
x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W) |
|
|
|
|
|
|
|
|
cx = (weights * x_coords).sum(dim=(2, 3)).squeeze(-1) |
|
|
cy = (weights * y_coords).sum(dim=(2, 3)).squeeze(-1) |
|
|
|
|
|
return cx, cy |
|
|
|
|
|
|
|
|
def estimate_eye_center( |
|
|
images: torch.Tensor, |
|
|
softmax_temperature: float = 0.1, |
|
|
) -> tuple: |
|
|
""" |
|
|
Estimate the center of the eye region in each image. |
|
|
|
|
|
Args: |
|
|
images: RGB images of shape (B, 3, H, W) |
|
|
softmax_temperature: Temperature for soft argmax (lower = sharper peak detection, |
|
|
higher = more averaging). Typical range: 0.01-1.0. Default 0.1 works well |
|
|
for most fundus images. Use higher values (0.3-0.5) for noisy images. |
|
|
|
|
|
Returns: |
|
|
Tuple of (cx, cy) each of shape (B,) in pixel coordinates |
|
|
""" |
|
|
grayscale = rgb_to_grayscale(images) |
|
|
grad_x, grad_y, grad_magnitude = compute_gradients(grayscale) |
|
|
response = compute_radial_symmetry_response(grayscale, grad_x, grad_y, grad_magnitude) |
|
|
cx, cy = soft_argmax_2d(response, temperature=softmax_temperature) |
|
|
|
|
|
return cx, cy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def estimate_radius( |
|
|
images: torch.Tensor, |
|
|
cx: torch.Tensor, |
|
|
cy: torch.Tensor, |
|
|
num_radii: int = 100, |
|
|
num_angles: int = 36, |
|
|
min_radius_frac: float = 0.1, |
|
|
max_radius_frac: float = 0.5, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Estimate the radius of the eye region by analyzing radial intensity profiles. |
|
|
|
|
|
Args: |
|
|
images: RGB images (B, 3, H, W) |
|
|
cx, cy: Center coordinates (B,) |
|
|
num_radii: Number of radius samples |
|
|
num_angles: Number of angular samples |
|
|
min_radius_frac: Minimum radius as fraction of image size |
|
|
max_radius_frac: Maximum radius as fraction of image size |
|
|
|
|
|
Returns: |
|
|
Estimated radius for each image (B,) |
|
|
""" |
|
|
B, _, H, W = images.shape |
|
|
device = images.device |
|
|
dtype = images.dtype |
|
|
|
|
|
grayscale = rgb_to_grayscale(images) |
|
|
|
|
|
min_dim = min(H, W) |
|
|
min_radius = int(min_radius_frac * min_dim) |
|
|
max_radius = int(max_radius_frac * min_dim) |
|
|
|
|
|
|
|
|
radii = torch.linspace(min_radius, max_radius, num_radii, device=device, dtype=dtype) |
|
|
angles = torch.linspace(0, 2 * math.pi, num_angles + 1, device=device, dtype=dtype)[:-1] |
|
|
|
|
|
|
|
|
cos_angles = torch.cos(angles).view(-1, 1) |
|
|
sin_angles = torch.sin(angles).view(-1, 1) |
|
|
|
|
|
|
|
|
dx = cos_angles * radii |
|
|
dy = sin_angles * radii |
|
|
|
|
|
|
|
|
|
|
|
cx_expanded = cx.view(B, 1, 1).expand(B, num_angles, num_radii) |
|
|
cy_expanded = cy.view(B, 1, 1).expand(B, num_angles, num_radii) |
|
|
|
|
|
sample_x = cx_expanded + dx.unsqueeze(0) |
|
|
sample_y = cy_expanded + dy.unsqueeze(0) |
|
|
|
|
|
|
|
|
sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0 |
|
|
sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0 |
|
|
|
|
|
|
|
|
grid = torch.stack([sample_x_norm, sample_y_norm], dim=-1) |
|
|
|
|
|
|
|
|
sampled = F.grid_sample( |
|
|
grayscale, grid, mode='bilinear', padding_mode='border', align_corners=True |
|
|
) |
|
|
|
|
|
|
|
|
radial_profile = sampled.mean(dim=2).squeeze(1) |
|
|
|
|
|
|
|
|
radial_gradient = radial_profile[:, 1:] - radial_profile[:, :-1] |
|
|
|
|
|
|
|
|
|
|
|
radius_weights = torch.linspace(0.5, 1.5, num_radii - 1, device=device, dtype=dtype) |
|
|
weighted_gradient = radial_gradient * radius_weights.unsqueeze(0) |
|
|
|
|
|
|
|
|
min_idx = weighted_gradient.argmin(dim=1) |
|
|
|
|
|
|
|
|
estimated_radius = radii[min_idx + 1] |
|
|
|
|
|
|
|
|
estimated_radius = estimated_radius.clamp(min_radius, max_radius) |
|
|
|
|
|
return estimated_radius |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_crop_box( |
|
|
cx: torch.Tensor, |
|
|
cy: torch.Tensor, |
|
|
radius: torch.Tensor, |
|
|
H: int, |
|
|
W: int, |
|
|
scale_factor: float = 1.1, |
|
|
allow_overflow: bool = False, |
|
|
) -> tuple: |
|
|
""" |
|
|
Compute square bounding box for cropping. |
|
|
|
|
|
Args: |
|
|
cx, cy: Center coordinates (B,) |
|
|
radius: Estimated radius (B,) |
|
|
H, W: Image dimensions |
|
|
scale_factor: Multiply radius by this factor for padding |
|
|
allow_overflow: If True, don't clamp box to image bounds (for pre-cropped images) |
|
|
|
|
|
Returns: |
|
|
Tuple of (x1, y1, x2, y2) each of shape (B,) |
|
|
""" |
|
|
|
|
|
half_side = radius * scale_factor |
|
|
|
|
|
|
|
|
x1 = cx - half_side |
|
|
y1 = cy - half_side |
|
|
x2 = cx + half_side |
|
|
y2 = cy + half_side |
|
|
|
|
|
if allow_overflow: |
|
|
|
|
|
|
|
|
return x1, y1, x2, y2 |
|
|
|
|
|
|
|
|
|
|
|
x1 = x1.clamp(min=0) |
|
|
y1 = y1.clamp(min=0) |
|
|
x2 = x2.clamp(max=W - 1) |
|
|
y2 = y2.clamp(max=H - 1) |
|
|
|
|
|
|
|
|
side_x = x2 - x1 |
|
|
side_y = y2 - y1 |
|
|
side = torch.minimum(side_x, side_y) |
|
|
|
|
|
|
|
|
cx_new = (x1 + x2) / 2 |
|
|
cy_new = (y1 + y2) / 2 |
|
|
|
|
|
x1 = (cx_new - side / 2).clamp(min=0) |
|
|
y1 = (cy_new - side / 2).clamp(min=0) |
|
|
x2 = x1 + side |
|
|
y2 = y1 + side |
|
|
|
|
|
|
|
|
x2 = x2.clamp(max=W - 1) |
|
|
y2 = y2.clamp(max=H - 1) |
|
|
|
|
|
return x1, y1, x2, y2 |
|
|
|
|
|
|
|
|
def batch_crop_and_resize( |
|
|
images: torch.Tensor, |
|
|
x1: torch.Tensor, |
|
|
y1: torch.Tensor, |
|
|
x2: torch.Tensor, |
|
|
y2: torch.Tensor, |
|
|
output_size: int, |
|
|
padding_mode: str = 'border', |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Crop and resize images using grid_sample for GPU efficiency. |
|
|
|
|
|
Args: |
|
|
images: Input images (B, C, H, W) |
|
|
x1, y1, x2, y2: Crop coordinates (B,) - can extend beyond image bounds |
|
|
output_size: Output square size |
|
|
padding_mode: How to handle out-of-bounds sampling: |
|
|
- 'border': repeat edge pixels (default) |
|
|
- 'zeros': fill with black (useful for pre-cropped images) |
|
|
|
|
|
Returns: |
|
|
Cropped and resized images (B, C, output_size, output_size) |
|
|
""" |
|
|
B, C, H, W = images.shape |
|
|
device = images.device |
|
|
dtype = images.dtype |
|
|
|
|
|
|
|
|
out_coords = torch.linspace(0, 1, output_size, device=device, dtype=dtype) |
|
|
out_y, out_x = torch.meshgrid(out_coords, out_coords, indexing='ij') |
|
|
out_grid = torch.stack([out_x, out_y], dim=-1) |
|
|
out_grid = out_grid.unsqueeze(0).expand(B, -1, -1, -1) |
|
|
|
|
|
|
|
|
|
|
|
x1 = x1.view(B, 1, 1, 1) |
|
|
y1 = y1.view(B, 1, 1, 1) |
|
|
x2 = x2.view(B, 1, 1, 1) |
|
|
y2 = y2.view(B, 1, 1, 1) |
|
|
|
|
|
|
|
|
sample_x = x1 + out_grid[..., 0:1] * (x2 - x1) |
|
|
sample_y = y1 + out_grid[..., 1:2] * (y2 - y1) |
|
|
|
|
|
|
|
|
sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0 |
|
|
sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0 |
|
|
|
|
|
grid = torch.cat([sample_x_norm, sample_y_norm], dim=-1) |
|
|
|
|
|
|
|
|
cropped = F.grid_sample( |
|
|
images, grid, mode='bilinear', padding_mode=padding_mode, align_corners=True |
|
|
) |
|
|
|
|
|
return cropped |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor: |
|
|
"""Convert sRGB to linear RGB.""" |
|
|
threshold = 0.04045 |
|
|
linear = torch.where( |
|
|
rgb <= threshold, |
|
|
rgb / 12.92, |
|
|
((rgb + 0.055) / 1.055) ** 2.4 |
|
|
) |
|
|
return linear |
|
|
|
|
|
|
|
|
def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor: |
|
|
"""Convert linear RGB to sRGB.""" |
|
|
threshold = 0.0031308 |
|
|
srgb = torch.where( |
|
|
linear <= threshold, |
|
|
linear * 12.92, |
|
|
1.055 * (linear ** (1.0 / 2.4)) - 0.055 |
|
|
) |
|
|
return srgb |
|
|
|
|
|
|
|
|
def rgb_to_lab(images: torch.Tensor) -> tuple: |
|
|
""" |
|
|
Convert sRGB images to CIE LAB color space. |
|
|
|
|
|
This is a proper LAB conversion that: |
|
|
1. Converts sRGB to linear RGB |
|
|
2. Converts linear RGB to XYZ |
|
|
3. Converts XYZ to LAB |
|
|
|
|
|
Args: |
|
|
images: RGB images (B, C, H, W) in [0, 1] sRGB |
|
|
|
|
|
Returns: |
|
|
Tuple of (L, a, b) where: |
|
|
- L: Luminance in [0, 1] (normalized from [0, 100]) |
|
|
- a, b: Chrominance (normalized to roughly [-0.5, 0.5]) |
|
|
""" |
|
|
device = images.device |
|
|
dtype = images.dtype |
|
|
|
|
|
|
|
|
linear_rgb = _srgb_to_linear(images) |
|
|
|
|
|
|
|
|
|
|
|
r = linear_rgb[:, 0:1, :, :] |
|
|
g = linear_rgb[:, 1:2, :, :] |
|
|
b = linear_rgb[:, 2:3, :, :] |
|
|
|
|
|
x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b |
|
|
y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b |
|
|
z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b |
|
|
|
|
|
|
|
|
xn, yn, zn = 0.95047, 1.0, 1.08883 |
|
|
|
|
|
x = x / xn |
|
|
y = y / yn |
|
|
z = z / zn |
|
|
|
|
|
|
|
|
delta = 6.0 / 29.0 |
|
|
delta_cube = delta ** 3 |
|
|
|
|
|
def f(t): |
|
|
return torch.where( |
|
|
t > delta_cube, |
|
|
t ** (1.0 / 3.0), |
|
|
t / (3.0 * delta ** 2) + 4.0 / 29.0 |
|
|
) |
|
|
|
|
|
fx = f(x) |
|
|
fy = f(y) |
|
|
fz = f(z) |
|
|
|
|
|
L = 116.0 * fy - 16.0 |
|
|
a = 500.0 * (fx - fy) |
|
|
b_ch = 200.0 * (fy - fz) |
|
|
|
|
|
|
|
|
L = L / 100.0 |
|
|
a = a / 256.0 + 0.5 |
|
|
b_ch = b_ch / 256.0 + 0.5 |
|
|
|
|
|
return L, a, b_ch |
|
|
|
|
|
|
|
|
def lab_to_rgb(L: torch.Tensor, a: torch.Tensor, b_ch: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Convert CIE LAB to sRGB. |
|
|
|
|
|
Args: |
|
|
L: Luminance in [0, 1] (normalized from [0, 100]) |
|
|
a, b_ch: Chrominance (normalized, roughly [0, 1]) |
|
|
|
|
|
Returns: |
|
|
RGB images (B, 3, H, W) in [0, 1] sRGB |
|
|
""" |
|
|
|
|
|
L_lab = L * 100.0 |
|
|
a_lab = (a - 0.5) * 256.0 |
|
|
b_lab = (b_ch - 0.5) * 256.0 |
|
|
|
|
|
|
|
|
fy = (L_lab + 16.0) / 116.0 |
|
|
fx = a_lab / 500.0 + fy |
|
|
fz = fy - b_lab / 200.0 |
|
|
|
|
|
delta = 6.0 / 29.0 |
|
|
|
|
|
def f_inv(t): |
|
|
return torch.where( |
|
|
t > delta, |
|
|
t ** 3, |
|
|
3.0 * (delta ** 2) * (t - 4.0 / 29.0) |
|
|
) |
|
|
|
|
|
|
|
|
xn, yn, zn = 0.95047, 1.0, 1.08883 |
|
|
|
|
|
x = xn * f_inv(fx) |
|
|
y = yn * f_inv(fy) |
|
|
z = zn * f_inv(fz) |
|
|
|
|
|
|
|
|
r = 3.2404542 * x - 1.5371385 * y - 0.4985314 * z |
|
|
g = -0.9692660 * x + 1.8760108 * y + 0.0415560 * z |
|
|
b = 0.0556434 * x - 0.2040259 * y + 1.0572252 * z |
|
|
|
|
|
linear_rgb = torch.cat([r, g, b], dim=1) |
|
|
|
|
|
|
|
|
linear_rgb = linear_rgb.clamp(0.0, 1.0) |
|
|
|
|
|
|
|
|
srgb = _linear_to_srgb(linear_rgb) |
|
|
|
|
|
return srgb.clamp(0.0, 1.0) |
|
|
|
|
|
|
|
|
def compute_histogram( |
|
|
tensor: torch.Tensor, |
|
|
num_bins: int = 256, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute histogram for a batch of single-channel images. |
|
|
|
|
|
Args: |
|
|
tensor: Input tensor (B, 1, H, W) with values in [0, 1] |
|
|
num_bins: Number of histogram bins |
|
|
|
|
|
Returns: |
|
|
Histograms (B, num_bins) |
|
|
""" |
|
|
B = tensor.shape[0] |
|
|
device = tensor.device |
|
|
dtype = tensor.dtype |
|
|
|
|
|
|
|
|
flat = tensor.view(B, -1) |
|
|
|
|
|
|
|
|
bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1) |
|
|
|
|
|
|
|
|
histograms = torch.zeros(B, num_bins, device=device, dtype=dtype) |
|
|
ones = torch.ones_like(flat, dtype=dtype) |
|
|
|
|
|
for i in range(B): |
|
|
histograms[i] = histograms[i].scatter_add(0, bin_indices[i], ones[i]) |
|
|
|
|
|
return histograms |
|
|
|
|
|
|
|
|
def clahe_single_tile( |
|
|
tile: torch.Tensor, |
|
|
clip_limit: float, |
|
|
num_bins: int = 256, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Apply CLAHE to a single tile. |
|
|
|
|
|
Args: |
|
|
tile: Input tile (B, 1, tile_h, tile_w) |
|
|
clip_limit: Histogram clip limit |
|
|
num_bins: Number of histogram bins |
|
|
|
|
|
Returns: |
|
|
CDF lookup table (B, num_bins) |
|
|
""" |
|
|
B, _, tile_h, tile_w = tile.shape |
|
|
device = tile.device |
|
|
dtype = tile.dtype |
|
|
num_pixels = tile_h * tile_w |
|
|
|
|
|
|
|
|
hist = compute_histogram(tile, num_bins) |
|
|
|
|
|
|
|
|
clip_value = clip_limit * num_pixels / num_bins |
|
|
excess = (hist - clip_value).clamp(min=0).sum(dim=1, keepdim=True) |
|
|
hist = hist.clamp(max=clip_value) |
|
|
|
|
|
|
|
|
redistribution = excess / num_bins |
|
|
hist = hist + redistribution |
|
|
|
|
|
|
|
|
cdf = hist.cumsum(dim=1) |
|
|
|
|
|
|
|
|
cdf_min = cdf[:, 0:1] |
|
|
cdf_max = cdf[:, -1:] |
|
|
cdf = (cdf - cdf_min) / (cdf_max - cdf_min + 1e-8) |
|
|
|
|
|
return cdf |
|
|
|
|
|
|
|
|
def apply_clahe_vectorized( |
|
|
images: torch.Tensor, |
|
|
grid_size: int = 8, |
|
|
clip_limit: float = 2.0, |
|
|
num_bins: int = 256, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Vectorized CLAHE implementation (more efficient for GPU). |
|
|
|
|
|
Args: |
|
|
images: Input images (B, C, H, W) |
|
|
grid_size: Number of tiles in each dimension |
|
|
clip_limit: Histogram clip limit |
|
|
num_bins: Number of histogram bins |
|
|
|
|
|
Returns: |
|
|
CLAHE-enhanced images (B, C, H, W) |
|
|
""" |
|
|
B, C, H, W = images.shape |
|
|
device = images.device |
|
|
dtype = images.dtype |
|
|
|
|
|
|
|
|
if C == 3: |
|
|
L, a, b_ch = rgb_to_lab(images) |
|
|
else: |
|
|
L = images.clone() |
|
|
a = b_ch = None |
|
|
|
|
|
|
|
|
pad_h = (grid_size - H % grid_size) % grid_size |
|
|
pad_w = (grid_size - W % grid_size) % grid_size |
|
|
|
|
|
if pad_h > 0 or pad_w > 0: |
|
|
L_padded = F.pad(L, (0, pad_w, 0, pad_h), mode='reflect') |
|
|
else: |
|
|
L_padded = L |
|
|
|
|
|
_, _, H_pad, W_pad = L_padded.shape |
|
|
tile_h = H_pad // grid_size |
|
|
tile_w = W_pad // grid_size |
|
|
|
|
|
|
|
|
L_tiles = L_padded.view(B, 1, grid_size, tile_h, grid_size, tile_w) |
|
|
L_tiles = L_tiles.permute(0, 2, 4, 1, 3, 5) |
|
|
L_tiles = L_tiles.reshape(B * grid_size * grid_size, 1, tile_h, tile_w) |
|
|
|
|
|
|
|
|
num_pixels = tile_h * tile_w |
|
|
flat = L_tiles.view(B * grid_size * grid_size, -1) |
|
|
bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1) |
|
|
|
|
|
|
|
|
histograms = torch.zeros(B * grid_size * grid_size, num_bins, device=device, dtype=dtype) |
|
|
histograms.scatter_add_(1, bin_indices, torch.ones_like(flat)) |
|
|
|
|
|
|
|
|
clip_value = clip_limit * num_pixels / num_bins |
|
|
excess = (histograms - clip_value).clamp(min=0).sum(dim=1, keepdim=True) |
|
|
histograms = histograms.clamp(max=clip_value) |
|
|
histograms = histograms + excess / num_bins |
|
|
|
|
|
|
|
|
cdfs = histograms.cumsum(dim=1) |
|
|
cdf_min = cdfs[:, 0:1] |
|
|
cdf_max = cdfs[:, -1:] |
|
|
cdfs = (cdfs - cdf_min) / (cdf_max - cdf_min + 1e-8) |
|
|
|
|
|
|
|
|
cdfs = cdfs.view(B, grid_size, grid_size, num_bins) |
|
|
|
|
|
|
|
|
y_coords = torch.arange(H_pad, device=device, dtype=dtype) |
|
|
x_coords = torch.arange(W_pad, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
tile_y = (y_coords + 0.5) / tile_h - 0.5 |
|
|
tile_x = (x_coords + 0.5) / tile_w - 0.5 |
|
|
|
|
|
tile_y = tile_y.clamp(0, grid_size - 1.001) |
|
|
tile_x = tile_x.clamp(0, grid_size - 1.001) |
|
|
|
|
|
|
|
|
ty0 = tile_y.long().clamp(0, grid_size - 2) |
|
|
tx0 = tile_x.long().clamp(0, grid_size - 2) |
|
|
ty1 = (ty0 + 1).clamp(max=grid_size - 1) |
|
|
tx1 = (tx0 + 1).clamp(max=grid_size - 1) |
|
|
|
|
|
wy = (tile_y - ty0.float()).view(1, H_pad, 1, 1) |
|
|
wx = (tile_x - tx0.float()).view(1, 1, W_pad, 1) |
|
|
|
|
|
|
|
|
bin_idx = (L_padded * (num_bins - 1)).long().clamp(0, num_bins - 1) |
|
|
bin_idx = bin_idx.squeeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, H_pad, W_pad) |
|
|
ty0_exp = ty0.view(1, H_pad, 1).expand(B, H_pad, W_pad) |
|
|
ty1_exp = ty1.view(1, H_pad, 1).expand(B, H_pad, W_pad) |
|
|
tx0_exp = tx0.view(1, 1, W_pad).expand(B, H_pad, W_pad) |
|
|
tx1_exp = tx1.view(1, 1, W_pad).expand(B, H_pad, W_pad) |
|
|
|
|
|
|
|
|
v00 = cdfs[b_idx, ty0_exp, tx0_exp, bin_idx] |
|
|
v01 = cdfs[b_idx, ty0_exp, tx1_exp, bin_idx] |
|
|
v10 = cdfs[b_idx, ty1_exp, tx0_exp, bin_idx] |
|
|
v11 = cdfs[b_idx, ty1_exp, tx1_exp, bin_idx] |
|
|
|
|
|
|
|
|
wy = wy.squeeze(-1) |
|
|
wx = wx.squeeze(-1) |
|
|
|
|
|
L_out = (1 - wy) * (1 - wx) * v00 + (1 - wy) * wx * v01 + wy * (1 - wx) * v10 + wy * wx * v11 |
|
|
L_out = L_out.unsqueeze(1) |
|
|
|
|
|
|
|
|
if pad_h > 0 or pad_w > 0: |
|
|
L_out = L_out[:, :, :H, :W] |
|
|
|
|
|
|
|
|
if C == 3: |
|
|
output = lab_to_rgb(L_out, a, b_ch) |
|
|
else: |
|
|
output = L_out |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IMAGENET_MEAN = [0.485, 0.456, 0.406] |
|
|
IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
|
|
|
|
|
|
def resize_images( |
|
|
images: torch.Tensor, |
|
|
size: int, |
|
|
mode: str = 'bilinear', |
|
|
antialias: bool = True, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Resize images to target size. |
|
|
|
|
|
Args: |
|
|
images: Input images (B, C, H, W) |
|
|
size: Target size (square) |
|
|
mode: Interpolation mode |
|
|
antialias: Whether to use antialiasing |
|
|
|
|
|
Returns: |
|
|
Resized images (B, C, size, size) |
|
|
""" |
|
|
return F.interpolate( |
|
|
images, |
|
|
size=(size, size), |
|
|
mode=mode, |
|
|
align_corners=False if mode in ['bilinear', 'bicubic'] else None, |
|
|
antialias=antialias if mode in ['bilinear', 'bicubic'] else False, |
|
|
) |
|
|
|
|
|
|
|
|
def normalize_images( |
|
|
images: torch.Tensor, |
|
|
mean: Optional[List[float]] = None, |
|
|
std: Optional[List[float]] = None, |
|
|
mode: str = 'imagenet', |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Normalize images. |
|
|
|
|
|
Args: |
|
|
images: Input images (B, C, H, W) in [0, 1] |
|
|
mean: Custom mean (per channel) |
|
|
std: Custom std (per channel) |
|
|
mode: 'imagenet', 'none', or 'custom' |
|
|
|
|
|
Returns: |
|
|
Normalized images |
|
|
""" |
|
|
if mode == 'none': |
|
|
return images |
|
|
|
|
|
if mode == 'imagenet': |
|
|
mean = IMAGENET_MEAN |
|
|
std = IMAGENET_STD |
|
|
elif mode == 'custom': |
|
|
if mean is None or std is None: |
|
|
raise ValueError("Custom mode requires mean and std") |
|
|
else: |
|
|
raise ValueError(f"Unknown normalization mode: {mode}") |
|
|
|
|
|
device = images.device |
|
|
dtype = images.dtype |
|
|
|
|
|
mean_tensor = torch.tensor(mean, device=device, dtype=dtype).view(1, -1, 1, 1) |
|
|
std_tensor = torch.tensor(std, device=device, dtype=dtype).view(1, -1, 1, 1) |
|
|
|
|
|
return (images - mean_tensor) / std_tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EyeCLAHEImageProcessor(BaseImageProcessor): |
|
|
""" |
|
|
GPU-native image processor for Color Fundus Photography (CFP) images. |
|
|
|
|
|
This processor: |
|
|
1. Localizes the eye region using gradient-based radial symmetry |
|
|
2. Crops to a border-minimized square centered on the eye |
|
|
3. Applies CLAHE for contrast enhancement |
|
|
4. Resizes and normalizes for vision model input |
|
|
|
|
|
All operations are implemented in pure PyTorch and are CUDA-compatible. |
|
|
""" |
|
|
|
|
|
model_input_names = ["pixel_values"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
size: int = 224, |
|
|
crop_scale_factor: float = 1.1, |
|
|
clahe_grid_size: int = 8, |
|
|
clahe_clip_limit: float = 2.0, |
|
|
normalization_mode: str = "imagenet", |
|
|
custom_mean: Optional[List[float]] = None, |
|
|
custom_std: Optional[List[float]] = None, |
|
|
do_clahe: bool = True, |
|
|
do_crop: bool = True, |
|
|
min_radius_frac: float = 0.1, |
|
|
max_radius_frac: float = 0.5, |
|
|
allow_overflow: bool = False, |
|
|
softmax_temperature: float = 0.1, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Initialize the EyeCLAHEImageProcessor. |
|
|
|
|
|
Args: |
|
|
size: Output image size (square) |
|
|
crop_scale_factor: Scale factor for crop box (relative to detected radius) |
|
|
clahe_grid_size: Number of tiles for CLAHE |
|
|
clahe_clip_limit: Histogram clip limit for CLAHE |
|
|
normalization_mode: 'imagenet', 'none', or 'custom' |
|
|
custom_mean: Custom normalization mean (if mode='custom') |
|
|
custom_std: Custom normalization std (if mode='custom') |
|
|
do_clahe: Whether to apply CLAHE |
|
|
do_crop: Whether to perform eye-centered cropping |
|
|
min_radius_frac: Minimum radius as fraction of image size |
|
|
max_radius_frac: Maximum radius as fraction of image size |
|
|
allow_overflow: If True, allow crop box to extend beyond image bounds |
|
|
and fill missing regions with black. Useful for pre-cropped |
|
|
images where the fundus circle is partially cut off. |
|
|
softmax_temperature: Temperature for soft argmax in eye center detection. |
|
|
Lower values (0.01-0.1) give sharper peak detection, higher values |
|
|
(0.3-0.5) provide more averaging for noisy images. Default: 0.1. |
|
|
""" |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.size = size |
|
|
self.crop_scale_factor = crop_scale_factor |
|
|
self.clahe_grid_size = clahe_grid_size |
|
|
self.clahe_clip_limit = clahe_clip_limit |
|
|
self.normalization_mode = normalization_mode |
|
|
self.custom_mean = custom_mean |
|
|
self.custom_std = custom_std |
|
|
self.do_clahe = do_clahe |
|
|
self.do_crop = do_crop |
|
|
self.min_radius_frac = min_radius_frac |
|
|
self.max_radius_frac = max_radius_frac |
|
|
self.allow_overflow = allow_overflow |
|
|
self.softmax_temperature = softmax_temperature |
|
|
|
|
|
def preprocess( |
|
|
self, |
|
|
images, |
|
|
return_tensors: str = "pt", |
|
|
device: Optional[Union[str, torch.device]] = None, |
|
|
**kwargs, |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Preprocess images for model input. |
|
|
|
|
|
Args: |
|
|
images: Input images in any of these formats: |
|
|
- torch.Tensor: (C,H,W), (B,C,H,W), or list of tensors |
|
|
- PIL.Image.Image: single image or list of images |
|
|
- numpy.ndarray: (H,W,C), (B,H,W,C), or list of arrays |
|
|
return_tensors: Return type (only "pt" supported) |
|
|
device: Target device for processing (e.g., "cuda", "cpu") |
|
|
|
|
|
Returns: |
|
|
BatchFeature with keys: |
|
|
- 'pixel_values': Processed images (B, C, size, size) |
|
|
- 'scale_x', 'scale_y': Scale factors for coordinate mapping (B,) |
|
|
- 'offset_x', 'offset_y': Offsets for coordinate mapping (B,) |
|
|
|
|
|
To map coordinates from processed image back to original: |
|
|
orig_x = offset_x + cropped_x * scale_x |
|
|
orig_y = offset_y + cropped_y * scale_y |
|
|
""" |
|
|
if return_tensors != "pt": |
|
|
raise ValueError("Only 'pt' (PyTorch) tensors are supported") |
|
|
|
|
|
|
|
|
if device is not None: |
|
|
device = torch.device(device) |
|
|
elif isinstance(images, torch.Tensor): |
|
|
device = images.device |
|
|
elif isinstance(images, list) and len(images) > 0 and isinstance(images[0], torch.Tensor): |
|
|
device = images[0].device |
|
|
else: |
|
|
|
|
|
device = torch.device('cpu') |
|
|
|
|
|
|
|
|
images = standardize_input(images, device) |
|
|
B, C, H_orig, W_orig = images.shape |
|
|
|
|
|
if self.do_crop: |
|
|
|
|
|
cx, cy = estimate_eye_center(images, softmax_temperature=self.softmax_temperature) |
|
|
|
|
|
|
|
|
radius = estimate_radius( |
|
|
images, cx, cy, |
|
|
min_radius_frac=self.min_radius_frac, |
|
|
max_radius_frac=self.max_radius_frac, |
|
|
) |
|
|
|
|
|
|
|
|
x1, y1, x2, y2 = compute_crop_box( |
|
|
cx, cy, radius, H_orig, W_orig, |
|
|
scale_factor=self.crop_scale_factor, |
|
|
allow_overflow=self.allow_overflow, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
scale_x = (x2 - x1) / (self.size - 1) |
|
|
scale_y = (y2 - y1) / (self.size - 1) |
|
|
offset_x = x1 |
|
|
offset_y = y1 |
|
|
|
|
|
|
|
|
|
|
|
padding_mode = 'zeros' if self.allow_overflow else 'border' |
|
|
images = batch_crop_and_resize(images, x1, y1, x2, y2, self.size, padding_mode=padding_mode) |
|
|
else: |
|
|
|
|
|
|
|
|
scale_x = torch.full((B,), (W_orig - 1) / (self.size - 1), device=device, dtype=images.dtype) |
|
|
scale_y = torch.full((B,), (H_orig - 1) / (self.size - 1), device=device, dtype=images.dtype) |
|
|
offset_x = torch.zeros(B, device=device, dtype=images.dtype) |
|
|
offset_y = torch.zeros(B, device=device, dtype=images.dtype) |
|
|
images = resize_images(images, self.size) |
|
|
|
|
|
|
|
|
if self.do_clahe: |
|
|
images = apply_clahe_vectorized( |
|
|
images, |
|
|
grid_size=self.clahe_grid_size, |
|
|
clip_limit=self.clahe_clip_limit, |
|
|
) |
|
|
|
|
|
|
|
|
images = normalize_images( |
|
|
images, |
|
|
mean=self.custom_mean, |
|
|
std=self.custom_std, |
|
|
mode=self.normalization_mode, |
|
|
) |
|
|
|
|
|
|
|
|
return BatchFeature( |
|
|
data={ |
|
|
"pixel_values": images, |
|
|
"scale_x": scale_x, |
|
|
"scale_y": scale_y, |
|
|
"offset_x": offset_x, |
|
|
"offset_y": offset_y, |
|
|
}, |
|
|
tensor_type="pt" |
|
|
) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: Union[torch.Tensor, List[torch.Tensor]], |
|
|
**kwargs, |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Process images (alias for preprocess). |
|
|
""" |
|
|
return self.preprocess(images, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
EyeGPUImageProcessor = EyeCLAHEImageProcessor |
|
|
|