import gradio as gr import json import base64 import requests import time import os from dotenv import load_dotenv import numpy as np from PIL import Image import io # Load API key from .env file load_dotenv() API_KEY = os.getenv('API_KEY') CURRENT_URL = os.getenv('CURRENT_URL') # API endpoints TRYON_URL = CURRENT_URL + 'api/tryon/' FETCH_URL = CURRENT_URL + 'api/tryon_state/' # Headers for API requests headers = { 'Authorization': 'Bearer ' + API_KEY, 'Content-Type': 'application/json', } # Create example directories if they don't exist os.makedirs("examples/garments", exist_ok=True) os.makedirs("examples/persons", exist_ok=True) # Paths to example images (you'll need to add these files) sample_garments = [ "samples/garments/g1.jpg", "samples/garments/g2.jpg", ] sample_humans = [ "samples/humans/h1.jpg", "samples/humans/h2.jpg", ] def preprocess_image(img, target_size=None): """Preprocess image without resizing if target_size is None""" if img is None: return None # Convert numpy array to PIL Image if it's a numpy array if isinstance(img, np.ndarray): img = Image.fromarray(img.astype('uint8')) # Only resize if target_size is specified if target_size is not None: img = img.resize(target_size, Image.LANCZOS) return img def virtual_tryon(garment_img, person_img): # Convert images to base64 if person_img is None or garment_img is None: return None # Preprocess images without resizing human_pil = preprocess_image(person_img) garment_pil = preprocess_image(garment_img) human_buffer = io.BytesIO() garment_buffer = io.BytesIO() human_pil.save(human_buffer, format="JPEG") garment_pil.save(garment_buffer, format="JPEG") human_base64_image = base64.b64encode(human_buffer.getvalue()).decode('utf-8') garment_base64_image = base64.b64encode(garment_buffer.getvalue()).decode('utf-8') # Prepare data for API request data = { 'human_image_base64': human_base64_image, 'garment_image_base64': garment_base64_image, } # Make API request to start tryon process response = requests.post(TRYON_URL, headers=headers, data=json.dumps(data)) if response.status_code != 200: return None json_response = response.json() tryon_pk = json_response['tryon_pk'] # Poll for result time_elapsed = 0 while time_elapsed < 60: # Timeout after 60 seconds fetch_response = requests.post(FETCH_URL, headers=headers, data=json.dumps({ 'tryon_pk': tryon_pk, })) if fetch_response.status_code != 200: return None json_response = fetch_response.json() if json_response.get('message') != 'success': return None if json_response.get('status') == 'done': # Download the result image result_url = json_response['s3_url'] img_response = requests.get(result_url) if img_response.status_code == 200: return Image.open(io.BytesIO(img_response.content)) time.sleep(2) time_elapsed += 2 return None custom_css = """ body, .gradio-container { font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; background-color: #121212; color: white; } h1, h2, h3 { color: white !important; } .container { max-width: 1200px; margin: 0 auto; } .image-container img { object-fit: contain; max-height: 450px; width: auto; margin: 0 auto; display: block; border-radius: 8px; } .examples-container { display: grid; grid-template-columns: repeat(auto-fill, minmax(150px, 1fr)); gap: 10px; margin-top: 10px; } .examples-container img { height: 120px; object-fit: cover; border-radius: 8px; cursor: pointer; transition: transform 0.2s; } .examples-container img:hover { transform: scale(1.05); } button#try-on-button { background-color: #FF6B00 !important; color: white !important; border: none !important; padding: 12px 20px !important; font-weight: 600 !important; border-radius: 8px !important; cursor: pointer !important; transition: background-color 0.3s !important; } button#try-on-button:hover { background-color: #FF8C33 !important; } footer {visibility: hidden} """ # Create Gradio interface with gr.Blocks(theme=gr.themes.Base(), css=custom_css) as demo: gr.HTML("