SkyCLIP-ViT-L-14

This repository provides a Hugging Face-format vision encoder (ViT-L-14) port of the original SkyCLIP model.

Original Repository and Links

Description

SkyCLIP is a vision-language foundation model for remote sensing, trained on a large-scale image-text dataset collected for aerial and satellite imagery. This repository only provides the ViT-L-14 vision encoder, converted from the original timm checkpoint to Hugging Face CLIPVisionModel format.

Preprocessing

The recommended image transforms are as follows:

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=True),
    CenterCrop(size=(224, 224)),
    ToTensor(),
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

How to use

With transformers

from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch

# Load model and processor
model = CLIPModel.from_pretrained("BiliSakura/SkyCLIP-ViT-L-14")
processor = CLIPProcessor.from_pretrained("BiliSakura/SkyCLIP-ViT-L-14")

# Load and process image
image = Image.open("path/to/your/image.jpg")
inputs = processor(
    text=["a photo of a building", "a photo of vegetation", "a photo of water"],
    images=image,
    return_tensors="pt",
    padding=True
)

# Get image-text similarity scores
with torch.inference_mode():
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)

print(f"Similarity scores: {probs}")

Zero-shot image classification:

from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch

model = CLIPModel.from_pretrained("BiliSakura/SkyCLIP-ViT-L-14")
processor = CLIPProcessor.from_pretrained("BiliSakura/SkyCLIP-ViT-L-14")

# Define candidate labels
candidate_labels = [
    "a satellite image of urban area",
    "a satellite image of forest",
    "a satellite image of agricultural land",
    "a satellite image of water body"
]

image = Image.open("path/to/your/image.jpg")
inputs = processor(
    text=candidate_labels,
    images=image,
    return_tensors="pt",
    padding=True
)

with torch.inference_mode():
    outputs = model(**inputs)
    probs = outputs.logits_per_image.softmax(dim=1)

# Get the predicted label
predicted_idx = probs.argmax().item()
print(f"Predicted label: {candidate_labels[predicted_idx]}")
print(f"Confidence: {probs[0][predicted_idx]:.4f}")

Extracting individual features:

from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch

model = CLIPModel.from_pretrained("BiliSakura/SkyCLIP-ViT-L-14")
processor = CLIPProcessor.from_pretrained("BiliSakura/SkyCLIP-ViT-L-14")

# Get image features only
image = Image.open("path/to/your/image.jpg")
image_inputs = processor(images=image, return_tensors="pt")

with torch.inference_mode():
    image_features = model.get_image_features(**image_inputs)

# Get text features only
text_inputs = processor(
    text=["a satellite image of urban area"],
    return_tensors="pt",
    padding=True,
    truncation=True
)

with torch.inference_mode():
    text_features = model.get_text_features(**text_inputs)

print(f"Image features shape: {image_features.shape}")
print(f"Text features shape: {text_features.shape}")

With diffusers

This model's text encoder can be used with Stable Diffusion and other diffusion models:

from diffusers import StableDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
import torch

# Load the text encoder and tokenizer
text_encoder = CLIPTextModel.from_pretrained(
    "BiliSakura/SkyCLIP-ViT-L-14/diffusers",
    subfolder="text_encoder",
    torch_dtype=torch.float16
)
tokenizer = CLIPTokenizer.from_pretrained(
    "BiliSakura/SkyCLIP-ViT-L-14"
)

# Encode text prompt
prompt = "a satellite image of a city with buildings and roads"
text_inputs = tokenizer(
    prompt,
    padding="max_length",
    max_length=77,
    truncation=True,
    return_tensors="pt"
)

with torch.inference_mode():
    text_outputs = text_encoder(text_inputs.input_ids)
    text_embeddings = text_outputs.last_hidden_state

print(f"Text embeddings shape: {text_embeddings.shape}")

Using with Stable Diffusion:

from diffusers import StableDiffusionPipeline
import torch

# Load pipeline with custom text encoder
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

# Generate image
prompt = "a high-resolution satellite image of urban area"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("generated_image.png")

Citation

If you use the SkyCLIP model or dataset, please cite the original work:

@article{wangSkyScriptLargeSemantically2024,
  title = {{{SkyScript}}: {{A Large}} and {{Semantically Diverse Vision-Language Dataset}} for {{Remote Sensing}}},
  shorttitle = {{{SkyScript}}},
  author = {Wang, Zhecheng and Prabha, Rajanie and Huang, Tianyuan and Wu, Jiajun and Rajagopal, Ram},
  year = 2024,
  month = mar,
  journal = {Proceedings of the AAAI Conference on Artificial Intelligence},
  volume = {38},
  number = {6},
  pages = {5805--5813},
  issn = {2374-3468},
  doi = {10.1609/aaai.v38i6.28393},
  urldate = {2024-07-06},
  copyright = {Copyright (c) 2024 Association for the Advancement of Artificial Intelligence},
  keywords = {ML: Multimodal Learning},
  annotation = {CCF: A}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including BiliSakura/SkyCLIP-ViT-L-14