gk2291's picture
Update app.py
0698e90 verified
import gradio as gr
import spaces
from PIL import Image
from gradio_client import Client, handle_file
import logging
import tempfile
import os
# ==========================================================
# Setup
# ==========================================================
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
vton_client = None
# ==========================================================
# Init Remote Try-On Model (API Endpoint)
# ==========================================================
def init_client():
global vton_client
try:
logger.info("πŸ”— Connecting to IDM-VTON API...")
# βœ… Use the maintained API Space (not model repo)
vton_client = Client("yisol-idm-vton-api.hf.space")
logger.info("βœ… Connected to yisol-idm-vton-api.hf.space")
return True
except Exception as e:
logger.error(f"❌ Connection failed: {e}")
vton_client = None
return False
model_ready = init_client()
# ==========================================================
# GPU Accelerated Try-On Function
# ==========================================================
@spaces.GPU(duration=180)
def virtual_tryon(person_image, garment_image, progress=gr.Progress()):
global vton_client
try:
if not person_image or not garment_image:
raise gr.Error("Please upload both images!")
if not vton_client:
if not init_client():
raise gr.Error("⚠️ Model API unavailable. Try again later.")
logger.info("🎯 Starting try-on pipeline...")
progress(0.2, desc="πŸ“Έ Preparing images...")
# Save temporary input files
person_tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
garment_tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
person_image.convert('RGB').save(person_tmp.name, 'JPEG', quality=95)
garment_image.convert('RGB').save(garment_tmp.name, 'JPEG', quality=95)
progress(0.5, desc="πŸͺ„ Running AI model (60-90s)...")
# Call remote API Space
result = vton_client.predict(
handle_file(person_tmp.name),
handle_file(garment_tmp.name),
"universal garment", # πŸ‘— supports saree, kurta, dress, etc.
True,
True,
30,
42,
api_name="/tryon"
)
progress(0.9, desc="🎨 Generating result...")
result_path = result[0] if isinstance(result, (tuple, list)) else result
result_image = Image.open(result_path)
# Cleanup temp files
for tmp in [person_tmp.name, garment_tmp.name]:
try:
os.unlink(tmp)
except:
pass
progress(1.0, desc="βœ… Done!")
logger.info("βœ… Try-on complete!")
return result_image
except Exception as e:
if "ZeroGPU" in str(e):
raise gr.Error(
"⚠️ GPU quota exceeded. Log in to Hugging Face or retry later."
)
logger.error(f"❌ Error: {e}")
raise gr.Error(f"Failed: {str(e)}")
# ==========================================================
# Gradio Interface
# ==========================================================
with gr.Blocks(title="Mirro Virtual Try-On", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ‘— Mirro Virtual Try-On\n### πŸ†“ AI-Powered Outfit Fitting (All Garments)")
if model_ready:
gr.Markdown("🟒 **Model ready!** Upload your photos below.")
else:
gr.Markdown("🟑 **Connecting...** Try again in a moment.")
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ“Έ Upload Images")
person_input = gr.Image(label="πŸ‘€ Person Photo", type="pil", sources=["upload", "webcam"])
garment_input = gr.Image(label="πŸ‘” Garment Photo", type="pil", sources=["upload", "webcam"])
btn = gr.Button("✨ Generate Try-On", variant="primary")
with gr.Column():
gr.Markdown("### 🎯 Result")
output_image = gr.Image(label="Virtual Try-On Result", type="pil")
with gr.Accordion("πŸ’‘ Tips for Best Results", open=False):
gr.Markdown("""
- Clear, front-facing full-body person photo
- Plain background recommended
- Any garment type: saree, kurta, dress, jeans, jacket, etc.
- Processing takes ~60-90 seconds on ZeroGPU
""")
btn.click(
fn=virtual_tryon,
inputs=[person_input, garment_input],
outputs=output_image,
api_name="predict"
)
gr.Markdown("""
---
<div style="text-align:center; color:#666; padding:15px;">
<p>πŸ†“ Powered by Hugging Face ZeroGPU + IDM-VTON API</p>
<p style="font-size:0.9em;">Supports all garment categories β€” production ready for iOS integration</p>
</div>
""")
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(server_name="0.0.0.0", server_port=7860)