Avanish11 commited on
Commit
9baa8f4
·
verified ·
1 Parent(s): 6cbbaaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -87
app.py CHANGED
@@ -1,114 +1,68 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import DiffusionPipeline, StableDiffusionImg2ImgPipeline
4
  from safetensors.torch import load_file
5
  from PIL import Image
6
- import os
7
 
8
- # Your LoRA files
9
- LORA_FILES = [
10
- "gh1bli-style.safetensors",
11
- "ghibli_landscape_lora.safetensors",
12
- ]
13
 
14
- # Function to check LoRA type
15
- def detect_lora_type(lora_path):
16
- try:
17
- keys = load_file(lora_path).keys()
18
- # SDXL LoRAs have transformer_blocks, SD1.5 ones don’t
19
- if any("transformer_blocks" in k for k in keys):
20
- return "SDXL"
21
- return "SD1.5"
22
- except Exception as e:
23
- print(f"⚠️ Could not read {lora_path}: {e}")
24
- return "UNKNOWN"
25
-
26
- # Detect which model type to use
27
- detected_type = None
28
- for lora in LORA_FILES:
29
- if os.path.exists(lora):
30
- t = detect_lora_type(lora)
31
- print(f"🔍 Detected {lora} → {t}")
32
- if t != "UNKNOWN":
33
- detected_type = t
34
- break
35
-
36
- # Fallback if nothing detected
37
- if detected_type is None:
38
- detected_type = "SD1.5"
39
-
40
- # Choose model accordingly
41
- if detected_type == "SDXL":
42
- BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
43
- print("✅ Using SDXL base model")
44
- else:
45
- BASE_MODEL = "runwayml/stable-diffusion-v1-5"
46
- print("✅ Using SD1.5 base model")
47
-
48
- # Load base model
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
50
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
51
 
52
- pipe_txt2img = DiffusionPipeline.from_pretrained(
 
53
  BASE_MODEL,
54
  torch_dtype=dtype,
55
  use_safetensors=True,
 
56
  ).to(device)
57
 
58
- # Apply LoRAs safely
59
- for lora in LORA_FILES:
60
- if os.path.exists(lora):
61
- try:
62
- print(f"🎨 Loading LoRA: {lora}")
63
- pipe_txt2img.load_lora_weights(lora)
64
- except Exception as e:
65
- print(f"⚠️ Skipped {lora}: {e}")
66
-
67
- # Image-to-Image (if supported)
68
- if detected_type == "SDXL":
69
- from diffusers import StableDiffusionXLImg2ImgPipeline
70
- pipe_img2img = StableDiffusionXLImg2ImgPipeline(**pipe_txt2img.components)
71
- else:
72
- pipe_img2img = StableDiffusionImg2ImgPipeline(**pipe_txt2img.components)
73
 
74
- # Generate function
75
- def generate(prompt, steps=30, guidance=7.5, seed=42, strength=0.6, image=None):
76
- generator = torch.Generator(device=device).manual_seed(int(seed))
77
-
78
- if image is not None:
79
- init_image = Image.open(image).convert("RGB").resize((768, 768))
80
- result = pipe_img2img(
81
- prompt=prompt,
82
- image=init_image,
83
- strength=float(strength),
84
- num_inference_steps=int(steps),
85
- guidance_scale=float(guidance),
86
- generator=generator,
87
- ).images[0]
88
- else:
89
- result = pipe_txt2img(
90
- prompt=prompt,
91
- num_inference_steps=int(steps),
92
- guidance_scale=float(guidance),
93
- generator=generator,
94
- ).images[0]
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  return result
97
 
98
- # Gradio UI
99
  demo = gr.Interface(
100
- fn=generate,
101
  inputs=[
102
- gr.Textbox(label="Prompt", placeholder="A Ghibli-style mountain village at sunset"),
103
  gr.Slider(10, 50, 30, step=1, label="Inference Steps"),
104
  gr.Slider(1, 15, 7.5, step=0.5, label="Guidance Scale"),
 
105
  gr.Number(label="Seed", value=42),
106
- gr.Slider(0.1, 1.0, 0.6, step=0.1, label="Strength (for image input)"),
107
- gr.Image(label="Upload Image (optional)", type="filepath"),
108
  ],
109
- outputs=gr.Image(label="Generated Image"),
110
- title="Ghibli Style Maker – Auto Model Switch",
111
- description="Automatically detects if your LoRA is for SD1.5 or SDXL and generates Studio Ghibli–style art from text or image.",
112
  )
113
 
114
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionXLImg2ImgPipeline
4
  from safetensors.torch import load_file
5
  from PIL import Image
 
6
 
7
+ # --- Base SDXL model ---
8
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
9
+ LORA_PATH = "studioghibli_flux_r32-v2.safetensors"
 
 
10
 
11
+ # --- Setup device & dtype ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
14
 
15
+ print("🔹 Loading SDXL base model...")
16
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
17
  BASE_MODEL,
18
  torch_dtype=dtype,
19
  use_safetensors=True,
20
+ variant="fp16" if torch.cuda.is_available() else None,
21
  ).to(device)
22
 
23
+ # --- Apply LoRA weights ---
24
+ print("🎨 Applying Ghibli-style LoRA...")
25
+ try:
26
+ lora_weights = load_file(LORA_PATH)
27
+ pipe.unet.load_state_dict(lora_weights, strict=False)
28
+ print("✅ LoRA loaded successfully.")
29
+ except Exception as e:
30
+ print(f"⚠️ Failed to load LoRA: {e}")
 
 
 
 
 
 
 
31
 
32
+ # --- Ghibli-style conversion ---
33
+ def ghibli_style(image, steps=30, guidance=7.5, strength=0.6, seed=42):
34
+ if image is None:
35
+ raise gr.Error("Please upload an image to convert.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ generator = torch.Generator(device=device).manual_seed(int(seed))
38
+ init_image = Image.open(image).convert("RGB").resize((1024, 1024))
39
+
40
+ prompt = "Ghibli-style art, soft lighting, painterly textures, cinematic color palette"
41
+
42
+ result = pipe(
43
+ prompt=prompt,
44
+ image=init_image,
45
+ strength=float(strength),
46
+ num_inference_steps=int(steps),
47
+ guidance_scale=float(guidance),
48
+ generator=generator,
49
+ ).images[0]
50
+
51
  return result
52
 
53
+ # --- Gradio Interface ---
54
  demo = gr.Interface(
55
+ fn=ghibli_style,
56
  inputs=[
57
+ gr.Image(label="Upload Image", type="filepath"),
58
  gr.Slider(10, 50, 30, step=1, label="Inference Steps"),
59
  gr.Slider(1, 15, 7.5, step=0.5, label="Guidance Scale"),
60
+ gr.Slider(0.1, 1.0, 0.6, step=0.1, label="Style Strength"),
61
  gr.Number(label="Seed", value=42),
 
 
62
  ],
63
+ outputs=gr.Image(label="Ghibli Style Output"),
64
+ title="Ghibli Style Image Converter",
65
+ description="Upload any image and transform it into a Studio Ghibli-style artwork using the Flux LoRA and SDXL model.",
66
  )
67
 
68
  if __name__ == "__main__":