import gradio as gr import json import base64 import requests import time import os from dotenv import load_dotenv import numpy as np from PIL import Image import io # Load API key from .env file load_dotenv() API_KEY = os.getenv('API_KEY') CURRENT_URL = os.getenv('CURRENT_URL') # API endpoints TRYON_URL = CURRENT_URL + 'api/tryon/' FETCH_URL = CURRENT_URL + 'api/tryon_state/' # Headers for API requests headers = { 'Authorization': 'Bearer ' + API_KEY, 'Content-Type': 'application/json', } # Create example directories if they don't exist os.makedirs("examples/garments", exist_ok=True) os.makedirs("examples/persons", exist_ok=True) # Paths to example images (you'll need to add these files) sample_garments = [ "samples/garments/g1.jpg", "samples/garments/g2.jpg", ] sample_humans = [ "samples/humans/h1.jpg", "samples/humans/h2.jpg", ] def preprocess_image(img, target_size=None): """Preprocess image without resizing if target_size is None""" if img is None: return None # Convert numpy array to PIL Image if it's a numpy array if isinstance(img, np.ndarray): img = Image.fromarray(img.astype('uint8')) # Only resize if target_size is specified if target_size is not None: img = img.resize(target_size, Image.LANCZOS) return img def virtual_tryon(garment_img, person_img): # Convert images to base64 if person_img is None or garment_img is None: return None # Preprocess images without resizing human_pil = preprocess_image(person_img) garment_pil = preprocess_image(garment_img) human_buffer = io.BytesIO() garment_buffer = io.BytesIO() human_pil.save(human_buffer, format="JPEG") garment_pil.save(garment_buffer, format="JPEG") human_base64_image = base64.b64encode(human_buffer.getvalue()).decode('utf-8') garment_base64_image = base64.b64encode(garment_buffer.getvalue()).decode('utf-8') # Prepare data for API request data = { 'human_image_base64': human_base64_image, 'garment_image_base64': garment_base64_image, } # Make API request to start tryon process response = requests.post(TRYON_URL, headers=headers, data=json.dumps(data)) if response.status_code != 200: return None json_response = response.json() tryon_pk = json_response['tryon_pk'] # Poll for result time_elapsed = 0 while time_elapsed < 60: # Timeout after 60 seconds fetch_response = requests.post(FETCH_URL, headers=headers, data=json.dumps({ 'tryon_pk': tryon_pk, })) if fetch_response.status_code != 200: return None json_response = fetch_response.json() if json_response.get('message') != 'success': return None if json_response.get('status') == 'done': # Download the result image result_url = json_response['s3_url'] img_response = requests.get(result_url) if img_response.status_code == 200: return Image.open(io.BytesIO(img_response.content)) time.sleep(2) time_elapsed += 2 return None custom_css = """ body, .gradio-container { font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; background-color: #121212; color: white; } h1, h2, h3 { color: white !important; } .container { max-width: 1200px; margin: 0 auto; } .image-container img { object-fit: contain; max-height: 450px; width: auto; margin: 0 auto; display: block; border-radius: 8px; } .examples-container { display: grid; grid-template-columns: repeat(auto-fill, minmax(150px, 1fr)); gap: 10px; margin-top: 10px; } .examples-container img { height: 120px; object-fit: cover; border-radius: 8px; cursor: pointer; transition: transform 0.2s; } .examples-container img:hover { transform: scale(1.05); } button#try-on-button { background-color: #FF6B00 !important; color: white !important; border: none !important; padding: 12px 20px !important; font-weight: 600 !important; border-radius: 8px !important; cursor: pointer !important; transition: background-color 0.3s !important; } button#try-on-button:hover { background-color: #FF8C33 !important; } footer {visibility: hidden} """ # Create Gradio interface with gr.Blocks(theme=gr.themes.Base(), css=custom_css) as demo: gr.HTML("

AlphaBakeVirtual Try-On

") with gr.Row(): # First column - Garment with gr.Column(scale=1): gr.Markdown("### Garment Image") garment_input = gr.Image( label="Upload a garment image", type="pil", elem_id="garment-image", elem_classes=["image-container"], height=350 ) # Add example garment images gr.Examples( examples=sample_garments, inputs=garment_input, label="Garment Examples", examples_per_page=4 ) # Second column - Person with gr.Column(scale=1): gr.Markdown("### Person Image") person_input = gr.Image( label="Upload a person image", type="pil", elem_id="person-image", elem_classes=["image-container"], height=350 ) # Add example person images gr.Examples( examples=sample_humans, inputs=person_input, label="Person Examples", examples_per_page=4 ) # Third column - Garment options & result with gr.Column(scale=1): # Try-on button try_on_button = gr.Button("Try On", elem_id="try-on-button", variant="primary", size="lg") # Result image output_image = gr.Image( label="Result", type="pil", elem_classes=["result-image"], height=400 ) # Validation function def validate_inputs(garment_img, person_img, garment_type, sleeve_length, garment_length): if garment_img is None: raise gr.Error("Please upload a garment image") if person_img is None: raise gr.Error("Please upload a person image") # If all validations pass, proceed with try-on try: result = virtual_tryon(garment_img, person_img) return result except Exception as e: raise gr.Error(f"Error: {str(e)}") # Connect button to validation and try-on functions try_on_button.click( fn=validate_inputs, inputs=[garment_input, person_input], outputs=output_image ) if __name__ == "__main__": demo.launch()