Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import base64 | |
| import requests | |
| import time | |
| import os | |
| from dotenv import load_dotenv | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| # Load API key from .env file | |
| load_dotenv() | |
| API_KEY = os.getenv('API_KEY') | |
| CURRENT_URL = os.getenv('CURRENT_URL') | |
| # API endpoints | |
| TRYON_URL = CURRENT_URL + 'api/tryon/' | |
| FETCH_URL = CURRENT_URL + 'api/tryon_state/' | |
| # Headers for API requests | |
| headers = { | |
| 'Authorization': 'Bearer ' + API_KEY, | |
| 'Content-Type': 'application/json', | |
| } | |
| # Create example directories if they don't exist | |
| os.makedirs("examples/garments", exist_ok=True) | |
| os.makedirs("examples/persons", exist_ok=True) | |
| # Paths to example images (you'll need to add these files) | |
| sample_garments = [ | |
| "samples/garments/g1.jpg", | |
| "samples/garments/g2.jpg", | |
| ] | |
| sample_humans = [ | |
| "samples/humans/h1.jpg", | |
| "samples/humans/h2.jpg", | |
| ] | |
| def preprocess_image(img, target_size=None): | |
| """Preprocess image without resizing if target_size is None""" | |
| if img is None: | |
| return None | |
| # Convert numpy array to PIL Image if it's a numpy array | |
| if isinstance(img, np.ndarray): | |
| img = Image.fromarray(img.astype('uint8')) | |
| # Only resize if target_size is specified | |
| if target_size is not None: | |
| img = img.resize(target_size, Image.LANCZOS) | |
| return img | |
| def virtual_tryon(garment_img, person_img): | |
| # Convert images to base64 | |
| if person_img is None or garment_img is None: | |
| return None | |
| # Preprocess images without resizing | |
| human_pil = preprocess_image(person_img) | |
| garment_pil = preprocess_image(garment_img) | |
| human_buffer = io.BytesIO() | |
| garment_buffer = io.BytesIO() | |
| human_pil.save(human_buffer, format="JPEG") | |
| garment_pil.save(garment_buffer, format="JPEG") | |
| human_base64_image = base64.b64encode(human_buffer.getvalue()).decode('utf-8') | |
| garment_base64_image = base64.b64encode(garment_buffer.getvalue()).decode('utf-8') | |
| # Prepare data for API request | |
| data = { | |
| 'human_image_base64': human_base64_image, | |
| 'garment_image_base64': garment_base64_image, | |
| } | |
| # Make API request to start tryon process | |
| response = requests.post(TRYON_URL, headers=headers, data=json.dumps(data)) | |
| if response.status_code != 200: | |
| return None | |
| json_response = response.json() | |
| tryon_pk = json_response['tryon_pk'] | |
| # Poll for result | |
| time_elapsed = 0 | |
| while time_elapsed < 60: # Timeout after 60 seconds | |
| fetch_response = requests.post(FETCH_URL, headers=headers, data=json.dumps({ | |
| 'tryon_pk': tryon_pk, | |
| })) | |
| if fetch_response.status_code != 200: | |
| return None | |
| json_response = fetch_response.json() | |
| if json_response.get('message') != 'success': | |
| return None | |
| if json_response.get('status') == 'done': | |
| # Download the result image | |
| result_url = json_response['s3_url'] | |
| img_response = requests.get(result_url) | |
| if img_response.status_code == 200: | |
| return Image.open(io.BytesIO(img_response.content)) | |
| time.sleep(2) | |
| time_elapsed += 2 | |
| return None | |
| custom_css = """ | |
| body, .gradio-container { | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; | |
| background-color: #121212; | |
| color: white; | |
| } | |
| h1, h2, h3 { | |
| color: white !important; | |
| } | |
| .container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| } | |
| .image-container img { | |
| object-fit: contain; | |
| max-height: 450px; | |
| width: auto; | |
| margin: 0 auto; | |
| display: block; | |
| border-radius: 8px; | |
| } | |
| .examples-container { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fill, minmax(150px, 1fr)); | |
| gap: 10px; | |
| margin-top: 10px; | |
| } | |
| .examples-container img { | |
| height: 120px; | |
| object-fit: cover; | |
| border-radius: 8px; | |
| cursor: pointer; | |
| transition: transform 0.2s; | |
| } | |
| .examples-container img:hover { | |
| transform: scale(1.05); | |
| } | |
| button#try-on-button { | |
| background-color: #FF6B00 !important; | |
| color: white !important; | |
| border: none !important; | |
| padding: 12px 20px !important; | |
| font-weight: 600 !important; | |
| border-radius: 8px !important; | |
| cursor: pointer !important; | |
| transition: background-color 0.3s !important; | |
| } | |
| button#try-on-button:hover { | |
| background-color: #FF8C33 !important; | |
| } | |
| footer {visibility: hidden} | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Base(), css=custom_css) as demo: | |
| gr.HTML("<h1 style='text-align: center; margin-bottom: 20px;'>AlphaBakeVirtual Try-On</h1>") | |
| with gr.Row(): | |
| # First column - Garment | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Garment Image") | |
| garment_input = gr.Image( | |
| label="Upload a garment image", | |
| type="pil", | |
| elem_id="garment-image", | |
| elem_classes=["image-container"], | |
| height=350 | |
| ) | |
| # Add example garment images | |
| gr.Examples( | |
| examples=sample_garments, | |
| inputs=garment_input, | |
| label="Garment Examples", | |
| examples_per_page=4 | |
| ) | |
| # Second column - Person | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Person Image") | |
| person_input = gr.Image( | |
| label="Upload a person image", | |
| type="pil", | |
| elem_id="person-image", | |
| elem_classes=["image-container"], | |
| height=350 | |
| ) | |
| # Add example person images | |
| gr.Examples( | |
| examples=sample_humans, | |
| inputs=person_input, | |
| label="Person Examples", | |
| examples_per_page=4 | |
| ) | |
| # Third column - Garment options & result | |
| with gr.Column(scale=1): | |
| # Try-on button | |
| try_on_button = gr.Button("Try On", elem_id="try-on-button", variant="primary", size="lg") | |
| # Result image | |
| output_image = gr.Image( | |
| label="Result", | |
| type="pil", | |
| elem_classes=["result-image"], | |
| height=400 | |
| ) | |
| # Validation function | |
| def validate_inputs(garment_img, person_img, garment_type, sleeve_length, garment_length): | |
| if garment_img is None: | |
| raise gr.Error("Please upload a garment image") | |
| if person_img is None: | |
| raise gr.Error("Please upload a person image") | |
| # If all validations pass, proceed with try-on | |
| try: | |
| result = virtual_tryon(garment_img, person_img) | |
| return result | |
| except Exception as e: | |
| raise gr.Error(f"Error: {str(e)}") | |
| # Connect button to validation and try-on functions | |
| try_on_button.click( | |
| fn=validate_inputs, | |
| inputs=[garment_input, person_input], | |
| outputs=output_image | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |