Spaces:
fantaxy
/
Running on Zero

flx-pulid / app.py
fantaxy's picture
Update app.py
bfb6f49 verified
import spaces
import time
import os
# ONNX Runtime CUDA provider ์‹œ๋„ (ํšจ๊ณผ ์—†๋”๋ผ๋„ ๋ฌดํ•ด)
os.environ.setdefault("INSIGHTFACE_ONNX_PROVIDERS", "CUDAExecutionProvider,CPUExecutionProvider")
os.environ.setdefault("ORT_LOG_severity_level", "3") # ORT ๋กœ๊ทธ ์ตœ์†Œํ™”
import gradio as gr
import torch
from einops import rearrange
from PIL import Image
import numpy as np
from flux.cli import SamplingOptions
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.util import load_ae, load_clip, load_flow_model, load_t5
from pulid.pipeline_flux import PuLIDPipeline
from pulid.utils import resize_numpy_image_long
NSFW_THRESHOLD = 0.85
def get_models(name: str, device: torch.device, offload: bool):
t5 = load_t5(device, max_length=128)
clip = load_clip(device)
model = load_flow_model(name, device="cpu" if offload else device)
model.eval()
ae = load_ae(name, device="cpu" if offload else device)
return model, ae, t5, clip
class FluxGenerator:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.offload = False
self.model_name = "flux-dev"
self.model, self.ae, self.t5, self.clip = get_models(
self.model_name,
device=self.device,
offload=self.offload,
)
device_str = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.bfloat16 if device_str == "cuda" else torch.float32
self.pulid_model = PuLIDPipeline(self.model, device_str, weight_dtype=weight_dtype)
self.pulid_model.load_pretrain()
flux_generator = FluxGenerator()
def _save_pil(img: Image.Image, prefix: str = "out") -> str:
os.makedirs("/tmp", exist_ok=True)
ts = int(time.time() * 1000)
path = f"/tmp/{prefix}_{ts}.png"
img.save(path, format="PNG")
return path
@spaces.GPU
@torch.inference_mode()
def generate_image(
width,
height,
num_steps,
start_step,
guidance,
seed,
prompt,
id_image=None,
id_weight=1.0,
neg_prompt="",
true_cfg=1.0,
timestep_to_start_cfg=1,
max_sequence_length=128,
):
flux_generator.t5.max_length = max_sequence_length
seed = int(seed)
if seed == -1:
seed = None
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
)
if opts.seed is None:
opts.seed = torch.Generator(device="cpu").seed()
print(f"Generating '{opts.prompt}' with seed {opts.seed}")
t0 = time.perf_counter()
use_true_cfg = abs(true_cfg - 1.0) > 1e-2
if id_image is not None:
id_image = resize_numpy_image_long(id_image, 1024)
id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(
id_image, cal_uncond=use_true_cfg
)
else:
id_embeddings = None
uncond_id_embeddings = None
# prepare input
x = get_noise(
1,
opts.height,
opts.width,
device=flux_generator.device,
dtype=torch.bfloat16 if flux_generator.device.type == "cuda" else torch.float32,
seed=opts.seed,
)
timesteps = get_schedule(
opts.num_steps,
x.shape[-1] * x.shape[-2] // 4,
shift=True,
)
if flux_generator.offload:
flux_generator.t5, flux_generator.clip = (
flux_generator.t5.to(flux_generator.device),
flux_generator.clip.to(flux_generator.device),
)
inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=opts.prompt)
inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
if flux_generator.offload:
flux_generator.t5, flux_generator.clip = flux_generator.t5.cpu(), flux_generator.clip.cpu()
torch.cuda.empty_cache()
flux_generator.model = flux_generator.model.to(flux_generator.device)
x = denoise(
flux_generator.model,
**inp,
timesteps=timesteps,
guidance=opts.guidance,
id=id_embeddings,
id_weight=id_weight,
start_step=start_step,
uncond_id=uncond_id_embeddings,
true_cfg=true_cfg,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=inp_neg["txt"] if use_true_cfg else None,
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
neg_vec=inp_neg["vec"] if use_true_cfg else None,
)
if flux_generator.offload:
flux_generator.model.cpu()
torch.cuda.empty_cache()
flux_generator.ae.decoder.to(x.device)
x = unpack(x.float(), opts.height, opts.width)
with torch.autocast(
device_type=flux_generator.device.type,
dtype=torch.bfloat16 if flux_generator.device.type == "cuda" else torch.float32,
):
x = flux_generator.ae.decode(x)
if flux_generator.offload:
flux_generator.ae.decoder.cpu()
torch.cuda.empty_cache()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s.")
# tensor [-1,1] โ†’ uint8 HWC
x = x.clamp(-1, 1)
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()).convert("RGB")
# ๋ฉ”์ธ ์ด๋ฏธ์ง€๋Š” ํŒŒ์ผ ๊ฒฝ๋กœ๋กœ ๋ฐ˜ํ™˜ (๋Œ€์šฉ๋Ÿ‰ base64 ์ „์†ก ์ด์Šˆ ํšŒํ”ผ)
out_path = _save_pil(img, "flux")
# ๋””๋ฒ„๊ทธ ๊ฐค๋Ÿฌ๋ฆฌ๋Š” ์„ ํƒ์ ์œผ๋กœ ์ถ•์†Œ/ํŒŒ์ผ ์ €์žฅ
debug_paths = []
for it in (flux_generator.pulid_model.debug_img_list or []):
try:
if isinstance(it, Image.Image):
pil = it.convert("RGB")
else:
if hasattr(it, "detach"):
arr = it.detach().cpu().numpy()
else:
arr = np.array(it)
if arr.ndim == 3 and arr.shape[0] in (1, 3): # C,H,W โ†’ H,W,C
arr = np.transpose(arr, (1, 2, 0))
if arr.dtype != np.uint8:
arr = np.clip(arr, 0, 255).astype(np.uint8)
pil = Image.fromarray(arr).convert("RGB")
# ์ธ๋„ค์ผํ™” (๋„ˆ๋น„ 512)
w, h = pil.size
if w > 512:
nh = int(h * (512 / w))
pil = pil.resize((512, nh), Image.BICUBIC)
debug_paths.append(_save_pil(pil, "debug"))
except Exception:
continue
return out_path, str(opts.seed), debug_paths
def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False):
# ํ™”๋ฉด ์ƒ๋‹จ์ด ๊ฐ€๋ ค์ง€๋Š” ๋ฌธ์ œ๋ฅผ ๊ฐ•ํ•˜๊ฒŒ ์™„ํ™”ํ•˜๋Š” ์ „์—ญ CSS
custom_css = """
:root{
/* ๊ธฐ๋ณธ HF ์ƒ๋‹จ ํˆด๋ฐ” ๋†’์ด ์ถ”์ •์น˜ (ํ™˜๊ฒฝ์— ๋”ฐ๋ผ 56~84px) */
--hf-header-offset: 72px;
--safe-top: env(safe-area-inset-top, 0px);
--top-offset: calc(var(--hf-header-offset) + var(--safe-top));
}
html, body, #root, .gradio-container{
margin: 0 !important;
padding-top: var(--top-offset) !important; /* ๊ณ ์ • ํ—ค๋”์— ๊ฐ€๋ฆฌ์ง€ ์•Š๋„๋ก ์ƒ๋‹จ ์—ฌ๋ฐฑ */
overflow: visible !important;
position: relative; /* ์Œ“์ž„ ๋งฅ๋ฝ ๋ณด์žฅ */
z-index: 0;
}
/* ๋‚ด๋ถ€ ์•ต์ปค/์ž๋™ ์Šคํฌ๋กค ์‹œ์—๋„ ํ—ค๋”์— ๊ฐ€๋ ค์ง€์ง€ ์•Š๋„๋ก */
:root { scroll-margin-top: var(--top-offset); scroll-padding-top: var(--top-offset); }
/* ์ƒ๋‹จ ๋ฐฐ์ง€ ์˜์—ญ์ด ๋‹ค๋ฅธ ์š”์†Œ ๋’ค๋กœ ๊น”๋ฆฌ์ง€ ์•Š๋„๋ก */
#top-badges { position: relative; z-index: 2; margin-top: 0 !important; }
/* ๋ชจ๋ฐ”์ผ์—์„œ ํ—ค๋”๊ฐ€ ๋” ๋†’๊ฒŒ ์žกํžˆ๋Š” ๊ฒฝ์šฐ ์—ฌ์œ ๋ฅผ ๋” ์ค€๋‹ค */
@media (max-width: 768px){
:root{ --hf-header-offset: 82px; }
.gradio-container { padding-top: calc(var(--top-offset) + 6px) !important; }
}
"""
with gr.Blocks(theme="soft", css=custom_css) as demo:
# ์ตœ์ƒ๋‹จ ์—ฌ๋ฐฑ ํ™•๋ณด์šฉ ์ŠคํŽ˜์ด์„œ (๋ธŒ๋ผ์šฐ์ €/๊ธฐ๊ธฐ๋ณ„ ์ƒ๋‹จ ๊ณ ์ • ๋ฐ” ๋Œ€์‘)
gr.HTML("<div id='top-spacer' style='height: 0;'></div>")
gr.HTML(
"""
<div id="top-badges" class='container' style='display:flex; justify-content:center; gap:12px; margin-top:0;'>
<a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank">
<img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge">
</a>
<a href="https://discord.gg/openfreeai" target="_blank">
<img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge">
</a>
</div>
"""
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic")
id_image = gr.Image(label="ID Image", type="numpy")
id_weight = gr.Slider(0.0, 3.0, 1, step=0.05, label="id weight")
width = gr.Slider(256, 1536, 896, step=16, label="Width")
height = gr.Slider(256, 1536, 1152, step=16, label="Height")
num_steps = gr.Slider(1, 20, 20, step=1, label="Number of steps")
start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID")
guidance = gr.Slider(1.0, 10.0, 4, step=0.1, label="Guidance")
seed = gr.Textbox(-1, label="Seed (-1 for random)")
max_sequence_length = gr.Slider(128, 512, 128, step=128, label="max_sequence_length for prompt (T5), small will be faster")
with gr.Accordion(
"Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG, if using true CFG, we recommend set the guidance scale to 1)",
open=False,
):
neg_prompt = gr.Textbox(
label="Negative Prompt",
value="bad quality, worst quality, text, signature, watermark, extra limbs",
)
true_cfg = gr.Slider(1.0, 10.0, 1, step=0.1, label="true CFG scale")
timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev)
generate_btn = gr.Button("Generate")
with gr.Column():
# ํŒŒ์ผ ๊ฒฝ๋กœ ๋ชจ๋“œ๋กœ ์ „์†ก โ†’ ๋ธŒ๋ผ์šฐ์ € ๋žœ๋”๋ง ์•ˆ์ •์ 
output_image = gr.Image(label="Generated Image", type="filepath", show_download_button=True)
seed_output = gr.Textbox(label="Used Seed")
intermediate_output = gr.Gallery(
label="Output (dev only)",
elem_id="gallery",
visible=args.dev,
allow_preview=True,
)
with gr.Row(), gr.Column():
gr.Markdown("## Examples")
example_inps = [
[
'a woman holding sign with glowing green text "PuLID for FLUX"',
"example_inputs/qw1.webp",
4,
4,
2680261499100305976,
1,
],
[
"portrait, pixar",
"example_inputs/qw2.webp",
1,
4,
9445036702517583939,
1,
],
]
gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg], label="fake CFG")
example_inps = [
[
"portrait, made of ice sculpture",
"example_inputs/qw3.webp",
1,
1,
3811899118709451814,
5,
],
]
gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg], label="true CFG")
generate_btn.click(
fn=generate_image,
inputs=[
width,
height,
num_steps,
start_step,
guidance,
seed,
prompt,
id_image,
id_weight,
neg_prompt,
true_cfg,
timestep_to_start_cfg,
max_sequence_length,
],
outputs=[output_image, seed_output, intermediate_output],
)
return demo
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
parser.add_argument("--name", type=str, default="flux-dev", choices=["flux-dev"], help="currently only support flux-dev")
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use"
)
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
parser.add_argument("--port", type=int, default=8080, help="Port to use")
parser.add_argument("--dev", action="store_true", help="Development mode")
parser.add_argument("--pretrained_model", type=str, help="for development")
args = parser.parse_args()
import huggingface_hub
hf_token = os.getenv("HF_TOKEN")
if hf_token:
huggingface_hub.login(hf_token)
demo = create_demo(args, args.name, args.device, args.offload)
demo.launch(ssr_mode=False)