| | """Image processor for MVANet model.""" |
| |
|
| | from typing import Dict, List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from PIL import Image |
| | from transformers import BaseImageProcessor |
| | from transformers.image_processing_utils import BatchFeature |
| | from transformers.image_utils import ( |
| | ImageInput, |
| | PILImageResampling, |
| | ) |
| | from transformers.utils import TensorType |
| |
|
| |
|
| | def to_pil_image(image: Union[np.ndarray, torch.Tensor, Image.Image]) -> Image.Image: |
| | """Convert various image formats to PIL Image.""" |
| | if isinstance(image, Image.Image): |
| | return image |
| | if isinstance(image, torch.Tensor): |
| | |
| | if image.ndim == 3 and image.shape[0] in [1, 3, 4]: |
| | image = image.permute(1, 2, 0).cpu().numpy() |
| | image = (image * 255).clip(0, 255).astype(np.uint8) |
| | if isinstance(image, np.ndarray): |
| | if image.ndim == 2: |
| | |
| | return Image.fromarray(image, mode="L") |
| | elif image.ndim == 3: |
| | if image.shape[2] == 1: |
| | return Image.fromarray(image.squeeze(2), mode="L") |
| | elif image.shape[2] == 3: |
| | return Image.fromarray(image, mode="RGB") |
| | elif image.shape[2] == 4: |
| | return Image.fromarray(image, mode="RGBA") |
| | raise ValueError(f"Unsupported image type: {type(image)}") |
| |
|
| |
|
| | class MVANetImageProcessor(BaseImageProcessor): |
| | """ |
| | Constructs a MVANet image processor. |
| | |
| | Args: |
| | do_resize (:obj:`bool`, `optional`, defaults to :obj:`True`): |
| | Whether to resize the image. |
| | size (:obj:`Dict[str, int]`, `optional`, defaults to :obj:`{"height": 1024, "width": 1024}`): |
| | Target size for resizing. MVANet was trained on 1024x1024 images. |
| | resample (:obj:`PILImageResampling`, `optional`, defaults to :obj:`PILImageResampling.BILINEAR`): |
| | Resampling filter to use when resizing the image. |
| | do_normalize (:obj:`bool`, `optional`, defaults to :obj:`True`): |
| | Whether to normalize the image. |
| | image_mean (:obj:`List[float]`, `optional`, defaults to :obj:`[0.485, 0.456, 0.406]`): |
| | Mean to use for normalization (ImageNet mean). |
| | image_std (:obj:`List[float]`, `optional`, defaults to :obj:`[0.229, 0.224, 0.225]`): |
| | Standard deviation to use for normalization (ImageNet std). |
| | """ |
| |
|
| | model_input_names = ["pixel_values"] |
| |
|
| | def __init__( |
| | self, |
| | do_resize: bool = True, |
| | size: Optional[Dict[str, int]] = None, |
| | resample: PILImageResampling = PILImageResampling.BILINEAR, |
| | do_normalize: bool = True, |
| | image_mean: Optional[List[float]] = None, |
| | image_std: Optional[List[float]] = None, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | size = size if size is not None else {"height": 1024, "width": 1024} |
| | self.do_resize = do_resize |
| | self.size = size |
| | self.resample = resample |
| | self.do_normalize = do_normalize |
| | self.image_mean = ( |
| | image_mean if image_mean is not None else [0.485, 0.456, 0.406] |
| | ) |
| | self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] |
| |
|
| | def resize( |
| | self, |
| | image: Image.Image, |
| | size: Dict[str, int], |
| | resample: PILImageResampling = PILImageResampling.BILINEAR, |
| | ) -> Image.Image: |
| | """Resize image to target size.""" |
| | target_height = size["height"] |
| | target_width = size["width"] |
| | return image.resize((target_width, target_height), resample) |
| |
|
| | def normalize( |
| | self, |
| | image: np.ndarray, |
| | mean: List[float], |
| | std: List[float], |
| | ) -> np.ndarray: |
| | """Normalize image with mean and std.""" |
| | image = image.astype(np.float32) / 255.0 |
| | mean = np.array(mean, dtype=np.float32) |
| | std = np.array(std, dtype=np.float32) |
| | image = (image - mean) / std |
| | return image |
| |
|
| | def preprocess( |
| | self, |
| | images: ImageInput, |
| | do_resize: Optional[bool] = None, |
| | size: Optional[Dict[str, int]] = None, |
| | resample: Optional[PILImageResampling] = None, |
| | do_normalize: Optional[bool] = None, |
| | image_mean: Optional[List[float]] = None, |
| | image_std: Optional[List[float]] = None, |
| | return_tensors: Optional[Union[str, TensorType]] = None, |
| | **kwargs, |
| | ) -> BatchFeature: |
| | """ |
| | Preprocess images for MVANet. |
| | |
| | Args: |
| | images (:obj:`ImageInput`): |
| | Images to preprocess. Can be a single image or a batch of images. |
| | do_resize (:obj:`bool`, `optional`): |
| | Whether to resize the image(s). Defaults to :obj:`self.do_resize`. |
| | size (:obj:`Dict[str, int]`, `optional`): |
| | Target size for resizing. Defaults to :obj:`self.size`. |
| | resample (:obj:`PILImageResampling`, `optional`): |
| | Resampling filter to use. Defaults to :obj:`self.resample`. |
| | do_normalize (:obj:`bool`, `optional`): |
| | Whether to normalize the image(s). Defaults to :obj:`self.do_normalize`. |
| | image_mean (:obj:`List[float]`, `optional`): |
| | Mean for normalization. Defaults to :obj:`self.image_mean`. |
| | image_std (:obj:`List[float]`, `optional`): |
| | Std for normalization. Defaults to :obj:`self.image_std`. |
| | return_tensors (:obj:`str` or :obj:`TensorType`, `optional`): |
| | Type of tensors to return. Can be 'pt' for PyTorch. |
| | |
| | Returns: |
| | :obj:`BatchFeature`: A :obj:`BatchFeature` with the following fields: |
| | - pixel_values (:obj:`torch.Tensor`): Preprocessed images. |
| | """ |
| | |
| | do_resize = do_resize if do_resize is not None else self.do_resize |
| | size = size if size is not None else self.size |
| | resample = resample if resample is not None else self.resample |
| | do_normalize = do_normalize if do_normalize is not None else self.do_normalize |
| | image_mean = image_mean if image_mean is not None else self.image_mean |
| | image_std = image_std if image_std is not None else self.image_std |
| |
|
| | |
| | if not isinstance(images, list): |
| | images = [images] |
| |
|
| | |
| | pil_images = [] |
| | |
| | for img in images: |
| | pil_img = to_pil_image(img) |
| | |
| | if pil_img.mode != "RGB": |
| | pil_img = pil_img.convert("RGB") |
| | |
| | pil_images.append(pil_img) |
| |
|
| | |
| | if do_resize: |
| | pil_images = [self.resize(img, size, resample) for img in pil_images] |
| |
|
| | |
| | np_images = [np.array(img) for img in pil_images] |
| |
|
| | |
| | if do_normalize: |
| | np_images = [ |
| | self.normalize(img, image_mean, image_std) for img in np_images |
| | ] |
| |
|
| | |
| | np_images = [img.transpose(2, 0, 1) for img in np_images] |
| |
|
| | |
| | if return_tensors == "pt": |
| | pixel_values = torch.tensor(np.stack(np_images), dtype=torch.float32) |
| | else: |
| | pixel_values = np.stack(np_images) |
| |
|
| | |
| | data = { |
| | "pixel_values": pixel_values, |
| | |
| | } |
| |
|
| | return BatchFeature(data=data, tensor_type=return_tensors) |
| |
|
| | def post_process_semantic_segmentation( |
| | self, |
| | outputs, |
| | target_sizes: Optional[List[Tuple[int, int]]] = None, |
| | ) -> List[torch.Tensor]: |
| | """ |
| | Post-process model outputs to semantic segmentation masks. |
| | |
| | Args: |
| | outputs (:obj:`SemanticSegmenterOutput` or :obj:`torch.Tensor`): |
| | Model outputs containing logits. |
| | target_sizes (:obj:`List[Tuple[int, int]]`, `optional`): |
| | List of target sizes (width, height) for each image. |
| | If not provided, returns masks at model output size. |
| | |
| | Returns: |
| | :obj:`List[torch.Tensor]`: List of segmentation masks (values in [0, 1]). |
| | """ |
| | |
| | if hasattr(outputs, "logits"): |
| | logits = outputs.logits |
| | else: |
| | logits = outputs |
| |
|
| | |
| | probs = torch.sigmoid(logits) |
| |
|
| | |
| | if target_sizes is not None: |
| | masks = [] |
| | for i, (target_w, target_h) in enumerate(target_sizes): |
| | mask = F.interpolate( |
| | probs[i : i + 1], |
| | size=(target_h, target_w), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | masks.append(mask.squeeze(0).squeeze(0)) |
| | return masks |
| |
|
| | |
| | return [ |
| | probs[i].squeeze(0) for i in range(probs.shape[0]) |
| | ] |
| |
|