Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from huggingface_hub import from_pretrained_keras | |
| from PIL import Image | |
| MODEL_CKPT = "chansung/segmentation-training-pipeline@v1667722548" | |
| MODEL = from_pretrained_keras(MODEL_CKPT) | |
| RESOLTUION = 128 | |
| PETS_PALETTE = [] | |
| with open(r"./palette.txt", "r") as fp: | |
| for line in fp: | |
| if "#" not in line: | |
| tmp_list = list(map(int, line[:-1].strip("][").split(", "))) | |
| PETS_PALETTE.append(tmp_list) | |
| def preprocess_input(image: Image) -> tf.Tensor: | |
| image = np.array(image) | |
| image = tf.convert_to_tensor(image) | |
| image = tf.image.resize(image, (RESOLTUION, RESOLTUION)) | |
| image = image / 255 | |
| return tf.expand_dims(image, 0) | |
| # The below utility get_seg_overlay() are from: | |
| # https://github.com/deep-diver/semantic-segmentation-ml-pipeline/blob/main/notebooks/inference_from_SavedModel.ipynb | |
| def get_seg_overlay(image, seg): | |
| color_seg = np.zeros( | |
| (seg.shape[0], seg.shape[1], 3), dtype=np.uint8 | |
| ) # height, width, 3 | |
| palette = np.array(PETS_PALETTE) | |
| for label, color in enumerate(palette): | |
| color_seg[seg == label, :] = color | |
| # Show image + mask | |
| img = np.array(image) * 0.5 + color_seg * 0.5 | |
| img *= 255 | |
| img = np.clip(img, 0, 255) | |
| img = img.astype(np.uint8) | |
| return img | |
| def run_model(image: Image) -> tf.Tensor: | |
| preprocessed_image = preprocess_input(image) | |
| prediction = MODEL.predict(preprocessed_image) | |
| seg_mask = tf.math.argmax(prediction, -1) | |
| seg_mask = tf.squeeze(seg_mask) | |
| return seg_mask | |
| def get_predictions(image: Image): | |
| predicted_segmentation_mask = run_model(image) | |
| preprocessed_image = preprocess_input(image) | |
| preprocessed_image = tf.squeeze(preprocessed_image, 0) | |
| pred_img = get_seg_overlay( | |
| preprocessed_image.numpy(), predicted_segmentation_mask.numpy() | |
| ) | |
| return Image.fromarray(pred_img) | |
| title = ( | |
| "Simple demo for a semantic segmentation model trained on the PETS dataset." | |
| ) | |
| description = """ | |
| Note that the outputs obtained in this demo won't be state-of-the-art. The underlying project has a different objective focusing more on the ops side of | |
| deploying a semantic segmentation model. For more details, check out the repository: https://github.com/deep-diver/semantic-segmentation-ml-pipeline/. | |
| """ | |
| demo = gr.Interface( | |
| get_predictions, | |
| gr.inputs.Image(type="pil"), | |
| "pil", | |
| allow_flagging="never", | |
| title=title, | |
| description=description, | |
| examples=[["test-image1.png"], ["test-image2.png"], ["test-image3.png"], ["test-image4.png"], ["test-image5.png"]], | |
| ) | |
| demo.launch() | |