Ubuntu
main working commit
cda9a4d
import gradio as gr
import json
import base64
import requests
import time
import os
from dotenv import load_dotenv
import numpy as np
from PIL import Image
import io
# Load API key from .env file
load_dotenv()
API_KEY = os.getenv('API_KEY')
CURRENT_URL = os.getenv('CURRENT_URL')
# API endpoints
TRYON_URL = CURRENT_URL + 'api/tryon/'
FETCH_URL = CURRENT_URL + 'api/tryon_state/'
# Headers for API requests
headers = {
'Authorization': 'Bearer ' + API_KEY,
'Content-Type': 'application/json',
}
# Create example directories if they don't exist
os.makedirs("examples/garments", exist_ok=True)
os.makedirs("examples/persons", exist_ok=True)
# Paths to example images (you'll need to add these files)
sample_garments = [
"samples/garments/g1.jpg",
"samples/garments/g2.jpg",
]
sample_humans = [
"samples/humans/h1.jpg",
"samples/humans/h2.jpg",
]
def preprocess_image(img, target_size=None):
"""Preprocess image without resizing if target_size is None"""
if img is None:
return None
# Convert numpy array to PIL Image if it's a numpy array
if isinstance(img, np.ndarray):
img = Image.fromarray(img.astype('uint8'))
# Only resize if target_size is specified
if target_size is not None:
img = img.resize(target_size, Image.LANCZOS)
return img
def virtual_tryon(garment_img, person_img):
# Convert images to base64
if person_img is None or garment_img is None:
return None
# Preprocess images without resizing
human_pil = preprocess_image(person_img)
garment_pil = preprocess_image(garment_img)
human_buffer = io.BytesIO()
garment_buffer = io.BytesIO()
human_pil.save(human_buffer, format="JPEG")
garment_pil.save(garment_buffer, format="JPEG")
human_base64_image = base64.b64encode(human_buffer.getvalue()).decode('utf-8')
garment_base64_image = base64.b64encode(garment_buffer.getvalue()).decode('utf-8')
# Prepare data for API request
data = {
'human_image_base64': human_base64_image,
'garment_image_base64': garment_base64_image,
}
# Make API request to start tryon process
response = requests.post(TRYON_URL, headers=headers, data=json.dumps(data))
if response.status_code != 200:
return None
json_response = response.json()
tryon_pk = json_response['tryon_pk']
# Poll for result
time_elapsed = 0
while time_elapsed < 60: # Timeout after 60 seconds
fetch_response = requests.post(FETCH_URL, headers=headers, data=json.dumps({
'tryon_pk': tryon_pk,
}))
if fetch_response.status_code != 200:
return None
json_response = fetch_response.json()
if json_response.get('message') != 'success':
return None
if json_response.get('status') == 'done':
# Download the result image
result_url = json_response['s3_url']
img_response = requests.get(result_url)
if img_response.status_code == 200:
return Image.open(io.BytesIO(img_response.content))
time.sleep(2)
time_elapsed += 2
return None
custom_css = """
body, .gradio-container {
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
background-color: #121212;
color: white;
}
h1, h2, h3 {
color: white !important;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
.image-container img {
object-fit: contain;
max-height: 450px;
width: auto;
margin: 0 auto;
display: block;
border-radius: 8px;
}
.examples-container {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(150px, 1fr));
gap: 10px;
margin-top: 10px;
}
.examples-container img {
height: 120px;
object-fit: cover;
border-radius: 8px;
cursor: pointer;
transition: transform 0.2s;
}
.examples-container img:hover {
transform: scale(1.05);
}
button#try-on-button {
background-color: #FF6B00 !important;
color: white !important;
border: none !important;
padding: 12px 20px !important;
font-weight: 600 !important;
border-radius: 8px !important;
cursor: pointer !important;
transition: background-color 0.3s !important;
}
button#try-on-button:hover {
background-color: #FF8C33 !important;
}
footer {visibility: hidden}
"""
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Base(), css=custom_css) as demo:
gr.HTML("<h1 style='text-align: center; margin-bottom: 20px;'>AlphaBakeVirtual Try-On</h1>")
with gr.Row():
# First column - Garment
with gr.Column(scale=1):
gr.Markdown("### Garment Image")
garment_input = gr.Image(
label="Upload a garment image",
type="pil",
elem_id="garment-image",
elem_classes=["image-container"],
height=350
)
# Add example garment images
gr.Examples(
examples=sample_garments,
inputs=garment_input,
label="Garment Examples",
examples_per_page=4
)
# Second column - Person
with gr.Column(scale=1):
gr.Markdown("### Person Image")
person_input = gr.Image(
label="Upload a person image",
type="pil",
elem_id="person-image",
elem_classes=["image-container"],
height=350
)
# Add example person images
gr.Examples(
examples=sample_humans,
inputs=person_input,
label="Person Examples",
examples_per_page=4
)
# Third column - Garment options & result
with gr.Column(scale=1):
# Try-on button
try_on_button = gr.Button("Try On", elem_id="try-on-button", variant="primary", size="lg")
# Result image
output_image = gr.Image(
label="Result",
type="pil",
elem_classes=["result-image"],
height=400
)
# Validation function
def validate_inputs(garment_img, person_img, garment_type, sleeve_length, garment_length):
if garment_img is None:
raise gr.Error("Please upload a garment image")
if person_img is None:
raise gr.Error("Please upload a person image")
# If all validations pass, proceed with try-on
try:
result = virtual_tryon(garment_img, person_img)
return result
except Exception as e:
raise gr.Error(f"Error: {str(e)}")
# Connect button to validation and try-on functions
try_on_button.click(
fn=validate_inputs,
inputs=[garment_input, person_input],
outputs=output_image
)
if __name__ == "__main__":
demo.launch()