Wiuhh's picture
Update app.py
cdda156 verified
import os
import sys
import tempfile
import cv2
import torch
import gradio as gr
from torchvision.transforms import functional
# --- PATCH FOR COMPATIBILITY ---
sys.modules["torchvision.transforms.functional_tensor"] = functional
# --- EMBEDDED CSS FOR STYLING ---
CSS_STYLING = """
:root {
--primary: hsl(265, 100%, 61%); /* Accent Purple */
--secondary: hsl(327, 100%, 72%); /* Accent Pink */
--blue: hsl(204, 100%, 72%); /* Accent Blue */
--background-darker: hsl(240, 14%, 3%);
--background-dark: hsl(240, 14%, 5%);
--card-background: hsl(240, 10%, 7%);
--light-text: hsl(240, 5%, 90%);
--muted-text: hsl(240, 4%, 65%);
--error-text: hsl(0, 100%, 74%);
--card-border: hsl(253, 100%, 72%, 0.15);
--input-background-fill: var(--card-background) !important;
--input-border-color: var(--card-border) !important;
--input-label-color: var(--light-text) !important;
}
.gradio-container {
background: var(--background-dark);
font-family: 'Inter', sans-serif;
}
#main-title {
color: var(--light-text);
text-align: center;
font-size: 2.5rem !important;
font-weight: 900;
}
#main-subtitle {
color: var(--muted-text);
text-align: center;
font-size: 1rem !important;
margin-top: -15px;
margin-bottom: 20px;
}
#submit-button {
background: linear-gradient(135deg, var(--primary), var(--secondary));
color: white;
font-weight: bold;
border-radius: 8px !important;
transition: all 0.3s ease;
}
#submit-button:hover {
box-shadow: 0px 4px 15px rgba(124, 58, 237, 0.4); /* Subtle purple shadow */
transform: translateY(-2px);
}
.gr-image {
border: 1px solid var(--card-border) !important;
border-radius: 12px !important;
min-height: 300px;
}
input[type="range"]::-webkit-slider-thumb {
background: var(--primary) !important;
}
input[type="range"]::-moz-range-thumb {
background: var(--primary) !important;
}
.gr-radio > div {
color: var(--light-text) !important;
}
"""
# --- DOWNLOAD HELPER FUNCTIONS ---
def download_file(url, dir_path, file_name):
"""Downloads a file if it doesn't exist."""
os.makedirs(dir_path, exist_ok=True)
file_path = os.path.join(dir_path, file_name)
if not os.path.exists(file_path):
print(f"Downloading {file_name}...")
try:
os.system(f"wget {url} -O {file_path}")
print("Download complete.")
except Exception as e:
print(f"Error downloading {file_name}: {e}")
return file_path
# --- DOWNLOAD MODELS AND EXAMPLES ---
print("Checking for required files...")
models_dir = 'models'
download_file('https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', models_dir, 'realesr-general-x4v3.pth')
download_file('https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', models_dir, 'GFPGANv1.4.pth')
download_file('https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth', models_dir, 'RestoreFormer.pth')
examples_dir = 'examples'
example1_path = download_file('https://raw.githubusercontent.com/TencentARC/GFPGAN/master/inputs/whole_imgs/10045.png', examples_dir, 'example1.png')
example2_path = download_file('https://raw.githubusercontent.com/TencentARC/GFPGAN/master/inputs/whole_imgs/Blake_Lively.jpg', examples_dir, 'example2.jpg')
# --- LOAD MODELS INTO MEMORY ---
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan.utils import GFPGANer
from realesrgan.utils import RealESRGANer
bg_upsampler = None
try:
model_path = os.path.join(models_dir, 'realesr-general-x4v3.pth')
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
half = torch.cuda.is_available()
bg_upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
print("Background Upsampler (Real-ESRGAN) loaded for 4x enhancement.")
except Exception as e:
print(f"Error loading background upsampler: {e}. The app may not work correctly.")
# --- CORE IMAGE PROCESSING FUNCTION ---
def upscale_image(img_path, version):
"""Enhance an image using GFPGAN and Real-ESRGAN with a fixed 4x upscale."""
if not img_path:
raise gr.Error("Please upload an image.")
if not bg_upsampler:
raise gr.Error("Background upsampler not loaded. Cannot proceed.")
try:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
if img is None: raise RuntimeError("Failed to read image.")
has_alpha = img.shape[2] == 4
face_enhancer = GFPGANer(
model_path=os.path.join(models_dir, f'{version}.pth'),
upscale=2, # Native GFPGAN upscale factor
arch='RestoreFormer' if version == 'RestoreFormer' else 'clean',
channel_multiplier=2,
bg_upsampler=bg_upsampler # Real-ESRGAN used for 4x background
)
# This will produce a 4x enhanced image because the bg_upsampler is 4x
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
output_rgb = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
ext = 'png' if has_alpha else 'jpg'
# Save to a temporary file for download
with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{ext}') as temp_file:
cv2.imwrite(temp_file.name, cv2.cvtColor(output_rgb, cv2.COLOR_RGB2BGR))
return output_rgb, temp_file.name
except Exception as error:
print(f"Error processing image: {error}")
raise gr.Error(f"An error occurred: {error}")
# --- GRADIO UI LAYOUT ---
with gr.Blocks(css=CSS_STYLING, theme=gr.themes.Base()) as demo:
gr.Markdown("<h1 id='main-title'>NeuraVision AI Image Upscaler</h1>", elem_id="main-title")
gr.Markdown("<p id='main-subtitle'>Enhance old, blurry, and low-resolution photos with AI (Fixed 4x Upscale).</p>", elem_id="main-subtitle")
with gr.Row(variant="panel"):
# LEFT COLUMN (INPUT & CONTROLS)
with gr.Column(scale=1):
input_image = gr.Image(type="filepath", label="Upload Image")
version = gr.Radio(
['GFPGANv1.4', 'RestoreFormer'], value='GFPGANv1.4',
label='AI Model', info="v1.4 for general use. RestoreFormer for old photos."
)
submit_btn = gr.Button("Enhance Image", variant="primary", elem_id="submit-button")
gr.Examples(
examples=[[example1_path, "RestoreFormer"], [example2_path, "GFPGANv1.4"]],
inputs=[input_image, version],
label="Click an example to start"
)
# RIGHT COLUMN (OUTPUT)
with gr.Column(scale=1):
output_image = gr.Image(type="numpy", label="Enhanced Result", interactive=False)
download_button = gr.File(label="Download Image", interactive=False)
# --- BUTTON & EVENT HANDLING ---
submit_btn.click(
fn=upscale_image,
inputs=[input_image, version],
outputs=[output_image, download_button]
)
input_image.clear(lambda: (None, None), None, [output_image, download_button])
if __name__ == "__main__":
demo.queue()
demo.launch(share=True)