Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor | |
| from PIL import Image | |
| import plotly.graph_objects as go | |
| import numpy as np | |
| import os | |
| import torch.nn as nn | |
| from sklearn.metrics import jaccard_score, accuracy_score | |
| from collections import Counter | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import torch.nn.functional as F | |
| import seaborn as sns | |
| from functools import partial | |
| from pytorch_grad_cam.utils.image import ( | |
| show_cam_on_image, | |
| preprocess_image as grad_preprocess, | |
| ) | |
| from pytorch_grad_cam import GradCAM | |
| import cv2 | |
| import transformers | |
| from torchvision import transforms | |
| import albumentations as A | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| data_folder = "data_sample" | |
| id2label = { | |
| 0: "void", | |
| 1: "flat", | |
| 2: "construction", | |
| 3: "object", | |
| 4: "nature", | |
| 5: "sky", | |
| 6: "human", | |
| 7: "vehicle", | |
| } | |
| label2id = {v: k for k, v in id2label.items()} | |
| num_labels = len(id2label) | |
| checkpoint = "nvidia/segformer-b3-finetuned-cityscapes-1024-1024" | |
| image_processor = SegformerImageProcessor(do_resize=False) | |
| state_dict_path = f"runs/{checkpoint}/best_model.pt" | |
| model = SegformerForSemanticSegmentation.from_pretrained( | |
| checkpoint, | |
| num_labels=num_labels, | |
| id2label=id2label, | |
| label2id=label2id, | |
| ignore_mismatched_sizes=True, | |
| ) | |
| loaded_state_dict = torch.load( | |
| state_dict_path, map_location=torch.device("cpu"), weights_only=True | |
| ) | |
| model.load_state_dict(loaded_state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| # ---- Partie Segmentation | |
| def load_and_prepare_images(image_name, segformer=False): | |
| """ | |
| Charge et prépare les images, les masques et les prédictions associées pour une image donnée. | |
| Args: | |
| image_name (str): Le nom du fichier de l'image à charger. | |
| segformer (bool, optional): Si True, prédit également le masque avec SegFormer. Par défaut False. | |
| Returns: | |
| tuple: Contient l'image originale redimensionnée, le masque réel, la prédiction FPN, | |
| et la prédiction SegFormer si `segformer` est True. | |
| """ | |
| image_path = os.path.join(data_folder, "images", image_name) | |
| mask_name = image_name.replace("_leftImg8bit.png", "_gtFine_labelIds.png") | |
| mask_path = os.path.join(data_folder, "masks", mask_name) | |
| fpn_pred_path = os.path.join(data_folder, "resnet101_mask", image_name) | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image not found: {image_path}") | |
| if not os.path.exists(mask_path): | |
| raise FileNotFoundError(f"Mask not found: {mask_path}") | |
| if not os.path.exists(fpn_pred_path): | |
| raise FileNotFoundError(f"FPN prediction not found: {fpn_pred_path}") | |
| original_image = Image.open(image_path).convert("RGB") | |
| original = original_image.resize((1024, 512)) | |
| true_mask = np.array(Image.open(mask_path)) | |
| fpn_pred = np.array(Image.open(fpn_pred_path)) | |
| if segformer: | |
| segformer_pred = predict_segmentation(original) | |
| return original, true_mask, fpn_pred, segformer_pred | |
| return original, true_mask, fpn_pred | |
| def predict_segmentation(image): | |
| """ | |
| Prédit la segmentation d'une image donnée à l'aide d'un modèle pré-entraîné. | |
| Args: | |
| image (PIL.Image.Image): L'image à segmenter. | |
| Returns: | |
| numpy.ndarray: La carte de segmentation prédite. | |
| """ | |
| inputs = image_processor(images=image, return_tensors="pt") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| pixel_values = inputs.pixel_values.to(device) | |
| with torch.no_grad(): | |
| outputs = model(pixel_values=pixel_values) | |
| logits = outputs.logits | |
| upsampled_logits = nn.functional.interpolate( | |
| logits, | |
| size=image.size[::-1], # (height, width) | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy() | |
| return pred_seg | |
| def process_image(image_name): | |
| """ | |
| Traite une image en chargeant l'image originale, le masque réel, et les prédictions de masques. | |
| Envoie la liste de tuple à l'interface "Predictions" de Gradio | |
| Args: | |
| image_name (str): Le nom de l'image à traiter. | |
| Returns: | |
| list: Une liste de tuples contenant l'image et son titre associé. | |
| """ | |
| original, true_mask, fpn_pred, segformer_pred = load_and_prepare_images( | |
| image_name, segformer=True | |
| ) | |
| true_mask_colored = colorize_mask(true_mask) | |
| true_mask_colored = Image.fromarray(true_mask_colored.astype("uint8")) | |
| true_mask_colored = true_mask_colored.resize((1024, 512)) | |
| # fpn_pred_colored = colorize_mask(fpn_pred) | |
| segformer_pred_colored = colorize_mask(segformer_pred) | |
| segformer_pred_colored = Image.fromarray(segformer_pred_colored.astype("uint8")) | |
| segformer_pred_colored = segformer_pred_colored.resize((1024, 512)) | |
| return [ | |
| (original, "Image originale"), | |
| (true_mask_colored, "Masque réel"), | |
| (fpn_pred, "Prédiction FPN"), | |
| (segformer_pred_colored, "Prédiction SegFormer"), | |
| ] | |
| def create_cityscapes_label_colormap(): | |
| """ | |
| Crée une colormap pour les labels Cityscapes. | |
| Returns: | |
| numpy.ndarray: Un tableau 2D où chaque ligne représente la couleur RGB d'un label. | |
| """ | |
| colormap = np.zeros((256, 3), dtype=np.uint8) | |
| colormap[0] = [78, 82, 110] | |
| colormap[1] = [128, 64, 128] | |
| colormap[2] = [154, 156, 153] | |
| colormap[3] = [168, 167, 18] | |
| colormap[4] = [80, 108, 28] | |
| colormap[5] = [112, 164, 196] | |
| colormap[6] = [168, 28, 52] | |
| colormap[7] = [16, 18, 112] | |
| return colormap | |
| # Créer la colormap une fois | |
| cityscapes_colormap = create_cityscapes_label_colormap() | |
| def colorize_mask(mask): | |
| return cityscapes_colormap[mask] | |
| # ---- Fin Partie Segmentation | |
| # ---- Partie EDA | |
| def analyse_mask(real_mask, num_labels): | |
| """ | |
| Analyse la distribution des classes dans un masque réel. | |
| Args: | |
| real_mask (numpy.ndarray): Le masque de labels réels. | |
| num_labels (int): Le nombre total de classes. | |
| Returns: | |
| dict: Un dictionnaire contenant les proportions des classes dans le masque. | |
| """ | |
| counts = np.bincount(real_mask.ravel(), minlength=num_labels) | |
| total_pixels = real_mask.size | |
| class_proportions = counts / total_pixels | |
| return dict(enumerate(class_proportions)) | |
| def show_eda(image_name): | |
| """ | |
| Affiche une analyse exploratoire de la distribution des classes pour une image et son masque associé. | |
| Args: | |
| image_name (str): Le nom de l'image à analyser. | |
| Returns: | |
| tuple: Contient l'image originale, le masque réel coloré et une figure Plotly représentant | |
| la distribution des classes. | |
| """ | |
| original_image, true_mask, _ = load_and_prepare_images(image_name) | |
| class_proportions = analyse_mask(true_mask, num_labels) | |
| cityscapes_colormap = create_cityscapes_label_colormap() | |
| true_mask_colored = colorize_mask(true_mask) | |
| true_mask_colored = Image.fromarray(true_mask_colored.astype("uint8")) | |
| true_mask_colored = true_mask_colored.resize((1024, 512)) | |
| # Trier les classes par proportion croissante | |
| sorted_classes = sorted( | |
| class_proportions.keys(), key=lambda x: class_proportions[x] | |
| ) | |
| # Préparer les données pour le barplot | |
| categories = [id2label[i] for i in sorted_classes] | |
| values = [class_proportions[i] for i in sorted_classes] | |
| color_list = [ | |
| f"rgb({cityscapes_colormap[i][0]}, {cityscapes_colormap[i][1]}, {cityscapes_colormap[i][2]})" | |
| for i in sorted_classes | |
| ] | |
| # Distribution des classes avec la colormap personnalisée | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Bar( | |
| x=categories, | |
| y=values, | |
| marker_color=color_list, | |
| text=[f"{v:.2f}" for v in values], | |
| textposition="outside", | |
| ) | |
| ) | |
| # Ajouter un titre et des labels, modifier la rotation et la taille de la police | |
| fig.update_layout( | |
| title={"text": "Distribution des classes", "font": {"size": 24}}, | |
| xaxis_title={"text": "Catégories", "font": {"size": 18}}, | |
| yaxis_title={"text": "Proportion", "font": {"size": 18}}, | |
| xaxis_tickangle=0, # Rotation modifiée à -45 degrés | |
| uniformtext_minsize=12, | |
| uniformtext_mode="hide", | |
| font=dict(size=14), | |
| autosize=True, | |
| bargap=0.2, | |
| height=600, | |
| margin=dict(l=20, r=20, t=50, b=20), | |
| ) | |
| return original_image, true_mask_colored, fig | |
| # ----Fin Partie EDA | |
| # ----Partie Explication GradCam | |
| class SegformerWrapper(nn.Module): | |
| """ | |
| Un wrapper pour le modèle SegFormer qui renvoie uniquement les logits en sortie. | |
| Args: | |
| model (torch.nn.Module): Le modèle SegFormer pré-entraîné. | |
| """ | |
| def __init__(self, model): | |
| """ | |
| Initialise le SegformerWrapper. | |
| Args: | |
| model (torch.nn.Module): Le modèle SegFormer pré-entraîné. | |
| """ | |
| super().__init__() | |
| self.model = model | |
| def forward(self, x): | |
| """ | |
| Renvoie les logits du modèle au lieu de renvoyer un dictionnaire. | |
| Args: | |
| x (torch.Tensor): Les entrées du modèle. | |
| Returns: | |
| torch.Tensor: Les logits du modèle. | |
| """ | |
| output = self.model(x) | |
| return output.logits | |
| class SemanticSegmentationTarget: | |
| """ | |
| Représente une classe cible pour la segmentation sémantique utilisée dans GradCAM. | |
| Args: | |
| category (int): L'index de la catégorie cible. | |
| mask (numpy.ndarray): Le masque binaire indiquant les pixels d'intérêt. | |
| """ | |
| def __init__(self, category, mask): | |
| """ | |
| Initialise la cible de segmentation sémantique. | |
| Args: | |
| category (int): L'index de la catégorie cible. | |
| mask (numpy.ndarray): Le masque binaire indiquant les pixels d'intérêt. | |
| """ | |
| self.category = category | |
| self.mask = torch.from_numpy(mask) | |
| if torch.cuda.is_available(): | |
| self.mask = self.mask.cuda() | |
| def __call__(self, model_output): | |
| if isinstance( | |
| model_output, (dict, transformers.modeling_outputs.SemanticSegmenterOutput) | |
| ): | |
| logits = ( | |
| model_output["logits"] | |
| if isinstance(model_output, dict) | |
| else model_output.logits | |
| ) | |
| elif isinstance(model_output, torch.Tensor): | |
| logits = model_output | |
| else: | |
| raise ValueError(f"Unexpected model_output type: {type(model_output)}") | |
| if logits.dim() == 4: # [batch, classes, height, width] | |
| return (logits[0, self.category, :, :] * self.mask).sum() | |
| elif logits.dim() == 3: # [classes, height, width] | |
| return (logits[self.category, :, :] * self.mask).sum() | |
| else: | |
| raise ValueError(f"Unexpected logits shape: {logits.shape}") | |
| def segformer_reshape_transform_huggingface(tensor, width, height): | |
| """ | |
| Réorganise les dimensions du tenseur pour qu'elles correspondent au format attendu par GradCAM. | |
| Args: | |
| tensor (torch.Tensor): Le tenseur à réorganiser. | |
| width (int): La nouvelle largeur. | |
| height (int): La nouvelle hauteur. | |
| Returns: | |
| torch.Tensor: Le tenseur réorganisé. | |
| """ | |
| result = tensor.reshape(tensor.size(0), height, width, tensor.size(2)) | |
| result = result.transpose(2, 3).transpose(1, 2) | |
| return result | |
| def explain_model(image_name, category_name): | |
| """ | |
| Explique les prédictions du modèle SegFormer en utilisant GradCAM pour une image et une catégorie données. | |
| Args: | |
| image_name (str): Le nom de l'image à expliquer. | |
| category_name (str): Le nom de la catégorie cible. | |
| Returns: | |
| matplotlib.figure.Figure: Une figure matplotlib contenant la carte de chaleur GradCAM superposée sur l'image originale. | |
| """ | |
| original_image, _, _ = load_and_prepare_images(image_name) | |
| rgb_img = np.float32(original_image) / 255 | |
| img_tensor = transforms.ToTensor()(rgb_img) | |
| input_tensor = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| )(img_tensor) | |
| input_tensor = input_tensor.unsqueeze(0).to(device) | |
| wrapped_model = SegformerWrapper(model).to(device) | |
| with torch.no_grad(): | |
| output = wrapped_model(input_tensor) | |
| upsampled_logits = nn.functional.interpolate( | |
| output, size=input_tensor.shape[-2:], mode="bilinear", align_corners=False | |
| ) | |
| normalized_masks = torch.nn.functional.softmax(upsampled_logits, dim=1).cpu() | |
| category = label2id[category_name] | |
| mask = normalized_masks[0].argmax(dim=0).numpy() | |
| mask_float = np.float32(mask == category) | |
| reshape_transform = partial( | |
| segformer_reshape_transform_huggingface, # réorganise les dimensions du tenseur pour qu'elles correspondent au format attendu par GradCAM. | |
| width=img_tensor.shape[2] // 32, | |
| height=img_tensor.shape[1] // 32, | |
| ) | |
| target_layers = [wrapped_model.model.segformer.encoder.layer_norm[-1]] | |
| mask_float_resized = cv2.resize(mask_float, (output.shape[3], output.shape[2])) | |
| targets = [SemanticSegmentationTarget(category, mask_float_resized)] | |
| cam = GradCAM( | |
| model=wrapped_model, | |
| target_layers=target_layers, | |
| reshape_transform=reshape_transform, | |
| ) | |
| grayscale_cam = cam(input_tensor=input_tensor, targets=targets) | |
| threshold = 0.01 # Seuil de 1% de sureté | |
| thresholded_cam = grayscale_cam.copy() | |
| thresholded_cam[grayscale_cam < threshold] = 0 | |
| if np.max(thresholded_cam) > 0: | |
| thresholded_cam = thresholded_cam / np.max(thresholded_cam) | |
| else: | |
| thresholded_cam = grayscale_cam[0] | |
| resized_cam = cv2.resize( | |
| thresholded_cam[0], (input_tensor.shape[3], input_tensor.shape[2]) | |
| ) | |
| masked_cam = resized_cam * mask_float | |
| if np.max(masked_cam) > 0: | |
| cam_image = show_cam_on_image(rgb_img, masked_cam, use_rgb=True) | |
| else: | |
| cam_image = original_image | |
| fig, ax = plt.subplots(figsize=(15, 10)) | |
| ax.imshow(cam_image) | |
| ax.axis("off") | |
| ax.set_title(f"Masque de chaleur GradCam pour {category_name}", color="white") | |
| margin = 0.02 # Adjust this value to change the size of the margin | |
| margin_color = "#0a0f1e" | |
| fig.subplots_adjust(left=margin, right=1 - margin, top=1 - margin, bottom=margin) | |
| fig.patch.set_facecolor(margin_color) | |
| plt.close() | |
| return fig | |
| # ----Fin Partie Explication GradCam | |
| # ----Partie Data augmentation | |
| import random | |
| def change_image(): | |
| """ | |
| Sélectionne et charge aléatoirement une image depuis un dossier spécifié. | |
| Returns: | |
| PIL.Image.Image: L'image sélectionnée. | |
| """ | |
| image_dir = ( | |
| "data_sample/images" # Remplacez par le chemin de votre dossier d'images | |
| ) | |
| image_list = [f for f in os.listdir(image_dir) if f.endswith(".png")] | |
| random_image = random.choice(image_list) | |
| return Image.open(os.path.join(image_dir, random_image)) | |
| def apply_augmentation(image, augmentation_names): | |
| """ | |
| Applique une ou plusieurs augmentations à une image. | |
| Args: | |
| image (PIL.Image.Image): L'image à augmenter. | |
| augmentation_names (list of str): Les noms des augmentations à appliquer. | |
| Returns: | |
| PIL.Image.Image: L'image augmentée. | |
| """ | |
| augmentations = { | |
| "Horizontal Flip": A.HorizontalFlip(p=1), | |
| "Shift Scale Rotate": A.ShiftScaleRotate(p=1), | |
| "Random Brightness Contrast": A.RandomBrightnessContrast(p=1), | |
| "RGB Shift": A.RGBShift(p=1), | |
| "Blur": A.Blur(blur_limit=(5, 7), p=1), | |
| "Gaussian Noise": A.GaussNoise(p=1), | |
| "Grid Distortion": A.GridDistortion(p=1), | |
| "Random Sun": A.RandomSunFlare(p=1), | |
| } | |
| image_array = np.array(image) | |
| if augmentation_names is not None: | |
| selected_augs = [ | |
| augmentations[name] for name in augmentation_names if name in augmentations | |
| ] | |
| compose = A.Compose(selected_augs) | |
| # Appliquer la composition d'augmentations | |
| augmented = compose(image=image_array) | |
| return Image.fromarray(augmented["image"]) | |
| else: | |
| return image | |
| # ---- Fin Partie Data augmentation | |
| image_list = [ | |
| f for f in os.listdir(os.path.join(data_folder, "images")) if f.endswith(".png") | |
| ] | |
| category_list = list(id2label.values()) | |
| image_name = "dusseldorf_000012_000019_leftImg8bit.png" | |
| default_image = os.path.join(data_folder, "images", image_name) | |
| my_theme = gr.Theme.from_hub("gstaff/whiteboard") | |
| with gr.Blocks(title="Preuve de concept", theme=my_theme) as demo: | |
| gr.Markdown("# Projet 10 - Développer une preuve de concept") | |
| with gr.Tab("Distribution"): | |
| gr.Markdown("## Distribution des classes Cityscapes") | |
| gr.Markdown( | |
| "### Visualisation de la distribution de chaque classe selon l'image choisie." | |
| ) | |
| eda_image_input = gr.Dropdown( | |
| choices=image_list, | |
| label="Sélectionnez une image", | |
| ) | |
| with gr.Row(): | |
| original_image_output = gr.Image(type="pil", label="Image originale") | |
| original_mask_output = gr.Image(type="pil", label="Masque original") | |
| class_distribution_plot = gr.Plot(label="Distribution des classes") | |
| eda_image_input.change( | |
| fn=show_eda, | |
| inputs=eda_image_input, | |
| outputs=[ | |
| original_image_output, | |
| original_mask_output, | |
| class_distribution_plot, | |
| ], | |
| ) | |
| with gr.Tab("Data Augmentation"): | |
| gr.Markdown("## Visualisation de l'augmentation des données") | |
| gr.Markdown( | |
| "### Sélectionnez une ou plusieurs augmentations pour l'appliquer à l'image." | |
| ) | |
| gr.Markdown("### Vous pouvez également changer d'image.") | |
| with gr.Row(): | |
| image_display = gr.Image( | |
| value=default_image, | |
| label="Image", | |
| show_download_button=False, | |
| interactive=False, | |
| ) | |
| augmented_image = gr.Image(label="Image Augmentée") | |
| with gr.Row(): | |
| change_image_button = gr.Button("Changer image") | |
| augmentation_dropdown = gr.Dropdown( | |
| choices=[ | |
| "Horizontal Flip", | |
| "Shift Scale Rotate", | |
| "Random Brightness Contrast", | |
| "RGB Shift", | |
| "Blur", | |
| "Gaussian Noise", | |
| "Grid Distortion", | |
| "Random Sun", | |
| ], | |
| label="Sélectionnez une augmentation", | |
| multiselect=True, | |
| ) | |
| apply_button = gr.Button("Appliquer l'augmentation") | |
| change_image_button.click(fn=change_image, outputs=image_display) | |
| apply_button.click( | |
| fn=apply_augmentation, | |
| inputs=[image_display, augmentation_dropdown], | |
| outputs=augmented_image, | |
| ) | |
| with gr.Tab("Prédictions"): | |
| gr.Markdown("## Comparaison de segmentations d'images Cityscapes") | |
| gr.Markdown( | |
| "### Sélectionnez une image pour voir la comparaison entre le masque réel, la prédiction FPN (pré-enregistré) et la prédiction du modèle SegFormer." | |
| ) | |
| image_input = gr.Dropdown(choices=image_list, label="Sélectionnez une image") | |
| gallery_output = gr.Gallery( | |
| label="Résultats de segmentation", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=[2], | |
| rows=[2], | |
| object_fit="contain", | |
| height="512px", | |
| min_width="1024px", | |
| ) | |
| image_input.change(fn=process_image, inputs=image_input, outputs=gallery_output) | |
| with gr.Tab("Explication SegFormer"): | |
| gr.Markdown("## Explication du modèle SegFormer") | |
| gr.Markdown( | |
| "### La méthode Grad-CAM est une technique populaire de visualisation qui est utile pour comprendre comment un réseau neuronal convolutif a été conduit à prendre une décision de classification. Elle est spécifique à chaque classe, ce qui signifie qu’elle peut produire une visualisation distincte pour chaque classe présente dans l’image." | |
| ) | |
| gr.Markdown( | |
| "### NB: Si l'image s'affiche sans masque, c'est que le modèle ne trouve pas de zones significatives pour une catégorie donnée." | |
| ) | |
| with gr.Row(): | |
| explain_image_input = gr.Dropdown( | |
| choices=image_list, label="Sélectionnez une image" | |
| ) | |
| explain_category_input = gr.Dropdown( | |
| choices=category_list, label="Sélectionnez une catégorie" | |
| ) | |
| explain_button = gr.Button("Expliquer") | |
| explain_output = gr.Plot(label="Explication SegFormer", min_width=200) | |
| explain_button.click( | |
| fn=explain_model, | |
| inputs=[explain_image_input, explain_category_input], | |
| outputs=explain_output, | |
| ) | |
| # Lancer l'application | |
| demo.launch(favicon_path="favicon.ico") | |