HariLogicgo commited on
Commit
7f04f92
·
1 Parent(s): 9b92c1c

fastapi added

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -20,9 +20,11 @@ from typing import List, Optional
20
  # Writable directories (Spaces often allow /tmp). Override via env if needed.
21
  OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/tmp/face_upscale/output")
22
  INPUT_DIR = os.environ.get("INPUT_DIR", "/tmp/face_upscale/input")
 
23
  # Ensure required directories exist at import time (for API mode)
24
  os.makedirs(OUTPUT_DIR, exist_ok=True)
25
  os.makedirs(INPUT_DIR, exist_ok=True)
 
26
 
27
  # FastAPI imports
28
  from fastapi import FastAPI, UploadFile, File, Form, Depends, HTTPException, status
@@ -369,14 +371,14 @@ class Upscale:
369
 
370
  def initBGUpscaleModel(self, upscale_model):
371
  upscale_type, upscale_model = upscale_model.split(", ", 1)
372
- download_from_url(upscale_models[upscale_model][0], upscale_model, os.path.join("weights", "upscale"))
373
  self.modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
374
  netscale = 1 if any(sub in upscale_model.lower() for sub in ("x1", "1x")) else (2 if any(sub in upscale_model.lower() for sub in ("x2", "2x")) else 4)
375
  model = None
376
  half = True if torch.cuda.is_available() else False
377
  if upscale_type:
378
  from basicsr.archs.rrdbnet_arch import RRDBNet
379
- loadnet = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
380
  if 'params_ema' in loadnet or 'params' in loadnet:
381
  loadnet = loadnet['params_ema'] if 'params_ema' in loadnet else loadnet['params']
382
  if upscale_type == "SRVGG":
@@ -566,7 +568,7 @@ class Upscale:
566
  model = SRFormer(img_size=img_size, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
567
  qkv_bias=qkv_bias, qk_scale=None, ape=ape, patch_norm=patch_norm, upscale=netscale, upsampler=upsampler, resi_connection=resi_connection)
568
  if model:
569
- self.realesrganer = RealESRGANer(scale=netscale, model_path=os.path.join("weights", "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
570
  elif upscale_model:
571
  import PIL
572
  from image_gen_aux import UpscaleWithModel
@@ -604,12 +606,12 @@ class Upscale:
604
  ), interpolation=interpolation)
605
  return cv_image, None
606
  device = "cuda" if torch.cuda.is_available() else "cpu"
607
- upscaler = UpscaleWithModel.from_pretrained(os.path.join("weights", "upscale", upscale_model)).to(device)
608
  upscaler.__class__ = UpscaleWithModel_Gfpgan
609
  self.realesrganer = upscaler
610
 
611
  def initFaceEnhancerModel(self, face_restoration, face_detection):
612
- model_rootpath = os.path.join("weights", "face")
613
  model_path = os.path.join(model_rootpath, face_restoration)
614
  download_from_url(face_models[face_restoration][0], face_restoration, model_rootpath)
615
  self.modelInUse = f"_{os.path.splitext(face_restoration)[0]}" + self.modelInUse
 
20
  # Writable directories (Spaces often allow /tmp). Override via env if needed.
21
  OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/tmp/face_upscale/output")
22
  INPUT_DIR = os.environ.get("INPUT_DIR", "/tmp/face_upscale/input")
23
+ WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", "/tmp/face_upscale/weights")
24
  # Ensure required directories exist at import time (for API mode)
25
  os.makedirs(OUTPUT_DIR, exist_ok=True)
26
  os.makedirs(INPUT_DIR, exist_ok=True)
27
+ os.makedirs(WEIGHTS_DIR, exist_ok=True)
28
 
29
  # FastAPI imports
30
  from fastapi import FastAPI, UploadFile, File, Form, Depends, HTTPException, status
 
371
 
372
  def initBGUpscaleModel(self, upscale_model):
373
  upscale_type, upscale_model = upscale_model.split(", ", 1)
374
+ download_from_url(upscale_models[upscale_model][0], upscale_model, os.path.join(WEIGHTS_DIR, "upscale"))
375
  self.modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
376
  netscale = 1 if any(sub in upscale_model.lower() for sub in ("x1", "1x")) else (2 if any(sub in upscale_model.lower() for sub in ("x2", "2x")) else 4)
377
  model = None
378
  half = True if torch.cuda.is_available() else False
379
  if upscale_type:
380
  from basicsr.archs.rrdbnet_arch import RRDBNet
381
+ loadnet = torch.load(os.path.join(WEIGHTS_DIR, "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
382
  if 'params_ema' in loadnet or 'params' in loadnet:
383
  loadnet = loadnet['params_ema'] if 'params_ema' in loadnet else loadnet['params']
384
  if upscale_type == "SRVGG":
 
568
  model = SRFormer(img_size=img_size, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
569
  qkv_bias=qkv_bias, qk_scale=None, ape=ape, patch_norm=patch_norm, upscale=netscale, upsampler=upsampler, resi_connection=resi_connection)
570
  if model:
571
+ self.realesrganer = RealESRGANer(scale=netscale, model_path=os.path.join(WEIGHTS_DIR, "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
572
  elif upscale_model:
573
  import PIL
574
  from image_gen_aux import UpscaleWithModel
 
606
  ), interpolation=interpolation)
607
  return cv_image, None
608
  device = "cuda" if torch.cuda.is_available() else "cpu"
609
+ upscaler = UpscaleWithModel.from_pretrained(os.path.join(WEIGHTS_DIR, "upscale", upscale_model)).to(device)
610
  upscaler.__class__ = UpscaleWithModel_Gfpgan
611
  self.realesrganer = upscaler
612
 
613
  def initFaceEnhancerModel(self, face_restoration, face_detection):
614
+ model_rootpath = os.path.join(WEIGHTS_DIR, "face")
615
  model_path = os.path.join(model_rootpath, face_restoration)
616
  download_from_url(face_models[face_restoration][0], face_restoration, model_rootpath)
617
  self.modelInUse = f"_{os.path.splitext(face_restoration)[0]}" + self.modelInUse