Spaces:
Running
on
T4
Running
on
T4
Commit
·
7f04f92
1
Parent(s):
9b92c1c
fastapi added
Browse files
app.py
CHANGED
|
@@ -20,9 +20,11 @@ from typing import List, Optional
|
|
| 20 |
# Writable directories (Spaces often allow /tmp). Override via env if needed.
|
| 21 |
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/tmp/face_upscale/output")
|
| 22 |
INPUT_DIR = os.environ.get("INPUT_DIR", "/tmp/face_upscale/input")
|
|
|
|
| 23 |
# Ensure required directories exist at import time (for API mode)
|
| 24 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 25 |
os.makedirs(INPUT_DIR, exist_ok=True)
|
|
|
|
| 26 |
|
| 27 |
# FastAPI imports
|
| 28 |
from fastapi import FastAPI, UploadFile, File, Form, Depends, HTTPException, status
|
|
@@ -369,14 +371,14 @@ class Upscale:
|
|
| 369 |
|
| 370 |
def initBGUpscaleModel(self, upscale_model):
|
| 371 |
upscale_type, upscale_model = upscale_model.split(", ", 1)
|
| 372 |
-
download_from_url(upscale_models[upscale_model][0], upscale_model, os.path.join(
|
| 373 |
self.modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
|
| 374 |
netscale = 1 if any(sub in upscale_model.lower() for sub in ("x1", "1x")) else (2 if any(sub in upscale_model.lower() for sub in ("x2", "2x")) else 4)
|
| 375 |
model = None
|
| 376 |
half = True if torch.cuda.is_available() else False
|
| 377 |
if upscale_type:
|
| 378 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 379 |
-
loadnet = torch.load(os.path.join(
|
| 380 |
if 'params_ema' in loadnet or 'params' in loadnet:
|
| 381 |
loadnet = loadnet['params_ema'] if 'params_ema' in loadnet else loadnet['params']
|
| 382 |
if upscale_type == "SRVGG":
|
|
@@ -566,7 +568,7 @@ class Upscale:
|
|
| 566 |
model = SRFormer(img_size=img_size, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
|
| 567 |
qkv_bias=qkv_bias, qk_scale=None, ape=ape, patch_norm=patch_norm, upscale=netscale, upsampler=upsampler, resi_connection=resi_connection)
|
| 568 |
if model:
|
| 569 |
-
self.realesrganer = RealESRGANer(scale=netscale, model_path=os.path.join(
|
| 570 |
elif upscale_model:
|
| 571 |
import PIL
|
| 572 |
from image_gen_aux import UpscaleWithModel
|
|
@@ -604,12 +606,12 @@ class Upscale:
|
|
| 604 |
), interpolation=interpolation)
|
| 605 |
return cv_image, None
|
| 606 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 607 |
-
upscaler = UpscaleWithModel.from_pretrained(os.path.join(
|
| 608 |
upscaler.__class__ = UpscaleWithModel_Gfpgan
|
| 609 |
self.realesrganer = upscaler
|
| 610 |
|
| 611 |
def initFaceEnhancerModel(self, face_restoration, face_detection):
|
| 612 |
-
model_rootpath = os.path.join(
|
| 613 |
model_path = os.path.join(model_rootpath, face_restoration)
|
| 614 |
download_from_url(face_models[face_restoration][0], face_restoration, model_rootpath)
|
| 615 |
self.modelInUse = f"_{os.path.splitext(face_restoration)[0]}" + self.modelInUse
|
|
|
|
| 20 |
# Writable directories (Spaces often allow /tmp). Override via env if needed.
|
| 21 |
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/tmp/face_upscale/output")
|
| 22 |
INPUT_DIR = os.environ.get("INPUT_DIR", "/tmp/face_upscale/input")
|
| 23 |
+
WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", "/tmp/face_upscale/weights")
|
| 24 |
# Ensure required directories exist at import time (for API mode)
|
| 25 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 26 |
os.makedirs(INPUT_DIR, exist_ok=True)
|
| 27 |
+
os.makedirs(WEIGHTS_DIR, exist_ok=True)
|
| 28 |
|
| 29 |
# FastAPI imports
|
| 30 |
from fastapi import FastAPI, UploadFile, File, Form, Depends, HTTPException, status
|
|
|
|
| 371 |
|
| 372 |
def initBGUpscaleModel(self, upscale_model):
|
| 373 |
upscale_type, upscale_model = upscale_model.split(", ", 1)
|
| 374 |
+
download_from_url(upscale_models[upscale_model][0], upscale_model, os.path.join(WEIGHTS_DIR, "upscale"))
|
| 375 |
self.modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
|
| 376 |
netscale = 1 if any(sub in upscale_model.lower() for sub in ("x1", "1x")) else (2 if any(sub in upscale_model.lower() for sub in ("x2", "2x")) else 4)
|
| 377 |
model = None
|
| 378 |
half = True if torch.cuda.is_available() else False
|
| 379 |
if upscale_type:
|
| 380 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 381 |
+
loadnet = torch.load(os.path.join(WEIGHTS_DIR, "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
|
| 382 |
if 'params_ema' in loadnet or 'params' in loadnet:
|
| 383 |
loadnet = loadnet['params_ema'] if 'params_ema' in loadnet else loadnet['params']
|
| 384 |
if upscale_type == "SRVGG":
|
|
|
|
| 568 |
model = SRFormer(img_size=img_size, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
|
| 569 |
qkv_bias=qkv_bias, qk_scale=None, ape=ape, patch_norm=patch_norm, upscale=netscale, upsampler=upsampler, resi_connection=resi_connection)
|
| 570 |
if model:
|
| 571 |
+
self.realesrganer = RealESRGANer(scale=netscale, model_path=os.path.join(WEIGHTS_DIR, "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
| 572 |
elif upscale_model:
|
| 573 |
import PIL
|
| 574 |
from image_gen_aux import UpscaleWithModel
|
|
|
|
| 606 |
), interpolation=interpolation)
|
| 607 |
return cv_image, None
|
| 608 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 609 |
+
upscaler = UpscaleWithModel.from_pretrained(os.path.join(WEIGHTS_DIR, "upscale", upscale_model)).to(device)
|
| 610 |
upscaler.__class__ = UpscaleWithModel_Gfpgan
|
| 611 |
self.realesrganer = upscaler
|
| 612 |
|
| 613 |
def initFaceEnhancerModel(self, face_restoration, face_detection):
|
| 614 |
+
model_rootpath = os.path.join(WEIGHTS_DIR, "face")
|
| 615 |
model_path = os.path.join(model_rootpath, face_restoration)
|
| 616 |
download_from_url(face_models[face_restoration][0], face_restoration, model_rootpath)
|
| 617 |
self.modelInUse = f"_{os.path.splitext(face_restoration)[0]}" + self.modelInUse
|