File size: 5,066 Bytes
61173b1
7389164
5a72617
82abb51
7389164
a750d59
d7e28cf
7fe8ffb
0698e90
 
 
7389164
 
b2383a9
d7e28cf
61173b1
0698e90
 
 
 
d7e28cf
 
 
0698e90
 
 
 
d7e28cf
 
0698e90
 
d7e28cf
a750d59
0698e90
d7e28cf
61173b1
0698e90
 
 
 
82abb51
d7e28cf
 
0698e90
7389164
d7e28cf
 
0698e90
d7e28cf
 
0698e90
 
 
 
 
 
d7e28cf
 
0698e90
 
 
 
 
 
 
d7e28cf
 
 
0698e90
d7e28cf
 
 
 
 
 
0698e90
 
 
d7e28cf
 
0698e90
 
 
 
 
 
 
 
 
 
d7e28cf
0698e90
7389164
0698e90
 
 
 
 
6ddaec0
61173b1
0698e90
 
 
 
d7e28cf
0698e90
 
 
d7e28cf
0698e90
14f0b8d
0698e90
 
54c43f9
6ddaec0
a750d59
0698e90
 
d7e28cf
0698e90
6ddaec0
a750d59
d7e28cf
0698e90
d7e28cf
a750d59
0698e90
 
 
 
a750d59
0698e90
d7e28cf
54c43f9
d7e28cf
54c43f9
 
 
0698e90
a750d59
 
0698e90
 
 
a750d59
 
61173b1
2bce4f1
d7e28cf
0698e90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import spaces
from PIL import Image
from gradio_client import Client, handle_file
import logging
import tempfile
import os

# ==========================================================
# Setup
# ==========================================================
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

vton_client = None


# ==========================================================
# Init Remote Try-On Model (API Endpoint)
# ==========================================================
def init_client():
    global vton_client
    try:
        logger.info("πŸ”— Connecting to IDM-VTON API...")
        # βœ… Use the maintained API Space (not model repo)
        vton_client = Client("yisol-idm-vton-api.hf.space")
        logger.info("βœ… Connected to yisol-idm-vton-api.hf.space")
        return True
    except Exception as e:
        logger.error(f"❌ Connection failed: {e}")
        vton_client = None
        return False


model_ready = init_client()


# ==========================================================
# GPU Accelerated Try-On Function
# ==========================================================
@spaces.GPU(duration=180)
def virtual_tryon(person_image, garment_image, progress=gr.Progress()):
    global vton_client

    try:
        if not person_image or not garment_image:
            raise gr.Error("Please upload both images!")

        if not vton_client:
            if not init_client():
                raise gr.Error("⚠️ Model API unavailable. Try again later.")

        logger.info("🎯 Starting try-on pipeline...")
        progress(0.2, desc="πŸ“Έ Preparing images...")

        # Save temporary input files
        person_tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
        garment_tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')

        person_image.convert('RGB').save(person_tmp.name, 'JPEG', quality=95)
        garment_image.convert('RGB').save(garment_tmp.name, 'JPEG', quality=95)

        progress(0.5, desc="πŸͺ„ Running AI model (60-90s)...")

        # Call remote API Space
        result = vton_client.predict(
            handle_file(person_tmp.name),
            handle_file(garment_tmp.name),
            "universal garment",   # πŸ‘— supports saree, kurta, dress, etc.
            True,
            True,
            30,
            42,
            api_name="/tryon"
        )

        progress(0.9, desc="🎨 Generating result...")

        result_path = result[0] if isinstance(result, (tuple, list)) else result
        result_image = Image.open(result_path)

        # Cleanup temp files
        for tmp in [person_tmp.name, garment_tmp.name]:
            try:
                os.unlink(tmp)
            except:
                pass

        progress(1.0, desc="βœ… Done!")
        logger.info("βœ… Try-on complete!")
        return result_image

    except Exception as e:
        if "ZeroGPU" in str(e):
            raise gr.Error(
                "⚠️ GPU quota exceeded. Log in to Hugging Face or retry later."
            )
        logger.error(f"❌ Error: {e}")
        raise gr.Error(f"Failed: {str(e)}")


# ==========================================================
# Gradio Interface
# ==========================================================
with gr.Blocks(title="Mirro Virtual Try-On", theme=gr.themes.Soft()) as demo:

    gr.Markdown("# πŸ‘— Mirro Virtual Try-On\n### πŸ†“ AI-Powered Outfit Fitting (All Garments)")

    if model_ready:
        gr.Markdown("🟒 **Model ready!** Upload your photos below.")
    else:
        gr.Markdown("🟑 **Connecting...** Try again in a moment.")

    with gr.Row():
        with gr.Column():
            gr.Markdown("### πŸ“Έ Upload Images")
            person_input = gr.Image(label="πŸ‘€ Person Photo", type="pil", sources=["upload", "webcam"])
            garment_input = gr.Image(label="πŸ‘” Garment Photo", type="pil", sources=["upload", "webcam"])
            btn = gr.Button("✨ Generate Try-On", variant="primary")

        with gr.Column():
            gr.Markdown("### 🎯 Result")
            output_image = gr.Image(label="Virtual Try-On Result", type="pil")

    with gr.Accordion("πŸ’‘ Tips for Best Results", open=False):
        gr.Markdown("""
        - Clear, front-facing full-body person photo  
        - Plain background recommended  
        - Any garment type: saree, kurta, dress, jeans, jacket, etc.  
        - Processing takes ~60-90 seconds on ZeroGPU  
        """)

    btn.click(
        fn=virtual_tryon,
        inputs=[person_input, garment_input],
        outputs=output_image,
        api_name="predict"
    )

    gr.Markdown("""
    ---
    <div style="text-align:center; color:#666; padding:15px;">
        <p>πŸ†“ Powered by Hugging Face ZeroGPU + IDM-VTON API</p>
        <p style="font-size:0.9em;">Supports all garment categories β€” production ready for iOS integration</p>
    </div>
    """)

if __name__ == "__main__":
    demo.queue(max_size=20)
    demo.launch(server_name="0.0.0.0", server_port=7860)