Spaces:
Runtime error
Runtime error
Commit
·
252e766
1
Parent(s):
762cf51
Upload 75 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +44 -0
- app.py +316 -0
- app/big-lama.pt +3 -0
- app/u2net.onnx +3 -0
- app/yolov8x-seg.pt +3 -0
- lama_cleaner/__init__.py +11 -0
- lama_cleaner/__pycache__/__init__.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/const.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/helper.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/interactive_seg.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/model_manager.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/parse_args.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/runtime.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/schema.cpython-38.pyc +0 -0
- lama_cleaner/__pycache__/server2.cpython-38.pyc +0 -0
- lama_cleaner/benchmark.py +109 -0
- lama_cleaner/const.py +68 -0
- lama_cleaner/file_manager/__init__.py +1 -0
- lama_cleaner/file_manager/__pycache__/__init__.cpython-38.pyc +0 -0
- lama_cleaner/file_manager/__pycache__/file_manager.cpython-38.pyc +0 -0
- lama_cleaner/file_manager/__pycache__/storage_backends.cpython-38.pyc +0 -0
- lama_cleaner/file_manager/__pycache__/utils.cpython-38.pyc +0 -0
- lama_cleaner/file_manager/file_manager.py +252 -0
- lama_cleaner/file_manager/storage_backends.py +46 -0
- lama_cleaner/file_manager/utils.py +66 -0
- lama_cleaner/helper.py +218 -0
- lama_cleaner/interactive_seg.py +202 -0
- lama_cleaner/model/__init__.py +0 -0
- lama_cleaner/model/__pycache__/__init__.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/base.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/ddim_sampler.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/fcf.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/lama.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/ldm.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/manga.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/mat.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/opencv2.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/paint_by_example.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/plms_sampler.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/sd.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/utils.cpython-38.pyc +0 -0
- lama_cleaner/model/__pycache__/zits.cpython-38.pyc +0 -0
- lama_cleaner/model/base.py +247 -0
- lama_cleaner/model/ddim_sampler.py +192 -0
- lama_cleaner/model/fcf.py +1212 -0
- lama_cleaner/model/lama.py +61 -0
- lama_cleaner/model/ldm.py +310 -0
- lama_cleaner/model/manga.py +130 -0
- lama_cleaner/model/mat.py +1444 -0
- lama_cleaner/model/opencv2.py +25 -0
Dockerfile
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.8
|
| 2 |
+
|
| 3 |
+
RUN mkdir /app
|
| 4 |
+
RUN mkdir /.cache/
|
| 5 |
+
RUN mkdir /.cache/matplotlib
|
| 6 |
+
RUN mkdir /.cache/huggingface
|
| 7 |
+
RUN mkdir /.cache/huggingface/hub/
|
| 8 |
+
RUN mkdir /.cache/torch/
|
| 9 |
+
RUN mkdir /.config
|
| 10 |
+
RUN mkdir /.config/matplotlib/
|
| 11 |
+
|
| 12 |
+
RUN chmod -R 777 /.cache
|
| 13 |
+
RUN chmod -R 777 /.cache/matplotlib
|
| 14 |
+
RUN chmod -R 777 /.cache/huggingface/hub
|
| 15 |
+
RUN chmod -R 777 /.cache/torch
|
| 16 |
+
RUN chmod -R 777 /.config/
|
| 17 |
+
RUN chmod -R 777 /.config/matplotlib
|
| 18 |
+
RUN chmod -R 777 /app
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
COPY lama_cleaner ./lama_cleaner
|
| 22 |
+
COPY ./app.py ./app.py
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
COPY app/yolov8x-seg.pt /app
|
| 26 |
+
COPY big-lama.pt /app
|
| 27 |
+
# COPY clickseg_pplnet.pt /app
|
| 28 |
+
COPY u2net.onnx /app
|
| 29 |
+
COPY u2net.onnx /tmp
|
| 30 |
+
|
| 31 |
+
RUN chmod -R a+r /app/yolov8x-seg.pt
|
| 32 |
+
RUN chmod -R a+r /app/big-lama.pt
|
| 33 |
+
#RUN chmod -R a+r /app/clickseg_pplnet.pt
|
| 34 |
+
RUN chmod -R a+r /app/u2net.onnx
|
| 35 |
+
RUN chmod -R a+r /tmp/u2net.onnx
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
COPY ./requirements.txt ./requirements.txt
|
| 39 |
+
RUN pip install -r ./requirements.txt
|
| 40 |
+
|
| 41 |
+
RUN --mount=type=secret,id=SECRET,mode=0444,required=true \
|
| 42 |
+
git clone $(cat /run/secrets/SECRET)
|
| 43 |
+
|
| 44 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import imghdr
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from ultralytics import YOLO
|
| 9 |
+
from ultralytics.yolo.utils.ops import scale_image
|
| 10 |
+
import asyncio
|
| 11 |
+
from fastapi import FastAPI, File, UploadFile, Request, Response
|
| 12 |
+
from fastapi.responses import JSONResponse
|
| 13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
+
import uvicorn
|
| 15 |
+
# from mangum import Mangum
|
| 16 |
+
from argparse import ArgumentParser
|
| 17 |
+
|
| 18 |
+
import lama_cleaner.server2 as server
|
| 19 |
+
from lama_cleaner.helper import (
|
| 20 |
+
load_img,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/directory"
|
| 24 |
+
|
| 25 |
+
app = FastAPI()
|
| 26 |
+
|
| 27 |
+
# handler = Mangum(app)
|
| 28 |
+
origins = ["*"]
|
| 29 |
+
|
| 30 |
+
app.add_middleware(
|
| 31 |
+
CORSMiddleware,
|
| 32 |
+
allow_origins=origins,
|
| 33 |
+
allow_credentials=True,
|
| 34 |
+
allow_methods=["*"],
|
| 35 |
+
allow_headers=["*"],
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
| 40 |
+
"""
|
| 41 |
+
Args:
|
| 42 |
+
image_numpy: numpy image
|
| 43 |
+
ext: image extension
|
| 44 |
+
Returns:
|
| 45 |
+
image bytes
|
| 46 |
+
"""
|
| 47 |
+
data = cv2.imencode(
|
| 48 |
+
f".{ext}",
|
| 49 |
+
image_numpy,
|
| 50 |
+
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
| 51 |
+
)[1].tobytes()
|
| 52 |
+
return data
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_image_ext(img_bytes):
|
| 56 |
+
"""
|
| 57 |
+
Args:
|
| 58 |
+
img_bytes: image bytes
|
| 59 |
+
Returns:
|
| 60 |
+
image extension
|
| 61 |
+
"""
|
| 62 |
+
if not img_bytes:
|
| 63 |
+
raise ValueError("Empty input")
|
| 64 |
+
header = img_bytes[:32]
|
| 65 |
+
w = imghdr.what("", header)
|
| 66 |
+
if w is None:
|
| 67 |
+
w = "jpeg"
|
| 68 |
+
return w
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def predict_on_image(model, img, conf, retina_masks):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
model: YOLOv8 model
|
| 75 |
+
img: image (C, H, W)
|
| 76 |
+
conf: confidence threshold
|
| 77 |
+
retina_masks: use retina masks or not
|
| 78 |
+
Returns:
|
| 79 |
+
boxes: box with xyxy format, (N, 4)
|
| 80 |
+
masks: masks, (N, H, W)
|
| 81 |
+
cls: class of masks, (N, )
|
| 82 |
+
probs: confidence score, (N, 1)
|
| 83 |
+
"""
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
result = model(img, conf=conf, retina_masks=retina_masks, scale=1)[0]
|
| 86 |
+
|
| 87 |
+
boxes, masks, cls, probs = None, None, None, None
|
| 88 |
+
|
| 89 |
+
if result.boxes.cls.size(0) > 0:
|
| 90 |
+
# detection
|
| 91 |
+
cls = result.boxes.cls.cpu().numpy().astype(np.int32)
|
| 92 |
+
probs = result.boxes.conf.cpu().numpy() # confidence score, (N, 1)
|
| 93 |
+
boxes = result.boxes.xyxy.cpu().numpy() # box with xyxy format, (N, 4)
|
| 94 |
+
|
| 95 |
+
# segmentation
|
| 96 |
+
masks = result.masks.masks.cpu().numpy() # masks, (N, H, W)
|
| 97 |
+
masks = np.transpose(masks, (1, 2, 0)) # masks, (H, W, N)
|
| 98 |
+
# rescale masks to original image
|
| 99 |
+
masks = scale_image(masks.shape[:2], masks, result.masks.orig_shape)
|
| 100 |
+
masks = np.transpose(masks, (2, 0, 1)) # masks, (N, H, W)
|
| 101 |
+
|
| 102 |
+
return boxes, masks, cls, probs
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def overlay(image, mask, color, alpha, id, resize=None):
|
| 106 |
+
"""Overlays a binary mask on an image.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
image: Image to be overlayed on.
|
| 110 |
+
mask: Binary mask to overlay.
|
| 111 |
+
color: Color to use for the mask.
|
| 112 |
+
alpha: Opacity of the mask.
|
| 113 |
+
id: id of the mask
|
| 114 |
+
resize: Resize the image to this size. If None, no resizing is performed.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
The overlayed image.
|
| 118 |
+
"""
|
| 119 |
+
color = color[::-1]
|
| 120 |
+
colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
|
| 121 |
+
colored_mask = np.moveaxis(colored_mask, 0, -1)
|
| 122 |
+
masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
|
| 123 |
+
image_overlay = masked.filled()
|
| 124 |
+
|
| 125 |
+
imgray = cv2.cvtColor(image_overlay, cv2.COLOR_BGR2GRAY)
|
| 126 |
+
|
| 127 |
+
contour_thickness = 8
|
| 128 |
+
_, thresh = cv2.threshold(imgray, 255, 255, 255)
|
| 129 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
|
| 130 |
+
imgray = cv2.cvtColor(imgray, cv2.COLOR_GRAY2BGR)
|
| 131 |
+
imgray = cv2.drawContours(imgray, contours, -1, (255, 255, 255), contour_thickness)
|
| 132 |
+
|
| 133 |
+
imgray = np.where(imgray.any(-1, keepdims=True), (46, 36, 225), 0)
|
| 134 |
+
|
| 135 |
+
if resize is not None:
|
| 136 |
+
image = cv2.resize(image.transpose(1, 2, 0), resize)
|
| 137 |
+
image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize)
|
| 138 |
+
|
| 139 |
+
return imgray
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
async def process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls):
|
| 143 |
+
"""Process the mask of the image.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
idx: index of the mask
|
| 147 |
+
mask_i: mask of the image
|
| 148 |
+
boxes: box with xyxy format, (N, 4)
|
| 149 |
+
probs: confidence score, (N, 1)
|
| 150 |
+
yolo_model: YOLOv8 model
|
| 151 |
+
blank_image: blank image
|
| 152 |
+
cls: class of masks, (N, )
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
dictionary_seg: dictionary of the mask of the image
|
| 156 |
+
"""
|
| 157 |
+
dictionary_seg = {}
|
| 158 |
+
maskwith_back = overlay(blank_image, mask_i, color=(255, 155, 155), alpha=0.5, id=idx)
|
| 159 |
+
|
| 160 |
+
alpha = np.sum(maskwith_back, axis=-1) > 0
|
| 161 |
+
alpha = np.uint8(alpha * 255)
|
| 162 |
+
maskwith_back = np.dstack((maskwith_back, alpha))
|
| 163 |
+
|
| 164 |
+
imgencode = await asyncio.get_running_loop().run_in_executor(None, cv2.imencode, '.png', maskwith_back)
|
| 165 |
+
mask = base64.b64encode(imgencode[1]).decode('utf-8')
|
| 166 |
+
|
| 167 |
+
dictionary_seg["confi"] = f'{probs[idx] * 100:.2f}'
|
| 168 |
+
dictionary_seg["boxe"] = [int(item) for item in list(boxes[idx])]
|
| 169 |
+
dictionary_seg["mask"] = mask
|
| 170 |
+
dictionary_seg["cls"] = str(yolo_model.names[cls[idx]])
|
| 171 |
+
|
| 172 |
+
return dictionary_seg
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@app.middleware("http")
|
| 176 |
+
async def check_auth_header(request: Request, call_next):
|
| 177 |
+
token = request.headers.get('Authorization')
|
| 178 |
+
if token != os.environ.get("SECRET"):
|
| 179 |
+
return JSONResponse(content={'error': 'Authorization header missing or incorrect.'}, status_code=403)
|
| 180 |
+
else:
|
| 181 |
+
response = await call_next(request)
|
| 182 |
+
return response
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@app.post("/api/mask")
|
| 186 |
+
async def detect_mask(file: UploadFile = File()):
|
| 187 |
+
"""
|
| 188 |
+
Detects masks in an image uploaded via a POST request and returns a JSON response containing the details of the detected masks.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
None
|
| 192 |
+
|
| 193 |
+
Parameters:
|
| 194 |
+
- file: a file object containing the input image
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
A JSON response containing the details of the detected masks:
|
| 198 |
+
- code: 200 if objects were detected, 500 if no objects were detected
|
| 199 |
+
- msg: a message indicating whether objects were detected or not
|
| 200 |
+
- data: a list of dictionaries, where each dictionary contains the following keys:
|
| 201 |
+
- confi: the confidence level of the detected object
|
| 202 |
+
- boxe: a list containing the coordinates of the bounding box of the detected object
|
| 203 |
+
- mask: the mask of the detected object encoded in base64
|
| 204 |
+
- cls: the class of the detected object
|
| 205 |
+
|
| 206 |
+
Raises:
|
| 207 |
+
500: No objects detected
|
| 208 |
+
"""
|
| 209 |
+
file = await file.read()
|
| 210 |
+
|
| 211 |
+
img, _ = load_img(file)
|
| 212 |
+
|
| 213 |
+
# predict by YOLOv8
|
| 214 |
+
boxes, masks, cls, probs = predict_on_image(yolo_model, img, conf=0.55, retina_masks=True)
|
| 215 |
+
|
| 216 |
+
if boxes is None:
|
| 217 |
+
return {'code': 500, 'msg': 'No objects detected'}
|
| 218 |
+
|
| 219 |
+
# overlay masks on original image
|
| 220 |
+
blank_image = np.zeros(img.shape, dtype=np.uint8)
|
| 221 |
+
|
| 222 |
+
data = []
|
| 223 |
+
|
| 224 |
+
coroutines = [process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls) for idx, mask_i in
|
| 225 |
+
enumerate(masks)]
|
| 226 |
+
results = await asyncio.gather(*coroutines)
|
| 227 |
+
|
| 228 |
+
for result in results:
|
| 229 |
+
data.append(result)
|
| 230 |
+
|
| 231 |
+
return {'code': 200, 'msg': "object detected", 'data': data}
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@app.post("/api/lama/paint")
|
| 235 |
+
async def paint(img: UploadFile = File(), mask: UploadFile = File()):
|
| 236 |
+
"""
|
| 237 |
+
Endpoint to process an image with a given mask using the server's process function.
|
| 238 |
+
|
| 239 |
+
Route: '/api/lama/paint'
|
| 240 |
+
Method: POST
|
| 241 |
+
|
| 242 |
+
Parameters:
|
| 243 |
+
img: The input image file (JPEG or PNG format).
|
| 244 |
+
mask: The mask file (JPEG or PNG format).
|
| 245 |
+
Returns:
|
| 246 |
+
A JSON object containing the processed image in base64 format under the "image" key.
|
| 247 |
+
"""
|
| 248 |
+
img = await img.read()
|
| 249 |
+
mask = await mask.read()
|
| 250 |
+
return {"image": server.process(img, mask)}
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
@app.post("/api/remove")
|
| 254 |
+
async def remove(img: UploadFile = File()):
|
| 255 |
+
x = await img.read()
|
| 256 |
+
return {"image": server.remove(x)}
|
| 257 |
+
|
| 258 |
+
@app.post("/api/lama/model")
|
| 259 |
+
def switch_model(new_name: str):
|
| 260 |
+
return server.switch_model(new_name)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@app.get("/api/lama/model")
|
| 264 |
+
def current_model():
|
| 265 |
+
return server.current_model()
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@app.get("/api/lama/switchmode")
|
| 269 |
+
def get_is_disable_model_switch():
|
| 270 |
+
return server.get_is_disable_model_switch()
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@app.on_event("startup")
|
| 274 |
+
def init_data():
|
| 275 |
+
model_device = "cpu"
|
| 276 |
+
global yolo_model
|
| 277 |
+
# TODO Update for local development
|
| 278 |
+
yolo_model = YOLO('yolov8x-seg.pt')
|
| 279 |
+
# yolo_model = YOLO('/app/yolov8x-seg.pt')
|
| 280 |
+
yolo_model.to(model_device)
|
| 281 |
+
print(f"YOLO model yolov8x-seg.pt loaded.")
|
| 282 |
+
server.initModel()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def create_app(args):
|
| 286 |
+
"""
|
| 287 |
+
Creates the FastAPI app and adds the endpoints.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
args: The arguments.
|
| 291 |
+
"""
|
| 292 |
+
uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
parser = ArgumentParser()
|
| 297 |
+
parser.add_argument('--model_name', type=str, default='lama', help='Model name')
|
| 298 |
+
parser.add_argument('--host', type=str, default="0.0.0.0")
|
| 299 |
+
parser.add_argument('--port', type=int, default=5000)
|
| 300 |
+
parser.add_argument('--reload', type=bool, default=True)
|
| 301 |
+
parser.add_argument('--model_device', type=str, default='cpu', help='Model device')
|
| 302 |
+
parser.add_argument('--disable_model_switch', type=bool, default=False, help='Disable model switch')
|
| 303 |
+
parser.add_argument('--gui', type=bool, default=False, help='Enable GUI')
|
| 304 |
+
parser.add_argument('--cpu_offload', type=bool, default=False, help='Enable CPU offload')
|
| 305 |
+
parser.add_argument('--disable_nsfw', type=bool, default=False, help='Disable NSFW')
|
| 306 |
+
parser.add_argument('--enable_xformers', type=bool, default=False, help='Enable xformers')
|
| 307 |
+
parser.add_argument('--hf_access_token', type=str, default='', help='Hugging Face access token')
|
| 308 |
+
parser.add_argument('--local_files_only', type=bool, default=False, help='Enable local files only')
|
| 309 |
+
parser.add_argument('--no_half', type=bool, default=False, help='Disable half')
|
| 310 |
+
parser.add_argument('--sd_cpu_textencoder', type=bool, default=False, help='Enable CPU text encoder')
|
| 311 |
+
parser.add_argument('--sd_disable_nsfw', type=bool, default=False, help='Disable NSFW')
|
| 312 |
+
parser.add_argument('--sd_enable_xformers', type=bool, default=False, help='Enable xformers')
|
| 313 |
+
parser.add_argument('--sd_run_local', type=bool, default=False, help='Enable local files only')
|
| 314 |
+
|
| 315 |
+
args = parser.parse_args()
|
| 316 |
+
create_app(args)
|
app/big-lama.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:344c77bbcb158f17dd143070d1e789f38a66c04202311ae3a258ef66667a9ea9
|
| 3 |
+
size 205669692
|
app/u2net.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d10d2f3bb75ae3b6d527c77944fc5e7dcd94b29809d47a739a7a728a912b491
|
| 3 |
+
size 175997641
|
app/yolov8x-seg.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d63cbfa5764867c0066bedfa43cf2dcd90a412a1de44b2e238c43978a9d28ea6
|
| 3 |
+
size 144076467
|
lama_cleaner/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
warnings.simplefilter("ignore", UserWarning)
|
| 3 |
+
|
| 4 |
+
from lama_cleaner.parse_args import parse_args
|
| 5 |
+
|
| 6 |
+
def entry_point():
|
| 7 |
+
args = parse_args()
|
| 8 |
+
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
|
| 9 |
+
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
|
| 10 |
+
from lama_cleaner.server import main
|
| 11 |
+
main(args)
|
lama_cleaner/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (478 Bytes). View file
|
|
|
lama_cleaner/__pycache__/const.cpython-38.pyc
ADDED
|
Binary file (1.79 kB). View file
|
|
|
lama_cleaner/__pycache__/helper.cpython-38.pyc
ADDED
|
Binary file (5.43 kB). View file
|
|
|
lama_cleaner/__pycache__/interactive_seg.cpython-38.pyc
ADDED
|
Binary file (6.77 kB). View file
|
|
|
lama_cleaner/__pycache__/model_manager.cpython-38.pyc
ADDED
|
Binary file (2.27 kB). View file
|
|
|
lama_cleaner/__pycache__/parse_args.cpython-38.pyc
ADDED
|
Binary file (4.28 kB). View file
|
|
|
lama_cleaner/__pycache__/runtime.cpython-38.pyc
ADDED
|
Binary file (1.35 kB). View file
|
|
|
lama_cleaner/__pycache__/schema.cpython-38.pyc
ADDED
|
Binary file (2.42 kB). View file
|
|
|
lama_cleaner/__pycache__/server2.cpython-38.pyc
ADDED
|
Binary file (6.31 kB). View file
|
|
|
lama_cleaner/benchmark.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import nvidia_smi
|
| 9 |
+
import psutil
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from lama_cleaner.model_manager import ModelManager
|
| 13 |
+
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
torch._C._jit_override_can_fuse_on_cpu(False)
|
| 17 |
+
torch._C._jit_override_can_fuse_on_gpu(False)
|
| 18 |
+
torch._C._jit_set_texpr_fuser_enabled(False)
|
| 19 |
+
torch._C._jit_set_nvfuser_enabled(False)
|
| 20 |
+
except:
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
NUM_THREADS = str(4)
|
| 24 |
+
|
| 25 |
+
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
| 26 |
+
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
| 27 |
+
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
| 28 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
| 29 |
+
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
| 30 |
+
if os.environ.get("CACHE_DIR"):
|
| 31 |
+
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def run_model(model, size):
|
| 35 |
+
# RGB
|
| 36 |
+
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
|
| 37 |
+
mask = np.random.randint(0, 255, size).astype(np.uint8)
|
| 38 |
+
|
| 39 |
+
config = Config(
|
| 40 |
+
ldm_steps=2,
|
| 41 |
+
hd_strategy=HDStrategy.ORIGINAL,
|
| 42 |
+
hd_strategy_crop_margin=128,
|
| 43 |
+
hd_strategy_crop_trigger_size=128,
|
| 44 |
+
hd_strategy_resize_limit=128,
|
| 45 |
+
prompt="a fox is sitting on a bench",
|
| 46 |
+
sd_steps=5,
|
| 47 |
+
sd_sampler=SDSampler.ddim
|
| 48 |
+
)
|
| 49 |
+
model(image, mask, config)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def benchmark(model, times: int, empty_cache: bool):
|
| 53 |
+
sizes = [(512, 512)]
|
| 54 |
+
|
| 55 |
+
nvidia_smi.nvmlInit()
|
| 56 |
+
device_id = 0
|
| 57 |
+
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
|
| 58 |
+
|
| 59 |
+
def format(metrics):
|
| 60 |
+
return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
|
| 61 |
+
|
| 62 |
+
process = psutil.Process(os.getpid())
|
| 63 |
+
# 每个 size 给出显存和内存占用的指标
|
| 64 |
+
for size in sizes:
|
| 65 |
+
torch.cuda.empty_cache()
|
| 66 |
+
time_metrics = []
|
| 67 |
+
cpu_metrics = []
|
| 68 |
+
memory_metrics = []
|
| 69 |
+
gpu_memory_metrics = []
|
| 70 |
+
for _ in range(times):
|
| 71 |
+
start = time.time()
|
| 72 |
+
run_model(model, size)
|
| 73 |
+
torch.cuda.synchronize()
|
| 74 |
+
|
| 75 |
+
# cpu_metrics.append(process.cpu_percent())
|
| 76 |
+
time_metrics.append((time.time() - start) * 1000)
|
| 77 |
+
memory_metrics.append(process.memory_info().rss / 1024 / 1024)
|
| 78 |
+
gpu_memory_metrics.append(nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024)
|
| 79 |
+
|
| 80 |
+
print(f"size: {size}".center(80, "-"))
|
| 81 |
+
# print(f"cpu: {format(cpu_metrics)}")
|
| 82 |
+
print(f"latency: {format(time_metrics)}ms")
|
| 83 |
+
print(f"memory: {format(memory_metrics)} MB")
|
| 84 |
+
print(f"gpu memory: {format(gpu_memory_metrics)} MB")
|
| 85 |
+
|
| 86 |
+
nvidia_smi.nvmlShutdown()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_args_parser():
|
| 90 |
+
parser = argparse.ArgumentParser()
|
| 91 |
+
parser.add_argument("--name")
|
| 92 |
+
parser.add_argument("--device", default="cuda", type=str)
|
| 93 |
+
parser.add_argument("--times", default=10, type=int)
|
| 94 |
+
parser.add_argument("--empty-cache", action="store_true")
|
| 95 |
+
return parser.parse_args()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
args = get_args_parser()
|
| 100 |
+
device = torch.device(args.device)
|
| 101 |
+
model = ModelManager(
|
| 102 |
+
name=args.name,
|
| 103 |
+
device=device,
|
| 104 |
+
sd_run_local=True,
|
| 105 |
+
disable_nsfw=True,
|
| 106 |
+
sd_cpu_textencoder=True,
|
| 107 |
+
hf_access_token="123"
|
| 108 |
+
)
|
| 109 |
+
benchmark(model, args.times, args.empty_cache)
|
lama_cleaner/const.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
DEFAULT_MODEL = "lama"
|
| 4 |
+
AVAILABLE_MODELS = [
|
| 5 |
+
"lama",
|
| 6 |
+
"ldm",
|
| 7 |
+
"zits",
|
| 8 |
+
"mat",
|
| 9 |
+
"fcf",
|
| 10 |
+
"sd1.5",
|
| 11 |
+
"cv2",
|
| 12 |
+
"manga",
|
| 13 |
+
"sd2",
|
| 14 |
+
"paint_by_example"
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
| 18 |
+
DEFAULT_DEVICE = 'cuda'
|
| 19 |
+
|
| 20 |
+
NO_HALF_HELP = """
|
| 21 |
+
Using full precision model.
|
| 22 |
+
If your generate result is always black or green, use this argument. (sd/paint_by_exmaple)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
CPU_OFFLOAD_HELP = """
|
| 26 |
+
Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example)
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
DISABLE_NSFW_HELP = """
|
| 30 |
+
Disable NSFW checker. (sd/paint_by_example)
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
SD_CPU_TEXTENCODER_HELP = """
|
| 34 |
+
Run Stable Diffusion text encoder model on CPU to save GPU memory.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
LOCAL_FILES_ONLY_HELP = """
|
| 38 |
+
Use local files only, not connect to Hugging Face server. (sd/paint_by_example)
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
ENABLE_XFORMERS_HELP = """
|
| 42 |
+
Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example)
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
DEFAULT_MODEL_DIR = os.getenv(
|
| 46 |
+
"XDG_CACHE_HOME",
|
| 47 |
+
os.path.join(os.path.expanduser("~"), ".cache")
|
| 48 |
+
)
|
| 49 |
+
MODEL_DIR_HELP = """
|
| 50 |
+
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
OUTPUT_DIR_HELP = """
|
| 54 |
+
Only required when --input is directory. Result images will be saved to output directory automatically.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
INPUT_HELP = """
|
| 58 |
+
If input is image, it will be loaded by default.
|
| 59 |
+
If input is directory, you can browse and select image in file manager.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
GUI_HELP = """
|
| 63 |
+
Launch Lama Cleaner as desktop app
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
NO_GUI_AUTO_CLOSE_HELP = """
|
| 67 |
+
Prevent backend auto close after the GUI window closed.
|
| 68 |
+
"""
|
lama_cleaner/file_manager/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .file_manager import FileManager
|
lama_cleaner/file_manager/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (227 Bytes). View file
|
|
|
lama_cleaner/file_manager/__pycache__/file_manager.cpython-38.pyc
ADDED
|
Binary file (7.68 kB). View file
|
|
|
lama_cleaner/file_manager/__pycache__/storage_backends.cpython-38.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
lama_cleaner/file_manager/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (1.64 kB). View file
|
|
|
lama_cleaner/file_manager/file_manager.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image, ImageOps, PngImagePlugin
|
| 11 |
+
from loguru import logger
|
| 12 |
+
from watchdog.events import FileSystemEventHandler
|
| 13 |
+
from watchdog.observers import Observer
|
| 14 |
+
|
| 15 |
+
LARGE_ENOUGH_NUMBER = 100
|
| 16 |
+
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024 ** 2)
|
| 17 |
+
from .storage_backends import FilesystemStorageBackend
|
| 18 |
+
from .utils import aspect_to_string, generate_filename, glob_img
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FileManager(FileSystemEventHandler):
|
| 22 |
+
def __init__(self, app=None):
|
| 23 |
+
self.app = app
|
| 24 |
+
self._default_root_directory = "media"
|
| 25 |
+
self._default_thumbnail_directory = "media"
|
| 26 |
+
self._default_root_url = "/"
|
| 27 |
+
self._default_thumbnail_root_url = "/"
|
| 28 |
+
self._default_format = "JPEG"
|
| 29 |
+
self.output_dir: Path = None
|
| 30 |
+
|
| 31 |
+
if app is not None:
|
| 32 |
+
self.init_app(app)
|
| 33 |
+
|
| 34 |
+
self.image_dir_filenames = []
|
| 35 |
+
self.output_dir_filenames = []
|
| 36 |
+
|
| 37 |
+
self.image_dir_observer = None
|
| 38 |
+
self.output_dir_observer = None
|
| 39 |
+
|
| 40 |
+
self.modified_time = {
|
| 41 |
+
"image": datetime.utcnow(),
|
| 42 |
+
"output": datetime.utcnow(),
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
def start(self):
|
| 46 |
+
self.image_dir_filenames = self._media_names(self.root_directory)
|
| 47 |
+
self.output_dir_filenames = self._media_names(self.output_dir)
|
| 48 |
+
|
| 49 |
+
logger.info(f"Start watching image directory: {self.root_directory}")
|
| 50 |
+
self.image_dir_observer = Observer()
|
| 51 |
+
self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
|
| 52 |
+
self.image_dir_observer.start()
|
| 53 |
+
|
| 54 |
+
logger.info(f"Start watching output directory: {self.output_dir}")
|
| 55 |
+
self.output_dir_observer = Observer()
|
| 56 |
+
self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
|
| 57 |
+
self.output_dir_observer.start()
|
| 58 |
+
|
| 59 |
+
def on_modified(self, event):
|
| 60 |
+
if not os.path.isdir(event.src_path):
|
| 61 |
+
return
|
| 62 |
+
if event.src_path == str(self.root_directory):
|
| 63 |
+
logger.info(f"Image directory {event.src_path} modified")
|
| 64 |
+
self.image_dir_filenames = self._media_names(self.root_directory)
|
| 65 |
+
self.modified_time['image'] = datetime.utcnow()
|
| 66 |
+
elif event.src_path == str(self.output_dir):
|
| 67 |
+
logger.info(f"Output directory {event.src_path} modified")
|
| 68 |
+
self.output_dir_filenames = self._media_names(self.output_dir)
|
| 69 |
+
self.modified_time['output'] = datetime.utcnow()
|
| 70 |
+
|
| 71 |
+
def init_app(self, app):
|
| 72 |
+
if self.app is None:
|
| 73 |
+
self.app = app
|
| 74 |
+
app.thumbnail_instance = self
|
| 75 |
+
|
| 76 |
+
if not hasattr(app, "extensions"):
|
| 77 |
+
app.extensions = {}
|
| 78 |
+
|
| 79 |
+
if "thumbnail" in app.extensions:
|
| 80 |
+
raise RuntimeError("Flask-thumbnail extension already initialized")
|
| 81 |
+
|
| 82 |
+
app.extensions["thumbnail"] = self
|
| 83 |
+
|
| 84 |
+
app.config.setdefault("THUMBNAIL_MEDIA_ROOT", self._default_root_directory)
|
| 85 |
+
app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_ROOT", self._default_thumbnail_directory)
|
| 86 |
+
app.config.setdefault("THUMBNAIL_MEDIA_URL", self._default_root_url)
|
| 87 |
+
app.config.setdefault("THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url)
|
| 88 |
+
app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format)
|
| 89 |
+
|
| 90 |
+
def save_to_output_directory(self, image: np.ndarray, filename: str):
|
| 91 |
+
fp = Path(filename)
|
| 92 |
+
new_name = fp.stem + f"_{int(time.time())}" + fp.suffix
|
| 93 |
+
if image.shape[2] == 3:
|
| 94 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 95 |
+
elif image.shape[2] == 4:
|
| 96 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)
|
| 97 |
+
|
| 98 |
+
cv2.imwrite(str(self.output_dir / new_name), image)
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def root_directory(self):
|
| 102 |
+
path = self.app.config["THUMBNAIL_MEDIA_ROOT"]
|
| 103 |
+
|
| 104 |
+
if os.path.isabs(path):
|
| 105 |
+
return path
|
| 106 |
+
else:
|
| 107 |
+
return os.path.join(self.app.root_path, path)
|
| 108 |
+
|
| 109 |
+
@property
|
| 110 |
+
def thumbnail_directory(self):
|
| 111 |
+
path = self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"]
|
| 112 |
+
|
| 113 |
+
if os.path.isabs(path):
|
| 114 |
+
return path
|
| 115 |
+
else:
|
| 116 |
+
return os.path.join(self.app.root_path, path)
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def root_url(self):
|
| 120 |
+
return self.app.config["THUMBNAIL_MEDIA_URL"]
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def media_names(self):
|
| 124 |
+
# return self.image_dir_filenames
|
| 125 |
+
return self._media_names(self.root_directory)
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def output_media_names(self):
|
| 129 |
+
return self._media_names(self.output_dir)
|
| 130 |
+
# return self.output_dir_filenames
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def _media_names(directory: Path):
|
| 134 |
+
names = sorted([it.name for it in glob_img(directory)])
|
| 135 |
+
res = []
|
| 136 |
+
for name in names:
|
| 137 |
+
path = os.path.join(directory, name)
|
| 138 |
+
img = Image.open(path)
|
| 139 |
+
res.append({"name": name, "height": img.height, "width": img.width, "ctime": os.path.getctime(path)})
|
| 140 |
+
return res
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def thumbnail_url(self):
|
| 144 |
+
return self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_URL"]
|
| 145 |
+
|
| 146 |
+
def get_thumbnail(self, directory: Path, original_filename: str, width, height, **options):
|
| 147 |
+
storage = FilesystemStorageBackend(self.app)
|
| 148 |
+
crop = options.get("crop", "fit")
|
| 149 |
+
background = options.get("background")
|
| 150 |
+
quality = options.get("quality", 90)
|
| 151 |
+
|
| 152 |
+
original_path, original_filename = os.path.split(original_filename)
|
| 153 |
+
original_filepath = os.path.join(directory, original_path, original_filename)
|
| 154 |
+
image = Image.open(BytesIO(storage.read(original_filepath)))
|
| 155 |
+
|
| 156 |
+
# keep ratio resize
|
| 157 |
+
if width is not None:
|
| 158 |
+
height = int(image.height * width / image.width)
|
| 159 |
+
else:
|
| 160 |
+
width = int(image.width * height / image.height)
|
| 161 |
+
|
| 162 |
+
thumbnail_size = (width, height)
|
| 163 |
+
|
| 164 |
+
thumbnail_filename = generate_filename(
|
| 165 |
+
original_filename, aspect_to_string(thumbnail_size), crop, background, quality
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
thumbnail_filepath = os.path.join(
|
| 169 |
+
self.thumbnail_directory, original_path, thumbnail_filename
|
| 170 |
+
)
|
| 171 |
+
thumbnail_url = os.path.join(self.thumbnail_url, original_path, thumbnail_filename)
|
| 172 |
+
|
| 173 |
+
if storage.exists(thumbnail_filepath):
|
| 174 |
+
return thumbnail_url, (width, height)
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
image.load()
|
| 178 |
+
except (IOError, OSError):
|
| 179 |
+
self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
|
| 180 |
+
return thumbnail_url, (width, height)
|
| 181 |
+
|
| 182 |
+
# get original image format
|
| 183 |
+
options["format"] = options.get("format", image.format)
|
| 184 |
+
|
| 185 |
+
image = self._create_thumbnail(image, thumbnail_size, crop, background=background)
|
| 186 |
+
|
| 187 |
+
raw_data = self.get_raw_data(image, **options)
|
| 188 |
+
storage.save(thumbnail_filepath, raw_data)
|
| 189 |
+
|
| 190 |
+
return thumbnail_url, (width, height)
|
| 191 |
+
|
| 192 |
+
def get_raw_data(self, image, **options):
|
| 193 |
+
data = {
|
| 194 |
+
"format": self._get_format(image, **options),
|
| 195 |
+
"quality": options.get("quality", 90),
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
_file = BytesIO()
|
| 199 |
+
image.save(_file, **data)
|
| 200 |
+
return _file.getvalue()
|
| 201 |
+
|
| 202 |
+
@staticmethod
|
| 203 |
+
def colormode(image, colormode="RGB"):
|
| 204 |
+
if colormode == "RGB" or colormode == "RGBA":
|
| 205 |
+
if image.mode == "RGBA":
|
| 206 |
+
return image
|
| 207 |
+
if image.mode == "LA":
|
| 208 |
+
return image.convert("RGBA")
|
| 209 |
+
return image.convert(colormode)
|
| 210 |
+
|
| 211 |
+
if colormode == "GRAY":
|
| 212 |
+
return image.convert("L")
|
| 213 |
+
|
| 214 |
+
return image.convert(colormode)
|
| 215 |
+
|
| 216 |
+
@staticmethod
|
| 217 |
+
def background(original_image, color=0xFF):
|
| 218 |
+
size = (max(original_image.size),) * 2
|
| 219 |
+
image = Image.new("L", size, color)
|
| 220 |
+
image.paste(
|
| 221 |
+
original_image,
|
| 222 |
+
tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return image
|
| 226 |
+
|
| 227 |
+
def _get_format(self, image, **options):
|
| 228 |
+
if options.get("format"):
|
| 229 |
+
return options.get("format")
|
| 230 |
+
if image.format:
|
| 231 |
+
return image.format
|
| 232 |
+
|
| 233 |
+
return self.app.config["THUMBNAIL_DEFAULT_FORMAT"]
|
| 234 |
+
|
| 235 |
+
def _create_thumbnail(self, image, size, crop="fit", background=None):
|
| 236 |
+
try:
|
| 237 |
+
resample = Image.Resampling.LANCZOS
|
| 238 |
+
except AttributeError: # pylint: disable=raise-missing-from
|
| 239 |
+
resample = Image.ANTIALIAS
|
| 240 |
+
|
| 241 |
+
if crop == "fit":
|
| 242 |
+
image = ImageOps.fit(image, size, resample)
|
| 243 |
+
else:
|
| 244 |
+
image = image.copy()
|
| 245 |
+
image.thumbnail(size, resample=resample)
|
| 246 |
+
|
| 247 |
+
if background is not None:
|
| 248 |
+
image = self.background(image)
|
| 249 |
+
|
| 250 |
+
image = self.colormode(image)
|
| 251 |
+
|
| 252 |
+
return image
|
lama_cleaner/file_manager/storage_backends.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py
|
| 2 |
+
import errno
|
| 3 |
+
import os
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaseStorageBackend(ABC):
|
| 8 |
+
def __init__(self, app=None):
|
| 9 |
+
self.app = app
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def read(self, filepath, mode="rb", **kwargs):
|
| 13 |
+
raise NotImplementedError
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def exists(self, filepath):
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def save(self, filepath, data):
|
| 21 |
+
raise NotImplementedError
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FilesystemStorageBackend(BaseStorageBackend):
|
| 25 |
+
def read(self, filepath, mode="rb", **kwargs):
|
| 26 |
+
with open(filepath, mode) as f: # pylint: disable=unspecified-encoding
|
| 27 |
+
return f.read()
|
| 28 |
+
|
| 29 |
+
def exists(self, filepath):
|
| 30 |
+
return os.path.exists(filepath)
|
| 31 |
+
|
| 32 |
+
def save(self, filepath, data):
|
| 33 |
+
directory = os.path.dirname(filepath)
|
| 34 |
+
|
| 35 |
+
if not os.path.exists(directory):
|
| 36 |
+
try:
|
| 37 |
+
os.makedirs(directory)
|
| 38 |
+
except OSError as e:
|
| 39 |
+
if e.errno != errno.EEXIST:
|
| 40 |
+
raise
|
| 41 |
+
|
| 42 |
+
if not os.path.isdir(directory):
|
| 43 |
+
raise IOError("{} is not a directory".format(directory))
|
| 44 |
+
|
| 45 |
+
with open(filepath, "wb") as f:
|
| 46 |
+
f.write(data)
|
lama_cleaner/file_manager/utils.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from typing import Union
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def generate_filename(original_filename, *options):
|
| 9 |
+
name, ext = os.path.splitext(original_filename)
|
| 10 |
+
for v in options:
|
| 11 |
+
if v:
|
| 12 |
+
name += "_%s" % v
|
| 13 |
+
name += ext
|
| 14 |
+
|
| 15 |
+
return name
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def parse_size(size):
|
| 19 |
+
if isinstance(size, int):
|
| 20 |
+
# If the size parameter is a single number, assume square aspect.
|
| 21 |
+
return [size, size]
|
| 22 |
+
|
| 23 |
+
if isinstance(size, (tuple, list)):
|
| 24 |
+
if len(size) == 1:
|
| 25 |
+
# If single value tuple/list is provided, exand it to two elements
|
| 26 |
+
return size + type(size)(size)
|
| 27 |
+
return size
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
|
| 31 |
+
except ValueError:
|
| 32 |
+
raise ValueError( # pylint: disable=raise-missing-from
|
| 33 |
+
"Bad thumbnail size format. Valid format is INTxINT."
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
if len(thumbnail_size) == 1:
|
| 37 |
+
# If the size parameter only contains a single integer, assume square aspect.
|
| 38 |
+
thumbnail_size.append(thumbnail_size[0])
|
| 39 |
+
|
| 40 |
+
return thumbnail_size
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def aspect_to_string(size):
|
| 44 |
+
if isinstance(size, str):
|
| 45 |
+
return size
|
| 46 |
+
|
| 47 |
+
return "x".join(map(str, size))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
IMG_SUFFIX = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def glob_img(p: Union[Path, str], recursive: bool = False):
|
| 54 |
+
p = Path(p)
|
| 55 |
+
if p.is_file() and p.suffix in IMG_SUFFIX:
|
| 56 |
+
yield p
|
| 57 |
+
else:
|
| 58 |
+
if recursive:
|
| 59 |
+
files = Path(p).glob("**/*.*")
|
| 60 |
+
else:
|
| 61 |
+
files = Path(p).glob("*.*")
|
| 62 |
+
|
| 63 |
+
for it in files:
|
| 64 |
+
if it.suffix not in IMG_SUFFIX:
|
| 65 |
+
continue
|
| 66 |
+
yield it
|
lama_cleaner/helper.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
from urllib.parse import urlparse
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image, ImageOps
|
| 11 |
+
from loguru import logger
|
| 12 |
+
from torch.hub import download_url_to_file, get_dir
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_cache_path_by_url(url):
|
| 16 |
+
parts = urlparse(url)
|
| 17 |
+
hub_dir = get_dir()
|
| 18 |
+
model_dir = os.path.join(hub_dir, "checkpoints")
|
| 19 |
+
if not os.path.isdir(model_dir):
|
| 20 |
+
os.makedirs(model_dir)
|
| 21 |
+
filename = os.path.basename(parts.path)
|
| 22 |
+
cached_file = os.path.join(model_dir, filename)
|
| 23 |
+
return cached_file
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def download_model(url):
|
| 27 |
+
cached_file = get_cache_path_by_url(url)
|
| 28 |
+
if not os.path.exists(cached_file):
|
| 29 |
+
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
| 30 |
+
hash_prefix = None
|
| 31 |
+
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
| 32 |
+
return cached_file
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def ceil_modulo(x, mod):
|
| 36 |
+
if x % mod == 0:
|
| 37 |
+
return x
|
| 38 |
+
return (x // mod + 1) * mod
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_jit_model(url_or_path, device):
|
| 42 |
+
# if os.path.exists(url_or_path):
|
| 43 |
+
# model_path = url_or_path
|
| 44 |
+
# else:
|
| 45 |
+
# model_path = download_model(url_or_path)
|
| 46 |
+
model_path = os.getcwd()
|
| 47 |
+
logger.info(f"Load model from: {model_path}")
|
| 48 |
+
try:
|
| 49 |
+
model = torch.jit.load(model_path).to(device)
|
| 50 |
+
except:
|
| 51 |
+
logger.error(
|
| 52 |
+
f"Failed to load {model_path}, delete model and restart lama-cleaner"
|
| 53 |
+
)
|
| 54 |
+
exit(-1)
|
| 55 |
+
model.eval()
|
| 56 |
+
return model
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_model(model: torch.nn.Module, url_or_path, device):
|
| 60 |
+
if os.path.exists(url_or_path):
|
| 61 |
+
model_path = url_or_path
|
| 62 |
+
else:
|
| 63 |
+
model_path = download_model(url_or_path)
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
| 67 |
+
model.load_state_dict(state_dict, strict=True)
|
| 68 |
+
model.to(device)
|
| 69 |
+
logger.info(f"Load model from: {model_path}")
|
| 70 |
+
except:
|
| 71 |
+
logger.error(
|
| 72 |
+
f"Failed to load {model_path}, delete model and restart lama-cleaner"
|
| 73 |
+
)
|
| 74 |
+
exit(-1)
|
| 75 |
+
model.eval()
|
| 76 |
+
return model
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
| 80 |
+
data = cv2.imencode(
|
| 81 |
+
f".{ext}",
|
| 82 |
+
image_numpy,
|
| 83 |
+
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
| 84 |
+
)[1]
|
| 85 |
+
image_bytes = data.tobytes()
|
| 86 |
+
return image_bytes
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def load_img(img_bytes, gray: bool = False):
|
| 90 |
+
alpha_channel = None
|
| 91 |
+
image = Image.open(io.BytesIO(img_bytes))
|
| 92 |
+
try:
|
| 93 |
+
image = ImageOps.exif_transpose(image)
|
| 94 |
+
except:
|
| 95 |
+
pass
|
| 96 |
+
|
| 97 |
+
if gray:
|
| 98 |
+
image = image.convert('L')
|
| 99 |
+
np_img = np.array(image)
|
| 100 |
+
else:
|
| 101 |
+
if image.mode == 'RGBA':
|
| 102 |
+
np_img = np.array(image)
|
| 103 |
+
alpha_channel = np_img[:, :, -1]
|
| 104 |
+
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
| 105 |
+
else:
|
| 106 |
+
image = image.convert('RGB')
|
| 107 |
+
np_img = np.array(image)
|
| 108 |
+
|
| 109 |
+
return np_img, alpha_channel
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def norm_img(np_img):
|
| 113 |
+
if len(np_img.shape) == 2:
|
| 114 |
+
np_img = np_img[:, :, np.newaxis]
|
| 115 |
+
np_img = np.transpose(np_img, (2, 0, 1))
|
| 116 |
+
np_img = np_img.astype("float32") / 255
|
| 117 |
+
return np_img
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def resize_max_size(
|
| 121 |
+
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
|
| 122 |
+
) -> np.ndarray:
|
| 123 |
+
# Resize image's longer size to size_limit if longer size larger than size_limit
|
| 124 |
+
h, w = np_img.shape[:2]
|
| 125 |
+
if max(h, w) > size_limit:
|
| 126 |
+
ratio = size_limit / max(h, w)
|
| 127 |
+
new_w = int(w * ratio + 0.5)
|
| 128 |
+
new_h = int(h * ratio + 0.5)
|
| 129 |
+
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
|
| 130 |
+
else:
|
| 131 |
+
return np_img
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def pad_img_to_modulo(
|
| 135 |
+
img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
|
| 136 |
+
):
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
img: [H, W, C]
|
| 141 |
+
mod:
|
| 142 |
+
square: 是否为正方形
|
| 143 |
+
min_size:
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
|
| 147 |
+
"""
|
| 148 |
+
if len(img.shape) == 2:
|
| 149 |
+
img = img[:, :, np.newaxis]
|
| 150 |
+
height, width = img.shape[:2]
|
| 151 |
+
out_height = ceil_modulo(height, mod)
|
| 152 |
+
out_width = ceil_modulo(width, mod)
|
| 153 |
+
|
| 154 |
+
if min_size is not None:
|
| 155 |
+
assert min_size % mod == 0
|
| 156 |
+
out_width = max(min_size, out_width)
|
| 157 |
+
out_height = max(min_size, out_height)
|
| 158 |
+
|
| 159 |
+
if square:
|
| 160 |
+
max_size = max(out_height, out_width)
|
| 161 |
+
out_height = max_size
|
| 162 |
+
out_width = max_size
|
| 163 |
+
|
| 164 |
+
return np.pad(
|
| 165 |
+
img,
|
| 166 |
+
((0, out_height - height), (0, out_width - width), (0, 0)),
|
| 167 |
+
mode="symmetric",
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
| 172 |
+
"""
|
| 173 |
+
Args:
|
| 174 |
+
mask: (h, w, 1) 0~255
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
|
| 178 |
+
"""
|
| 179 |
+
height, width = mask.shape[:2]
|
| 180 |
+
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
| 181 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 182 |
+
|
| 183 |
+
boxes = []
|
| 184 |
+
for cnt in contours:
|
| 185 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
| 186 |
+
box = np.array([x, y, x + w, y + h]).astype(int)
|
| 187 |
+
|
| 188 |
+
box[::2] = np.clip(box[::2], 0, width)
|
| 189 |
+
box[1::2] = np.clip(box[1::2], 0, height)
|
| 190 |
+
boxes.append(box)
|
| 191 |
+
|
| 192 |
+
return boxes
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
|
| 196 |
+
"""
|
| 197 |
+
Args:
|
| 198 |
+
mask: (h, w) 0~255
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
|
| 202 |
+
"""
|
| 203 |
+
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
| 204 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 205 |
+
|
| 206 |
+
max_area = 0
|
| 207 |
+
max_index = -1
|
| 208 |
+
for i, cnt in enumerate(contours):
|
| 209 |
+
area = cv2.contourArea(cnt)
|
| 210 |
+
if area > max_area:
|
| 211 |
+
max_area = area
|
| 212 |
+
max_index = i
|
| 213 |
+
|
| 214 |
+
if max_index != -1:
|
| 215 |
+
new_mask = np.zeros_like(mask)
|
| 216 |
+
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
|
| 217 |
+
else:
|
| 218 |
+
return mask
|
lama_cleaner/interactive_seg.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Tuple, List
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from loguru import logger
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
|
| 11 |
+
from lama_cleaner.helper import load_jit_model
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Click(BaseModel):
|
| 15 |
+
# [y, x]
|
| 16 |
+
coords: Tuple[float, float]
|
| 17 |
+
is_positive: bool
|
| 18 |
+
indx: int
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def coords_and_indx(self):
|
| 22 |
+
return (*self.coords, self.indx)
|
| 23 |
+
|
| 24 |
+
def scale(self, x_ratio: float, y_ratio: float) -> 'Click':
|
| 25 |
+
return Click(
|
| 26 |
+
coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio),
|
| 27 |
+
is_positive=self.is_positive,
|
| 28 |
+
indx=self.indx
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ResizeTrans:
|
| 33 |
+
def __init__(self, size=480):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.crop_height = size
|
| 36 |
+
self.crop_width = size
|
| 37 |
+
|
| 38 |
+
def transform(self, image_nd, clicks_lists):
|
| 39 |
+
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
|
| 40 |
+
image_height, image_width = image_nd.shape[2:4]
|
| 41 |
+
self.image_height = image_height
|
| 42 |
+
self.image_width = image_width
|
| 43 |
+
image_nd_r = F.interpolate(image_nd, (self.crop_height, self.crop_width), mode='bilinear', align_corners=True)
|
| 44 |
+
|
| 45 |
+
y_ratio = self.crop_height / image_height
|
| 46 |
+
x_ratio = self.crop_width / image_width
|
| 47 |
+
|
| 48 |
+
clicks_lists_resized = []
|
| 49 |
+
for clicks_list in clicks_lists:
|
| 50 |
+
clicks_list_resized = [click.scale(y_ratio, x_ratio) for click in clicks_list]
|
| 51 |
+
clicks_lists_resized.append(clicks_list_resized)
|
| 52 |
+
|
| 53 |
+
return image_nd_r, clicks_lists_resized
|
| 54 |
+
|
| 55 |
+
def inv_transform(self, prob_map):
|
| 56 |
+
new_prob_map = F.interpolate(prob_map, (self.image_height, self.image_width), mode='bilinear',
|
| 57 |
+
align_corners=True)
|
| 58 |
+
|
| 59 |
+
return new_prob_map
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ISPredictor(object):
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
model,
|
| 66 |
+
device,
|
| 67 |
+
open_kernel_size: int,
|
| 68 |
+
dilate_kernel_size: int,
|
| 69 |
+
net_clicks_limit=None,
|
| 70 |
+
zoom_in=None,
|
| 71 |
+
infer_size=384,
|
| 72 |
+
):
|
| 73 |
+
self.model = model
|
| 74 |
+
self.open_kernel_size = open_kernel_size
|
| 75 |
+
self.dilate_kernel_size = dilate_kernel_size
|
| 76 |
+
self.net_clicks_limit = net_clicks_limit
|
| 77 |
+
self.device = device
|
| 78 |
+
self.zoom_in = zoom_in
|
| 79 |
+
self.infer_size = infer_size
|
| 80 |
+
|
| 81 |
+
# self.transforms = [zoom_in] if zoom_in is not None else []
|
| 82 |
+
|
| 83 |
+
def __call__(self, input_image: torch.Tensor, clicks: List[Click], prev_mask):
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
input_image: [1, 3, H, W] [0~1]
|
| 88 |
+
clicks: List[Click]
|
| 89 |
+
prev_mask: [1, 1, H, W]
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
|
| 93 |
+
"""
|
| 94 |
+
transforms = [ResizeTrans(self.infer_size)]
|
| 95 |
+
input_image = torch.cat((input_image, prev_mask), dim=1)
|
| 96 |
+
|
| 97 |
+
# image_nd resized to infer_size
|
| 98 |
+
for t in transforms:
|
| 99 |
+
image_nd, clicks_lists = t.transform(input_image, [clicks])
|
| 100 |
+
|
| 101 |
+
# image_nd.shape = [1, 4, 256, 256]
|
| 102 |
+
# points_nd.sha[e = [1, 2, 3]
|
| 103 |
+
# clicks_lists[0][0] Click 类
|
| 104 |
+
points_nd = self.get_points_nd(clicks_lists)
|
| 105 |
+
pred_logits = self.model(image_nd, points_nd)
|
| 106 |
+
pred = torch.sigmoid(pred_logits)
|
| 107 |
+
pred = self.post_process(pred)
|
| 108 |
+
|
| 109 |
+
prediction = F.interpolate(pred, mode='bilinear', align_corners=True,
|
| 110 |
+
size=image_nd.size()[2:])
|
| 111 |
+
|
| 112 |
+
for t in reversed(transforms):
|
| 113 |
+
prediction = t.inv_transform(prediction)
|
| 114 |
+
|
| 115 |
+
# if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
|
| 116 |
+
# return self.get_prediction(clicker)
|
| 117 |
+
|
| 118 |
+
return prediction.cpu().numpy()[0, 0]
|
| 119 |
+
|
| 120 |
+
def post_process(self, pred: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
pred_mask = pred.cpu().numpy()[0][0]
|
| 122 |
+
# morph_open to remove small noise
|
| 123 |
+
kernel_size = self.open_kernel_size
|
| 124 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
| 125 |
+
pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1)
|
| 126 |
+
|
| 127 |
+
# Why dilate: make region slightly larger to avoid missing some pixels, this generally works better
|
| 128 |
+
dilate_kernel_size = self.dilate_kernel_size
|
| 129 |
+
if dilate_kernel_size > 1:
|
| 130 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size))
|
| 131 |
+
pred_mask = cv2.dilate(pred_mask, kernel, 1)
|
| 132 |
+
return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0)
|
| 133 |
+
|
| 134 |
+
def get_points_nd(self, clicks_lists):
|
| 135 |
+
total_clicks = []
|
| 136 |
+
num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
|
| 137 |
+
num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
|
| 138 |
+
num_max_points = max(num_pos_clicks + num_neg_clicks)
|
| 139 |
+
if self.net_clicks_limit is not None:
|
| 140 |
+
num_max_points = min(self.net_clicks_limit, num_max_points)
|
| 141 |
+
num_max_points = max(1, num_max_points)
|
| 142 |
+
|
| 143 |
+
for clicks_list in clicks_lists:
|
| 144 |
+
clicks_list = clicks_list[:self.net_clicks_limit]
|
| 145 |
+
pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
|
| 146 |
+
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
|
| 147 |
+
|
| 148 |
+
neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
|
| 149 |
+
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
|
| 150 |
+
total_clicks.append(pos_clicks + neg_clicks)
|
| 151 |
+
|
| 152 |
+
return torch.tensor(total_clicks, device=self.device)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
INTERACTIVE_SEG_MODEL_URL = os.environ.get(
|
| 156 |
+
"INTERACTIVE_SEG_MODEL_URL",
|
| 157 |
+
"https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class InteractiveSeg:
|
| 162 |
+
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
|
| 163 |
+
device = torch.device('cpu')
|
| 164 |
+
model = load_jit_model(INTERACTIVE_SEG_MODEL_URL, device).eval()
|
| 165 |
+
self.predictor = ISPredictor(model, device,
|
| 166 |
+
infer_size=infer_size,
|
| 167 |
+
open_kernel_size=open_kernel_size,
|
| 168 |
+
dilate_kernel_size=dilate_kernel_size)
|
| 169 |
+
|
| 170 |
+
def __call__(self, image, clicks, prev_mask=None):
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
image: [H,W,C] RGB
|
| 175 |
+
clicks:
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
|
| 179 |
+
"""
|
| 180 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 181 |
+
image = torch.from_numpy((image / 255).transpose(2, 0, 1)).unsqueeze(0).float()
|
| 182 |
+
if prev_mask is None:
|
| 183 |
+
mask = torch.zeros_like(image[:, :1, :, :])
|
| 184 |
+
else:
|
| 185 |
+
logger.info('InteractiveSeg run with prev_mask')
|
| 186 |
+
mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float()
|
| 187 |
+
|
| 188 |
+
pred_probs = self.predictor(image, clicks, mask)
|
| 189 |
+
pred_mask = pred_probs > 0.5
|
| 190 |
+
pred_mask = (pred_mask * 255).astype(np.uint8)
|
| 191 |
+
|
| 192 |
+
# Find largest contour
|
| 193 |
+
# pred_mask = only_keep_largest_contour(pred_mask)
|
| 194 |
+
# To simplify frontend process, add mask brush color here
|
| 195 |
+
fg = pred_mask == 255
|
| 196 |
+
bg = pred_mask != 255
|
| 197 |
+
pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_GRAY2BGRA)
|
| 198 |
+
# frontend brush color "ffcc00bb"
|
| 199 |
+
pred_mask[bg] = 0
|
| 200 |
+
pred_mask[fg] = [255, 203, 0, int(255 * 0.73)]
|
| 201 |
+
pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_BGRA2RGBA)
|
| 202 |
+
return pred_mask
|
lama_cleaner/model/__init__.py
ADDED
|
File without changes
|
lama_cleaner/model/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (172 Bytes). View file
|
|
|
lama_cleaner/model/__pycache__/base.cpython-38.pyc
ADDED
|
Binary file (6.7 kB). View file
|
|
|
lama_cleaner/model/__pycache__/ddim_sampler.cpython-38.pyc
ADDED
|
Binary file (4.74 kB). View file
|
|
|
lama_cleaner/model/__pycache__/fcf.cpython-38.pyc
ADDED
|
Binary file (33.4 kB). View file
|
|
|
lama_cleaner/model/__pycache__/lama.cpython-38.pyc
ADDED
|
Binary file (2.12 kB). View file
|
|
|
lama_cleaner/model/__pycache__/ldm.cpython-38.pyc
ADDED
|
Binary file (7.79 kB). View file
|
|
|
lama_cleaner/model/__pycache__/manga.cpython-38.pyc
ADDED
|
Binary file (2.72 kB). View file
|
|
|
lama_cleaner/model/__pycache__/mat.cpython-38.pyc
ADDED
|
Binary file (38.8 kB). View file
|
|
|
lama_cleaner/model/__pycache__/opencv2.cpython-38.pyc
ADDED
|
Binary file (1.13 kB). View file
|
|
|
lama_cleaner/model/__pycache__/paint_by_example.cpython-38.pyc
ADDED
|
Binary file (4.25 kB). View file
|
|
|
lama_cleaner/model/__pycache__/plms_sampler.cpython-38.pyc
ADDED
|
Binary file (7.09 kB). View file
|
|
|
lama_cleaner/model/__pycache__/sd.cpython-38.pyc
ADDED
|
Binary file (6.26 kB). View file
|
|
|
lama_cleaner/model/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (26.3 kB). View file
|
|
|
lama_cleaner/model/__pycache__/zits.cpython-38.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
lama_cleaner/model/base.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
|
| 10 |
+
from lama_cleaner.schema import Config, HDStrategy
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class InpaintModel:
|
| 14 |
+
min_size: Optional[int] = None
|
| 15 |
+
pad_mod = 8
|
| 16 |
+
pad_to_square = False
|
| 17 |
+
|
| 18 |
+
def __init__(self, device, **kwargs):
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
device:
|
| 23 |
+
"""
|
| 24 |
+
self.device = device
|
| 25 |
+
self.init_model(device, **kwargs)
|
| 26 |
+
|
| 27 |
+
@abc.abstractmethod
|
| 28 |
+
def init_model(self, device, **kwargs):
|
| 29 |
+
...
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
@abc.abstractmethod
|
| 33 |
+
def is_downloaded() -> bool:
|
| 34 |
+
...
|
| 35 |
+
|
| 36 |
+
@abc.abstractmethod
|
| 37 |
+
def forward(self, image, mask, config: Config):
|
| 38 |
+
"""Input images and output images have same size
|
| 39 |
+
images: [H, W, C] RGB
|
| 40 |
+
masks: [H, W, 1] 255 为 masks 区域
|
| 41 |
+
return: BGR IMAGE
|
| 42 |
+
"""
|
| 43 |
+
...
|
| 44 |
+
|
| 45 |
+
def _pad_forward(self, image, mask, config: Config):
|
| 46 |
+
origin_height, origin_width = image.shape[:2]
|
| 47 |
+
pad_image = pad_img_to_modulo(
|
| 48 |
+
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
| 49 |
+
)
|
| 50 |
+
pad_mask = pad_img_to_modulo(
|
| 51 |
+
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
logger.info(f"final forward pad size: {pad_image.shape}")
|
| 55 |
+
|
| 56 |
+
result = self.forward(pad_image, pad_mask, config)
|
| 57 |
+
result = result[0:origin_height, 0:origin_width, :]
|
| 58 |
+
|
| 59 |
+
result, image, mask = self.forward_post_process(result, image, mask, config)
|
| 60 |
+
|
| 61 |
+
mask = mask[:, :, np.newaxis]
|
| 62 |
+
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
|
| 63 |
+
return result
|
| 64 |
+
|
| 65 |
+
def forward_post_process(self, result, image, mask, config):
|
| 66 |
+
return result, image, mask
|
| 67 |
+
|
| 68 |
+
@torch.no_grad()
|
| 69 |
+
def __call__(self, image, mask, config: Config):
|
| 70 |
+
"""
|
| 71 |
+
images: [H, W, C] RGB, not normalized
|
| 72 |
+
masks: [H, W]
|
| 73 |
+
return: BGR IMAGE
|
| 74 |
+
"""
|
| 75 |
+
inpaint_result = None
|
| 76 |
+
logger.info(f"hd_strategy: {config.hd_strategy}")
|
| 77 |
+
if config.hd_strategy == HDStrategy.CROP:
|
| 78 |
+
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
| 79 |
+
logger.info(f"Run crop strategy")
|
| 80 |
+
boxes = boxes_from_mask(mask)
|
| 81 |
+
crop_result = []
|
| 82 |
+
for box in boxes:
|
| 83 |
+
crop_image, crop_box = self._run_box(image, mask, box, config)
|
| 84 |
+
crop_result.append((crop_image, crop_box))
|
| 85 |
+
|
| 86 |
+
inpaint_result = image[:, :, ::-1]
|
| 87 |
+
for crop_image, crop_box in crop_result:
|
| 88 |
+
x1, y1, x2, y2 = crop_box
|
| 89 |
+
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
| 90 |
+
|
| 91 |
+
elif config.hd_strategy == HDStrategy.RESIZE:
|
| 92 |
+
if max(image.shape) > config.hd_strategy_resize_limit:
|
| 93 |
+
origin_size = image.shape[:2]
|
| 94 |
+
downsize_image = resize_max_size(
|
| 95 |
+
image, size_limit=config.hd_strategy_resize_limit
|
| 96 |
+
)
|
| 97 |
+
downsize_mask = resize_max_size(
|
| 98 |
+
mask, size_limit=config.hd_strategy_resize_limit
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
logger.info(
|
| 102 |
+
f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
|
| 103 |
+
)
|
| 104 |
+
inpaint_result = self._pad_forward(
|
| 105 |
+
downsize_image, downsize_mask, config
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# only paste masked area result
|
| 109 |
+
inpaint_result = cv2.resize(
|
| 110 |
+
inpaint_result,
|
| 111 |
+
(origin_size[1], origin_size[0]),
|
| 112 |
+
interpolation=cv2.INTER_CUBIC,
|
| 113 |
+
)
|
| 114 |
+
original_pixel_indices = mask < 127
|
| 115 |
+
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
| 116 |
+
original_pixel_indices
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
if inpaint_result is None:
|
| 120 |
+
inpaint_result = self._pad_forward(image, mask, config)
|
| 121 |
+
|
| 122 |
+
return inpaint_result
|
| 123 |
+
|
| 124 |
+
def _crop_box(self, image, mask, box, config: Config):
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
image: [H, W, C] RGB
|
| 129 |
+
mask: [H, W, 1]
|
| 130 |
+
box: [left,top,right,bottom]
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
BGR IMAGE, (l, r, r, b)
|
| 134 |
+
"""
|
| 135 |
+
box_h = box[3] - box[1]
|
| 136 |
+
box_w = box[2] - box[0]
|
| 137 |
+
cx = (box[0] + box[2]) // 2
|
| 138 |
+
cy = (box[1] + box[3]) // 2
|
| 139 |
+
img_h, img_w = image.shape[:2]
|
| 140 |
+
|
| 141 |
+
w = box_w + config.hd_strategy_crop_margin * 2
|
| 142 |
+
h = box_h + config.hd_strategy_crop_margin * 2
|
| 143 |
+
|
| 144 |
+
_l = cx - w // 2
|
| 145 |
+
_r = cx + w // 2
|
| 146 |
+
_t = cy - h // 2
|
| 147 |
+
_b = cy + h // 2
|
| 148 |
+
|
| 149 |
+
l = max(_l, 0)
|
| 150 |
+
r = min(_r, img_w)
|
| 151 |
+
t = max(_t, 0)
|
| 152 |
+
b = min(_b, img_h)
|
| 153 |
+
|
| 154 |
+
# try to get more context when crop around image edge
|
| 155 |
+
if _l < 0:
|
| 156 |
+
r += abs(_l)
|
| 157 |
+
if _r > img_w:
|
| 158 |
+
l -= _r - img_w
|
| 159 |
+
if _t < 0:
|
| 160 |
+
b += abs(_t)
|
| 161 |
+
if _b > img_h:
|
| 162 |
+
t -= _b - img_h
|
| 163 |
+
|
| 164 |
+
l = max(l, 0)
|
| 165 |
+
r = min(r, img_w)
|
| 166 |
+
t = max(t, 0)
|
| 167 |
+
b = min(b, img_h)
|
| 168 |
+
|
| 169 |
+
crop_img = image[t:b, l:r, :]
|
| 170 |
+
crop_mask = mask[t:b, l:r]
|
| 171 |
+
|
| 172 |
+
logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
| 173 |
+
|
| 174 |
+
return crop_img, crop_mask, [l, t, r, b]
|
| 175 |
+
|
| 176 |
+
def _calculate_cdf(self, histogram):
|
| 177 |
+
cdf = histogram.cumsum()
|
| 178 |
+
normalized_cdf = cdf / float(cdf.max())
|
| 179 |
+
return normalized_cdf
|
| 180 |
+
|
| 181 |
+
def _calculate_lookup(self, source_cdf, reference_cdf):
|
| 182 |
+
lookup_table = np.zeros(256)
|
| 183 |
+
lookup_val = 0
|
| 184 |
+
for source_index, source_val in enumerate(source_cdf):
|
| 185 |
+
for reference_index, reference_val in enumerate(reference_cdf):
|
| 186 |
+
if reference_val >= source_val:
|
| 187 |
+
lookup_val = reference_index
|
| 188 |
+
break
|
| 189 |
+
lookup_table[source_index] = lookup_val
|
| 190 |
+
return lookup_table
|
| 191 |
+
|
| 192 |
+
def _match_histograms(self, source, reference, mask):
|
| 193 |
+
transformed_channels = []
|
| 194 |
+
for channel in range(source.shape[-1]):
|
| 195 |
+
source_channel = source[:, :, channel]
|
| 196 |
+
reference_channel = reference[:, :, channel]
|
| 197 |
+
|
| 198 |
+
# only calculate histograms for non-masked parts
|
| 199 |
+
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
|
| 200 |
+
reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0, 256])
|
| 201 |
+
|
| 202 |
+
source_cdf = self._calculate_cdf(source_histogram)
|
| 203 |
+
reference_cdf = self._calculate_cdf(reference_histogram)
|
| 204 |
+
|
| 205 |
+
lookup = self._calculate_lookup(source_cdf, reference_cdf)
|
| 206 |
+
|
| 207 |
+
transformed_channels.append(cv2.LUT(source_channel, lookup))
|
| 208 |
+
|
| 209 |
+
result = cv2.merge(transformed_channels)
|
| 210 |
+
result = cv2.convertScaleAbs(result)
|
| 211 |
+
|
| 212 |
+
return result
|
| 213 |
+
|
| 214 |
+
def _apply_cropper(self, image, mask, config: Config):
|
| 215 |
+
img_h, img_w = image.shape[:2]
|
| 216 |
+
l, t, w, h = (
|
| 217 |
+
config.croper_x,
|
| 218 |
+
config.croper_y,
|
| 219 |
+
config.croper_width,
|
| 220 |
+
config.croper_height,
|
| 221 |
+
)
|
| 222 |
+
r = l + w
|
| 223 |
+
b = t + h
|
| 224 |
+
|
| 225 |
+
l = max(l, 0)
|
| 226 |
+
r = min(r, img_w)
|
| 227 |
+
t = max(t, 0)
|
| 228 |
+
b = min(b, img_h)
|
| 229 |
+
|
| 230 |
+
crop_img = image[t:b, l:r, :]
|
| 231 |
+
crop_mask = mask[t:b, l:r]
|
| 232 |
+
return crop_img, crop_mask, (l, t, r, b)
|
| 233 |
+
|
| 234 |
+
def _run_box(self, image, mask, box, config: Config):
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
image: [H, W, C] RGB
|
| 239 |
+
mask: [H, W, 1]
|
| 240 |
+
box: [left,top,right,bottom]
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
BGR IMAGE
|
| 244 |
+
"""
|
| 245 |
+
crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
|
| 246 |
+
|
| 247 |
+
return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
|
lama_cleaner/model/ddim_sampler.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DDIMSampler(object):
|
| 10 |
+
def __init__(self, model, schedule="linear"):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.model = model
|
| 13 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
| 14 |
+
self.schedule = schedule
|
| 15 |
+
|
| 16 |
+
def register_buffer(self, name, attr):
|
| 17 |
+
setattr(self, name, attr)
|
| 18 |
+
|
| 19 |
+
def make_schedule(
|
| 20 |
+
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
| 21 |
+
):
|
| 22 |
+
self.ddim_timesteps = make_ddim_timesteps(
|
| 23 |
+
ddim_discr_method=ddim_discretize,
|
| 24 |
+
num_ddim_timesteps=ddim_num_steps,
|
| 25 |
+
# array([1])
|
| 26 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
| 27 |
+
verbose=verbose,
|
| 28 |
+
)
|
| 29 |
+
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
|
| 30 |
+
assert (
|
| 31 |
+
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
| 32 |
+
), "alphas have to be defined for each timestep"
|
| 33 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
| 34 |
+
|
| 35 |
+
self.register_buffer("betas", to_torch(self.model.betas))
|
| 36 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
| 37 |
+
self.register_buffer(
|
| 38 |
+
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 42 |
+
self.register_buffer(
|
| 43 |
+
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
| 44 |
+
)
|
| 45 |
+
self.register_buffer(
|
| 46 |
+
"sqrt_one_minus_alphas_cumprod",
|
| 47 |
+
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
| 48 |
+
)
|
| 49 |
+
self.register_buffer(
|
| 50 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
| 51 |
+
)
|
| 52 |
+
self.register_buffer(
|
| 53 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
| 54 |
+
)
|
| 55 |
+
self.register_buffer(
|
| 56 |
+
"sqrt_recipm1_alphas_cumprod",
|
| 57 |
+
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# ddim sampling parameters
|
| 61 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
| 62 |
+
alphacums=alphas_cumprod.cpu(),
|
| 63 |
+
ddim_timesteps=self.ddim_timesteps,
|
| 64 |
+
eta=ddim_eta,
|
| 65 |
+
verbose=verbose,
|
| 66 |
+
)
|
| 67 |
+
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
| 68 |
+
self.register_buffer("ddim_alphas", ddim_alphas)
|
| 69 |
+
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
| 70 |
+
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
| 71 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
| 72 |
+
(1 - self.alphas_cumprod_prev)
|
| 73 |
+
/ (1 - self.alphas_cumprod)
|
| 74 |
+
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
| 75 |
+
)
|
| 76 |
+
self.register_buffer(
|
| 77 |
+
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
@torch.no_grad()
|
| 81 |
+
def sample(self, steps, conditioning, batch_size, shape):
|
| 82 |
+
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
|
| 83 |
+
# sampling
|
| 84 |
+
C, H, W = shape
|
| 85 |
+
size = (batch_size, C, H, W)
|
| 86 |
+
|
| 87 |
+
# samples: 1,3,128,128
|
| 88 |
+
return self.ddim_sampling(
|
| 89 |
+
conditioning,
|
| 90 |
+
size,
|
| 91 |
+
quantize_denoised=False,
|
| 92 |
+
ddim_use_original_steps=False,
|
| 93 |
+
noise_dropout=0,
|
| 94 |
+
temperature=1.0,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
@torch.no_grad()
|
| 98 |
+
def ddim_sampling(
|
| 99 |
+
self,
|
| 100 |
+
cond,
|
| 101 |
+
shape,
|
| 102 |
+
ddim_use_original_steps=False,
|
| 103 |
+
quantize_denoised=False,
|
| 104 |
+
temperature=1.0,
|
| 105 |
+
noise_dropout=0.0,
|
| 106 |
+
):
|
| 107 |
+
device = self.model.betas.device
|
| 108 |
+
b = shape[0]
|
| 109 |
+
img = torch.randn(shape, device=device, dtype=cond.dtype)
|
| 110 |
+
timesteps = (
|
| 111 |
+
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
time_range = (
|
| 115 |
+
reversed(range(0, timesteps))
|
| 116 |
+
if ddim_use_original_steps
|
| 117 |
+
else np.flip(timesteps)
|
| 118 |
+
)
|
| 119 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 120 |
+
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 121 |
+
|
| 122 |
+
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
| 123 |
+
|
| 124 |
+
for i, step in enumerate(iterator):
|
| 125 |
+
index = total_steps - i - 1
|
| 126 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
| 127 |
+
|
| 128 |
+
outs = self.p_sample_ddim(
|
| 129 |
+
img,
|
| 130 |
+
cond,
|
| 131 |
+
ts,
|
| 132 |
+
index=index,
|
| 133 |
+
use_original_steps=ddim_use_original_steps,
|
| 134 |
+
quantize_denoised=quantize_denoised,
|
| 135 |
+
temperature=temperature,
|
| 136 |
+
noise_dropout=noise_dropout,
|
| 137 |
+
)
|
| 138 |
+
img, _ = outs
|
| 139 |
+
|
| 140 |
+
return img
|
| 141 |
+
|
| 142 |
+
@torch.no_grad()
|
| 143 |
+
def p_sample_ddim(
|
| 144 |
+
self,
|
| 145 |
+
x,
|
| 146 |
+
c,
|
| 147 |
+
t,
|
| 148 |
+
index,
|
| 149 |
+
repeat_noise=False,
|
| 150 |
+
use_original_steps=False,
|
| 151 |
+
quantize_denoised=False,
|
| 152 |
+
temperature=1.0,
|
| 153 |
+
noise_dropout=0.0,
|
| 154 |
+
):
|
| 155 |
+
b, *_, device = *x.shape, x.device
|
| 156 |
+
e_t = self.model.apply_model(x, t, c)
|
| 157 |
+
|
| 158 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 159 |
+
alphas_prev = (
|
| 160 |
+
self.model.alphas_cumprod_prev
|
| 161 |
+
if use_original_steps
|
| 162 |
+
else self.ddim_alphas_prev
|
| 163 |
+
)
|
| 164 |
+
sqrt_one_minus_alphas = (
|
| 165 |
+
self.model.sqrt_one_minus_alphas_cumprod
|
| 166 |
+
if use_original_steps
|
| 167 |
+
else self.ddim_sqrt_one_minus_alphas
|
| 168 |
+
)
|
| 169 |
+
sigmas = (
|
| 170 |
+
self.model.ddim_sigmas_for_original_num_steps
|
| 171 |
+
if use_original_steps
|
| 172 |
+
else self.ddim_sigmas
|
| 173 |
+
)
|
| 174 |
+
# select parameters corresponding to the currently considered timestep
|
| 175 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
| 176 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
| 177 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
| 178 |
+
sqrt_one_minus_at = torch.full(
|
| 179 |
+
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# current prediction for x_0
|
| 183 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 184 |
+
if quantize_denoised: # 没用
|
| 185 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
| 186 |
+
# direction pointing to x_t
|
| 187 |
+
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
| 188 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 189 |
+
if noise_dropout > 0.0: # 没用
|
| 190 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 191 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 192 |
+
return x_prev, pred_x0
|
lama_cleaner/model/fcf.py
ADDED
|
@@ -0,0 +1,1212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.fft as fft
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import conv2d, nn
|
| 10 |
+
|
| 11 |
+
from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img, boxes_from_mask, resize_max_size
|
| 12 |
+
from lama_cleaner.model.base import InpaintModel
|
| 13 |
+
from lama_cleaner.model.utils import setup_filter, _parse_scaling, _parse_padding, Conv2dLayer, FullyConnectedLayer, \
|
| 14 |
+
MinibatchStdLayer, activation_funcs, conv2d_resample, bias_act, upsample2d, normalize_2nd_moment, downsample2d
|
| 15 |
+
from lama_cleaner.schema import Config
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
| 19 |
+
assert isinstance(x, torch.Tensor)
|
| 20 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
| 24 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
| 25 |
+
"""
|
| 26 |
+
# Validate arguments.
|
| 27 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
| 28 |
+
if f is None:
|
| 29 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
| 30 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
| 31 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
| 32 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
| 33 |
+
upx, upy = _parse_scaling(up)
|
| 34 |
+
downx, downy = _parse_scaling(down)
|
| 35 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
| 36 |
+
|
| 37 |
+
# Upsample by inserting zeros.
|
| 38 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
| 39 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
| 40 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
| 41 |
+
|
| 42 |
+
# Pad or crop.
|
| 43 |
+
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
| 44 |
+
x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)]
|
| 45 |
+
|
| 46 |
+
# Setup filter.
|
| 47 |
+
f = f * (gain ** (f.ndim / 2))
|
| 48 |
+
f = f.to(x.dtype)
|
| 49 |
+
if not flip_filter:
|
| 50 |
+
f = f.flip(list(range(f.ndim)))
|
| 51 |
+
|
| 52 |
+
# Convolve with the filter.
|
| 53 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
| 54 |
+
if f.ndim == 4:
|
| 55 |
+
x = conv2d(input=x, weight=f, groups=num_channels)
|
| 56 |
+
else:
|
| 57 |
+
x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
| 58 |
+
x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
| 59 |
+
|
| 60 |
+
# Downsample by throwing away pixels.
|
| 61 |
+
x = x[:, :, ::downy, ::downx]
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class EncoderEpilogue(torch.nn.Module):
|
| 66 |
+
def __init__(self,
|
| 67 |
+
in_channels, # Number of input channels.
|
| 68 |
+
cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
|
| 69 |
+
z_dim, # Output Latent (Z) dimensionality.
|
| 70 |
+
resolution, # Resolution of this block.
|
| 71 |
+
img_channels, # Number of input color channels.
|
| 72 |
+
architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
| 73 |
+
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
| 74 |
+
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
| 75 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 76 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 77 |
+
):
|
| 78 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.in_channels = in_channels
|
| 81 |
+
self.cmap_dim = cmap_dim
|
| 82 |
+
self.resolution = resolution
|
| 83 |
+
self.img_channels = img_channels
|
| 84 |
+
self.architecture = architecture
|
| 85 |
+
|
| 86 |
+
if architecture == 'skip':
|
| 87 |
+
self.fromrgb = Conv2dLayer(self.img_channels, in_channels, kernel_size=1, activation=activation)
|
| 88 |
+
self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size,
|
| 89 |
+
num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
|
| 90 |
+
self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation,
|
| 91 |
+
conv_clamp=conv_clamp)
|
| 92 |
+
self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), z_dim, activation=activation)
|
| 93 |
+
self.dropout = torch.nn.Dropout(p=0.5)
|
| 94 |
+
|
| 95 |
+
def forward(self, x, cmap, force_fp32=False):
|
| 96 |
+
_ = force_fp32 # unused
|
| 97 |
+
dtype = torch.float32
|
| 98 |
+
memory_format = torch.contiguous_format
|
| 99 |
+
|
| 100 |
+
# FromRGB.
|
| 101 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
| 102 |
+
|
| 103 |
+
# Main layers.
|
| 104 |
+
if self.mbstd is not None:
|
| 105 |
+
x = self.mbstd(x)
|
| 106 |
+
const_e = self.conv(x)
|
| 107 |
+
x = self.fc(const_e.flatten(1))
|
| 108 |
+
x = self.dropout(x)
|
| 109 |
+
|
| 110 |
+
# Conditioning.
|
| 111 |
+
if self.cmap_dim > 0:
|
| 112 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
| 113 |
+
|
| 114 |
+
assert x.dtype == dtype
|
| 115 |
+
return x, const_e
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class EncoderBlock(torch.nn.Module):
|
| 119 |
+
def __init__(self,
|
| 120 |
+
in_channels, # Number of input channels, 0 = first block.
|
| 121 |
+
tmp_channels, # Number of intermediate channels.
|
| 122 |
+
out_channels, # Number of output channels.
|
| 123 |
+
resolution, # Resolution of this block.
|
| 124 |
+
img_channels, # Number of input color channels.
|
| 125 |
+
first_layer_idx, # Index of the first layer.
|
| 126 |
+
architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
|
| 127 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 128 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
| 129 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 130 |
+
use_fp16=False, # Use FP16 for this block?
|
| 131 |
+
fp16_channels_last=False, # Use channels-last memory format with FP16?
|
| 132 |
+
freeze_layers=0, # Freeze-D: Number of layers to freeze.
|
| 133 |
+
):
|
| 134 |
+
assert in_channels in [0, tmp_channels]
|
| 135 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.in_channels = in_channels
|
| 138 |
+
self.resolution = resolution
|
| 139 |
+
self.img_channels = img_channels + 1
|
| 140 |
+
self.first_layer_idx = first_layer_idx
|
| 141 |
+
self.architecture = architecture
|
| 142 |
+
self.use_fp16 = use_fp16
|
| 143 |
+
self.channels_last = (use_fp16 and fp16_channels_last)
|
| 144 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
| 145 |
+
|
| 146 |
+
self.num_layers = 0
|
| 147 |
+
|
| 148 |
+
def trainable_gen():
|
| 149 |
+
while True:
|
| 150 |
+
layer_idx = self.first_layer_idx + self.num_layers
|
| 151 |
+
trainable = (layer_idx >= freeze_layers)
|
| 152 |
+
self.num_layers += 1
|
| 153 |
+
yield trainable
|
| 154 |
+
|
| 155 |
+
trainable_iter = trainable_gen()
|
| 156 |
+
|
| 157 |
+
if in_channels == 0:
|
| 158 |
+
self.fromrgb = Conv2dLayer(self.img_channels, tmp_channels, kernel_size=1, activation=activation,
|
| 159 |
+
trainable=next(trainable_iter), conv_clamp=conv_clamp,
|
| 160 |
+
channels_last=self.channels_last)
|
| 161 |
+
|
| 162 |
+
self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
|
| 163 |
+
trainable=next(trainable_iter), conv_clamp=conv_clamp,
|
| 164 |
+
channels_last=self.channels_last)
|
| 165 |
+
|
| 166 |
+
self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
|
| 167 |
+
trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp,
|
| 168 |
+
channels_last=self.channels_last)
|
| 169 |
+
|
| 170 |
+
if architecture == 'resnet':
|
| 171 |
+
self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
|
| 172 |
+
trainable=next(trainable_iter), resample_filter=resample_filter,
|
| 173 |
+
channels_last=self.channels_last)
|
| 174 |
+
|
| 175 |
+
def forward(self, x, img, force_fp32=False):
|
| 176 |
+
# dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
| 177 |
+
dtype = torch.float32
|
| 178 |
+
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
| 179 |
+
|
| 180 |
+
# Input.
|
| 181 |
+
if x is not None:
|
| 182 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
| 183 |
+
|
| 184 |
+
# FromRGB.
|
| 185 |
+
if self.in_channels == 0:
|
| 186 |
+
img = img.to(dtype=dtype, memory_format=memory_format)
|
| 187 |
+
y = self.fromrgb(img)
|
| 188 |
+
x = x + y if x is not None else y
|
| 189 |
+
img = downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
|
| 190 |
+
|
| 191 |
+
# Main layers.
|
| 192 |
+
if self.architecture == 'resnet':
|
| 193 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
| 194 |
+
x = self.conv0(x)
|
| 195 |
+
feat = x.clone()
|
| 196 |
+
x = self.conv1(x, gain=np.sqrt(0.5))
|
| 197 |
+
x = y.add_(x)
|
| 198 |
+
else:
|
| 199 |
+
x = self.conv0(x)
|
| 200 |
+
feat = x.clone()
|
| 201 |
+
x = self.conv1(x)
|
| 202 |
+
|
| 203 |
+
assert x.dtype == dtype
|
| 204 |
+
return x, img, feat
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class EncoderNetwork(torch.nn.Module):
|
| 208 |
+
def __init__(self,
|
| 209 |
+
c_dim, # Conditioning label (C) dimensionality.
|
| 210 |
+
z_dim, # Input latent (Z) dimensionality.
|
| 211 |
+
img_resolution, # Input resolution.
|
| 212 |
+
img_channels, # Number of input color channels.
|
| 213 |
+
architecture='orig', # Architecture: 'orig', 'skip', 'resnet'.
|
| 214 |
+
channel_base=16384, # Overall multiplier for the number of channels.
|
| 215 |
+
channel_max=512, # Maximum number of channels in any layer.
|
| 216 |
+
num_fp16_res=0, # Use FP16 for the N highest resolutions.
|
| 217 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 218 |
+
cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
|
| 219 |
+
block_kwargs={}, # Arguments for DiscriminatorBlock.
|
| 220 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
| 221 |
+
epilogue_kwargs={}, # Arguments for EncoderEpilogue.
|
| 222 |
+
):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.c_dim = c_dim
|
| 225 |
+
self.z_dim = z_dim
|
| 226 |
+
self.img_resolution = img_resolution
|
| 227 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
| 228 |
+
self.img_channels = img_channels
|
| 229 |
+
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
|
| 230 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
|
| 231 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
| 232 |
+
|
| 233 |
+
if cmap_dim is None:
|
| 234 |
+
cmap_dim = channels_dict[4]
|
| 235 |
+
if c_dim == 0:
|
| 236 |
+
cmap_dim = 0
|
| 237 |
+
|
| 238 |
+
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
|
| 239 |
+
cur_layer_idx = 0
|
| 240 |
+
for res in self.block_resolutions:
|
| 241 |
+
in_channels = channels_dict[res] if res < img_resolution else 0
|
| 242 |
+
tmp_channels = channels_dict[res]
|
| 243 |
+
out_channels = channels_dict[res // 2]
|
| 244 |
+
use_fp16 = (res >= fp16_resolution)
|
| 245 |
+
use_fp16 = False
|
| 246 |
+
block = EncoderBlock(in_channels, tmp_channels, out_channels, resolution=res,
|
| 247 |
+
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
|
| 248 |
+
setattr(self, f'b{res}', block)
|
| 249 |
+
cur_layer_idx += block.num_layers
|
| 250 |
+
if c_dim > 0:
|
| 251 |
+
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None,
|
| 252 |
+
**mapping_kwargs)
|
| 253 |
+
self.b4 = EncoderEpilogue(channels_dict[4], cmap_dim=cmap_dim, z_dim=z_dim * 2, resolution=4, **epilogue_kwargs,
|
| 254 |
+
**common_kwargs)
|
| 255 |
+
|
| 256 |
+
def forward(self, img, c, **block_kwargs):
|
| 257 |
+
x = None
|
| 258 |
+
feats = {}
|
| 259 |
+
for res in self.block_resolutions:
|
| 260 |
+
block = getattr(self, f'b{res}')
|
| 261 |
+
x, img, feat = block(x, img, **block_kwargs)
|
| 262 |
+
feats[res] = feat
|
| 263 |
+
|
| 264 |
+
cmap = None
|
| 265 |
+
if self.c_dim > 0:
|
| 266 |
+
cmap = self.mapping(None, c)
|
| 267 |
+
x, const_e = self.b4(x, cmap)
|
| 268 |
+
feats[4] = const_e
|
| 269 |
+
|
| 270 |
+
B, _ = x.shape
|
| 271 |
+
z = torch.zeros((B, self.z_dim), requires_grad=False, dtype=x.dtype,
|
| 272 |
+
device=x.device) ## Noise for Co-Modulation
|
| 273 |
+
return x, z, feats
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def fma(a, b, c): # => a * b + c
|
| 277 |
+
return _FusedMultiplyAdd.apply(a, b, c)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
| 281 |
+
@staticmethod
|
| 282 |
+
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
| 283 |
+
out = torch.addcmul(c, a, b)
|
| 284 |
+
ctx.save_for_backward(a, b)
|
| 285 |
+
ctx.c_shape = c.shape
|
| 286 |
+
return out
|
| 287 |
+
|
| 288 |
+
@staticmethod
|
| 289 |
+
def backward(ctx, dout): # pylint: disable=arguments-differ
|
| 290 |
+
a, b = ctx.saved_tensors
|
| 291 |
+
c_shape = ctx.c_shape
|
| 292 |
+
da = None
|
| 293 |
+
db = None
|
| 294 |
+
dc = None
|
| 295 |
+
|
| 296 |
+
if ctx.needs_input_grad[0]:
|
| 297 |
+
da = _unbroadcast(dout * b, a.shape)
|
| 298 |
+
|
| 299 |
+
if ctx.needs_input_grad[1]:
|
| 300 |
+
db = _unbroadcast(dout * a, b.shape)
|
| 301 |
+
|
| 302 |
+
if ctx.needs_input_grad[2]:
|
| 303 |
+
dc = _unbroadcast(dout, c_shape)
|
| 304 |
+
|
| 305 |
+
return da, db, dc
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _unbroadcast(x, shape):
|
| 309 |
+
extra_dims = x.ndim - len(shape)
|
| 310 |
+
assert extra_dims >= 0
|
| 311 |
+
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
| 312 |
+
if len(dim):
|
| 313 |
+
x = x.sum(dim=dim, keepdim=True)
|
| 314 |
+
if extra_dims:
|
| 315 |
+
x = x.reshape(-1, *x.shape[extra_dims + 1:])
|
| 316 |
+
assert x.shape == shape
|
| 317 |
+
return x
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def modulated_conv2d(
|
| 321 |
+
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
|
| 322 |
+
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
|
| 323 |
+
styles, # Modulation coefficients of shape [batch_size, in_channels].
|
| 324 |
+
noise=None, # Optional noise tensor to add to the output activations.
|
| 325 |
+
up=1, # Integer upsampling factor.
|
| 326 |
+
down=1, # Integer downsampling factor.
|
| 327 |
+
padding=0, # Padding with respect to the upsampled image.
|
| 328 |
+
resample_filter=None,
|
| 329 |
+
# Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
|
| 330 |
+
demodulate=True, # Apply weight demodulation?
|
| 331 |
+
flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
|
| 332 |
+
fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation?
|
| 333 |
+
):
|
| 334 |
+
batch_size = x.shape[0]
|
| 335 |
+
out_channels, in_channels, kh, kw = weight.shape
|
| 336 |
+
|
| 337 |
+
# Pre-normalize inputs to avoid FP16 overflow.
|
| 338 |
+
if x.dtype == torch.float16 and demodulate:
|
| 339 |
+
weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1, 2, 3],
|
| 340 |
+
keepdim=True)) # max_Ikk
|
| 341 |
+
styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
|
| 342 |
+
|
| 343 |
+
# Calculate per-sample weights and demodulation coefficients.
|
| 344 |
+
w = None
|
| 345 |
+
dcoefs = None
|
| 346 |
+
if demodulate or fused_modconv:
|
| 347 |
+
w = weight.unsqueeze(0) # [NOIkk]
|
| 348 |
+
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
|
| 349 |
+
if demodulate:
|
| 350 |
+
dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
|
| 351 |
+
if demodulate and fused_modconv:
|
| 352 |
+
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
|
| 353 |
+
# Execute by scaling the activations before and after the convolution.
|
| 354 |
+
if not fused_modconv:
|
| 355 |
+
x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
| 356 |
+
x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down,
|
| 357 |
+
padding=padding, flip_weight=flip_weight)
|
| 358 |
+
if demodulate and noise is not None:
|
| 359 |
+
x = fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
|
| 360 |
+
elif demodulate:
|
| 361 |
+
x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
| 362 |
+
elif noise is not None:
|
| 363 |
+
x = x.add_(noise.to(x.dtype))
|
| 364 |
+
return x
|
| 365 |
+
|
| 366 |
+
# Execute as one fused op using grouped convolution.
|
| 367 |
+
batch_size = int(batch_size)
|
| 368 |
+
x = x.reshape(1, -1, *x.shape[2:])
|
| 369 |
+
w = w.reshape(-1, in_channels, kh, kw)
|
| 370 |
+
x = conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding,
|
| 371 |
+
groups=batch_size, flip_weight=flip_weight)
|
| 372 |
+
x = x.reshape(batch_size, -1, *x.shape[2:])
|
| 373 |
+
if noise is not None:
|
| 374 |
+
x = x.add_(noise)
|
| 375 |
+
return x
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class SynthesisLayer(torch.nn.Module):
|
| 379 |
+
def __init__(self,
|
| 380 |
+
in_channels, # Number of input channels.
|
| 381 |
+
out_channels, # Number of output channels.
|
| 382 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 383 |
+
resolution, # Resolution of this layer.
|
| 384 |
+
kernel_size=3, # Convolution kernel size.
|
| 385 |
+
up=1, # Integer upsampling factor.
|
| 386 |
+
use_noise=True, # Enable noise input?
|
| 387 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 388 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
| 389 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 390 |
+
channels_last=False, # Use channels_last format for the weights?
|
| 391 |
+
):
|
| 392 |
+
super().__init__()
|
| 393 |
+
self.resolution = resolution
|
| 394 |
+
self.up = up
|
| 395 |
+
self.use_noise = use_noise
|
| 396 |
+
self.activation = activation
|
| 397 |
+
self.conv_clamp = conv_clamp
|
| 398 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
| 399 |
+
self.padding = kernel_size // 2
|
| 400 |
+
self.act_gain = activation_funcs[activation].def_gain
|
| 401 |
+
|
| 402 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
| 403 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
| 404 |
+
self.weight = torch.nn.Parameter(
|
| 405 |
+
torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
| 406 |
+
if use_noise:
|
| 407 |
+
self.register_buffer('noise_const', torch.randn([resolution, resolution]))
|
| 408 |
+
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
|
| 409 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
| 410 |
+
|
| 411 |
+
def forward(self, x, w, noise_mode='none', fused_modconv=True, gain=1):
|
| 412 |
+
assert noise_mode in ['random', 'const', 'none']
|
| 413 |
+
in_resolution = self.resolution // self.up
|
| 414 |
+
styles = self.affine(w)
|
| 415 |
+
|
| 416 |
+
noise = None
|
| 417 |
+
if self.use_noise and noise_mode == 'random':
|
| 418 |
+
noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution],
|
| 419 |
+
device=x.device) * self.noise_strength
|
| 420 |
+
if self.use_noise and noise_mode == 'const':
|
| 421 |
+
noise = self.noise_const * self.noise_strength
|
| 422 |
+
|
| 423 |
+
flip_weight = (self.up == 1) # slightly faster
|
| 424 |
+
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
|
| 425 |
+
padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight,
|
| 426 |
+
fused_modconv=fused_modconv)
|
| 427 |
+
|
| 428 |
+
act_gain = self.act_gain * gain
|
| 429 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
| 430 |
+
x = F.leaky_relu(x, negative_slope=0.2, inplace=False)
|
| 431 |
+
if act_gain != 1:
|
| 432 |
+
x = x * act_gain
|
| 433 |
+
if act_clamp is not None:
|
| 434 |
+
x = x.clamp(-act_clamp, act_clamp)
|
| 435 |
+
return x
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
class ToRGBLayer(torch.nn.Module):
|
| 439 |
+
def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
|
| 440 |
+
super().__init__()
|
| 441 |
+
self.conv_clamp = conv_clamp
|
| 442 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
| 443 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
| 444 |
+
self.weight = torch.nn.Parameter(
|
| 445 |
+
torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
| 446 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
| 447 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
| 448 |
+
|
| 449 |
+
def forward(self, x, w, fused_modconv=True):
|
| 450 |
+
styles = self.affine(w) * self.weight_gain
|
| 451 |
+
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
|
| 452 |
+
x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
|
| 453 |
+
return x
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
class SynthesisForeword(torch.nn.Module):
|
| 457 |
+
def __init__(self,
|
| 458 |
+
z_dim, # Output Latent (Z) dimensionality.
|
| 459 |
+
resolution, # Resolution of this block.
|
| 460 |
+
in_channels,
|
| 461 |
+
img_channels, # Number of input color channels.
|
| 462 |
+
architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
|
| 463 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 464 |
+
|
| 465 |
+
):
|
| 466 |
+
super().__init__()
|
| 467 |
+
self.in_channels = in_channels
|
| 468 |
+
self.z_dim = z_dim
|
| 469 |
+
self.resolution = resolution
|
| 470 |
+
self.img_channels = img_channels
|
| 471 |
+
self.architecture = architecture
|
| 472 |
+
|
| 473 |
+
self.fc = FullyConnectedLayer(self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation)
|
| 474 |
+
self.conv = SynthesisLayer(self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4)
|
| 475 |
+
|
| 476 |
+
if architecture == 'skip':
|
| 477 |
+
self.torgb = ToRGBLayer(self.in_channels, self.img_channels, kernel_size=1, w_dim=(z_dim // 2) * 3)
|
| 478 |
+
|
| 479 |
+
def forward(self, x, ws, feats, img, force_fp32=False):
|
| 480 |
+
_ = force_fp32 # unused
|
| 481 |
+
dtype = torch.float32
|
| 482 |
+
memory_format = torch.contiguous_format
|
| 483 |
+
|
| 484 |
+
x_global = x.clone()
|
| 485 |
+
# ToRGB.
|
| 486 |
+
x = self.fc(x)
|
| 487 |
+
x = x.view(-1, self.z_dim // 2, 4, 4)
|
| 488 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
| 489 |
+
|
| 490 |
+
# Main layers.
|
| 491 |
+
x_skip = feats[4].clone()
|
| 492 |
+
x = x + x_skip
|
| 493 |
+
|
| 494 |
+
mod_vector = []
|
| 495 |
+
mod_vector.append(ws[:, 0])
|
| 496 |
+
mod_vector.append(x_global.clone())
|
| 497 |
+
mod_vector = torch.cat(mod_vector, dim=1)
|
| 498 |
+
|
| 499 |
+
x = self.conv(x, mod_vector)
|
| 500 |
+
|
| 501 |
+
mod_vector = []
|
| 502 |
+
mod_vector.append(ws[:, 2 * 2 - 3])
|
| 503 |
+
mod_vector.append(x_global.clone())
|
| 504 |
+
mod_vector = torch.cat(mod_vector, dim=1)
|
| 505 |
+
|
| 506 |
+
if self.architecture == 'skip':
|
| 507 |
+
img = self.torgb(x, mod_vector)
|
| 508 |
+
img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
| 509 |
+
|
| 510 |
+
assert x.dtype == dtype
|
| 511 |
+
return x, img
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class SELayer(nn.Module):
|
| 515 |
+
def __init__(self, channel, reduction=16):
|
| 516 |
+
super(SELayer, self).__init__()
|
| 517 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 518 |
+
self.fc = nn.Sequential(
|
| 519 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
| 520 |
+
nn.ReLU(inplace=False),
|
| 521 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
| 522 |
+
nn.Sigmoid()
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
def forward(self, x):
|
| 526 |
+
b, c, _, _ = x.size()
|
| 527 |
+
y = self.avg_pool(x).view(b, c)
|
| 528 |
+
y = self.fc(y).view(b, c, 1, 1)
|
| 529 |
+
res = x * y.expand_as(x)
|
| 530 |
+
return res
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
class FourierUnit(nn.Module):
|
| 534 |
+
|
| 535 |
+
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
|
| 536 |
+
spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
|
| 537 |
+
# bn_layer not used
|
| 538 |
+
super(FourierUnit, self).__init__()
|
| 539 |
+
self.groups = groups
|
| 540 |
+
|
| 541 |
+
self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
|
| 542 |
+
out_channels=out_channels * 2,
|
| 543 |
+
kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
|
| 544 |
+
self.relu = torch.nn.ReLU(inplace=False)
|
| 545 |
+
|
| 546 |
+
# squeeze and excitation block
|
| 547 |
+
self.use_se = use_se
|
| 548 |
+
if use_se:
|
| 549 |
+
if se_kwargs is None:
|
| 550 |
+
se_kwargs = {}
|
| 551 |
+
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
|
| 552 |
+
|
| 553 |
+
self.spatial_scale_factor = spatial_scale_factor
|
| 554 |
+
self.spatial_scale_mode = spatial_scale_mode
|
| 555 |
+
self.spectral_pos_encoding = spectral_pos_encoding
|
| 556 |
+
self.ffc3d = ffc3d
|
| 557 |
+
self.fft_norm = fft_norm
|
| 558 |
+
|
| 559 |
+
def forward(self, x):
|
| 560 |
+
batch = x.shape[0]
|
| 561 |
+
|
| 562 |
+
if self.spatial_scale_factor is not None:
|
| 563 |
+
orig_size = x.shape[-2:]
|
| 564 |
+
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
|
| 565 |
+
align_corners=False)
|
| 566 |
+
|
| 567 |
+
r_size = x.size()
|
| 568 |
+
# (batch, c, h, w/2+1, 2)
|
| 569 |
+
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
| 570 |
+
ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
| 571 |
+
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
| 572 |
+
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
| 573 |
+
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
|
| 574 |
+
|
| 575 |
+
if self.spectral_pos_encoding:
|
| 576 |
+
height, width = ffted.shape[-2:]
|
| 577 |
+
coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
|
| 578 |
+
coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
|
| 579 |
+
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
|
| 580 |
+
|
| 581 |
+
if self.use_se:
|
| 582 |
+
ffted = self.se(ffted)
|
| 583 |
+
|
| 584 |
+
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
|
| 585 |
+
ffted = self.relu(ffted)
|
| 586 |
+
|
| 587 |
+
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
|
| 588 |
+
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
|
| 589 |
+
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
| 590 |
+
|
| 591 |
+
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
| 592 |
+
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
|
| 593 |
+
|
| 594 |
+
if self.spatial_scale_factor is not None:
|
| 595 |
+
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
|
| 596 |
+
|
| 597 |
+
return output
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class SpectralTransform(nn.Module):
|
| 601 |
+
|
| 602 |
+
def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
|
| 603 |
+
# bn_layer not used
|
| 604 |
+
super(SpectralTransform, self).__init__()
|
| 605 |
+
self.enable_lfu = enable_lfu
|
| 606 |
+
if stride == 2:
|
| 607 |
+
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
|
| 608 |
+
else:
|
| 609 |
+
self.downsample = nn.Identity()
|
| 610 |
+
|
| 611 |
+
self.stride = stride
|
| 612 |
+
self.conv1 = nn.Sequential(
|
| 613 |
+
nn.Conv2d(in_channels, out_channels //
|
| 614 |
+
2, kernel_size=1, groups=groups, bias=False),
|
| 615 |
+
# nn.BatchNorm2d(out_channels // 2),
|
| 616 |
+
nn.ReLU(inplace=True)
|
| 617 |
+
)
|
| 618 |
+
self.fu = FourierUnit(
|
| 619 |
+
out_channels // 2, out_channels // 2, groups, **fu_kwargs)
|
| 620 |
+
if self.enable_lfu:
|
| 621 |
+
self.lfu = FourierUnit(
|
| 622 |
+
out_channels // 2, out_channels // 2, groups)
|
| 623 |
+
self.conv2 = torch.nn.Conv2d(
|
| 624 |
+
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
|
| 625 |
+
|
| 626 |
+
def forward(self, x):
|
| 627 |
+
|
| 628 |
+
x = self.downsample(x)
|
| 629 |
+
x = self.conv1(x)
|
| 630 |
+
output = self.fu(x)
|
| 631 |
+
|
| 632 |
+
if self.enable_lfu:
|
| 633 |
+
n, c, h, w = x.shape
|
| 634 |
+
split_no = 2
|
| 635 |
+
split_s = h // split_no
|
| 636 |
+
xs = torch.cat(torch.split(
|
| 637 |
+
x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
|
| 638 |
+
xs = torch.cat(torch.split(xs, split_s, dim=-1),
|
| 639 |
+
dim=1).contiguous()
|
| 640 |
+
xs = self.lfu(xs)
|
| 641 |
+
xs = xs.repeat(1, 1, split_no, split_no).contiguous()
|
| 642 |
+
else:
|
| 643 |
+
xs = 0
|
| 644 |
+
|
| 645 |
+
output = self.conv2(x + output + xs)
|
| 646 |
+
|
| 647 |
+
return output
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
class FFC(nn.Module):
|
| 651 |
+
|
| 652 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
| 653 |
+
ratio_gin, ratio_gout, stride=1, padding=0,
|
| 654 |
+
dilation=1, groups=1, bias=False, enable_lfu=True,
|
| 655 |
+
padding_type='reflect', gated=False, **spectral_kwargs):
|
| 656 |
+
super(FFC, self).__init__()
|
| 657 |
+
|
| 658 |
+
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
|
| 659 |
+
self.stride = stride
|
| 660 |
+
|
| 661 |
+
in_cg = int(in_channels * ratio_gin)
|
| 662 |
+
in_cl = in_channels - in_cg
|
| 663 |
+
out_cg = int(out_channels * ratio_gout)
|
| 664 |
+
out_cl = out_channels - out_cg
|
| 665 |
+
# groups_g = 1 if groups == 1 else int(groups * ratio_gout)
|
| 666 |
+
# groups_l = 1 if groups == 1 else groups - groups_g
|
| 667 |
+
|
| 668 |
+
self.ratio_gin = ratio_gin
|
| 669 |
+
self.ratio_gout = ratio_gout
|
| 670 |
+
self.global_in_num = in_cg
|
| 671 |
+
|
| 672 |
+
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
|
| 673 |
+
self.convl2l = module(in_cl, out_cl, kernel_size,
|
| 674 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
| 675 |
+
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
|
| 676 |
+
self.convl2g = module(in_cl, out_cg, kernel_size,
|
| 677 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
| 678 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
|
| 679 |
+
self.convg2l = module(in_cg, out_cl, kernel_size,
|
| 680 |
+
stride, padding, dilation, groups, bias, padding_mode=padding_type)
|
| 681 |
+
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
|
| 682 |
+
self.convg2g = module(
|
| 683 |
+
in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
|
| 684 |
+
|
| 685 |
+
self.gated = gated
|
| 686 |
+
module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
|
| 687 |
+
self.gate = module(in_channels, 2, 1)
|
| 688 |
+
|
| 689 |
+
def forward(self, x, fname=None):
|
| 690 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
| 691 |
+
out_xl, out_xg = 0, 0
|
| 692 |
+
|
| 693 |
+
if self.gated:
|
| 694 |
+
total_input_parts = [x_l]
|
| 695 |
+
if torch.is_tensor(x_g):
|
| 696 |
+
total_input_parts.append(x_g)
|
| 697 |
+
total_input = torch.cat(total_input_parts, dim=1)
|
| 698 |
+
|
| 699 |
+
gates = torch.sigmoid(self.gate(total_input))
|
| 700 |
+
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
|
| 701 |
+
else:
|
| 702 |
+
g2l_gate, l2g_gate = 1, 1
|
| 703 |
+
|
| 704 |
+
spec_x = self.convg2g(x_g)
|
| 705 |
+
|
| 706 |
+
if self.ratio_gout != 1:
|
| 707 |
+
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
|
| 708 |
+
if self.ratio_gout != 0:
|
| 709 |
+
out_xg = self.convl2g(x_l) * l2g_gate + spec_x
|
| 710 |
+
|
| 711 |
+
return out_xl, out_xg
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
class FFC_BN_ACT(nn.Module):
|
| 715 |
+
|
| 716 |
+
def __init__(self, in_channels, out_channels,
|
| 717 |
+
kernel_size, ratio_gin, ratio_gout,
|
| 718 |
+
stride=1, padding=0, dilation=1, groups=1, bias=False,
|
| 719 |
+
norm_layer=nn.SyncBatchNorm, activation_layer=nn.Identity,
|
| 720 |
+
padding_type='reflect',
|
| 721 |
+
enable_lfu=True, **kwargs):
|
| 722 |
+
super(FFC_BN_ACT, self).__init__()
|
| 723 |
+
self.ffc = FFC(in_channels, out_channels, kernel_size,
|
| 724 |
+
ratio_gin, ratio_gout, stride, padding, dilation,
|
| 725 |
+
groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
|
| 726 |
+
lnorm = nn.Identity if ratio_gout == 1 else norm_layer
|
| 727 |
+
gnorm = nn.Identity if ratio_gout == 0 else norm_layer
|
| 728 |
+
global_channels = int(out_channels * ratio_gout)
|
| 729 |
+
# self.bn_l = lnorm(out_channels - global_channels)
|
| 730 |
+
# self.bn_g = gnorm(global_channels)
|
| 731 |
+
|
| 732 |
+
lact = nn.Identity if ratio_gout == 1 else activation_layer
|
| 733 |
+
gact = nn.Identity if ratio_gout == 0 else activation_layer
|
| 734 |
+
self.act_l = lact(inplace=True)
|
| 735 |
+
self.act_g = gact(inplace=True)
|
| 736 |
+
|
| 737 |
+
def forward(self, x, fname=None):
|
| 738 |
+
x_l, x_g = self.ffc(x, fname=fname, )
|
| 739 |
+
x_l = self.act_l(x_l)
|
| 740 |
+
x_g = self.act_g(x_g)
|
| 741 |
+
return x_l, x_g
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
class FFCResnetBlock(nn.Module):
|
| 745 |
+
def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
|
| 746 |
+
spatial_transform_kwargs=None, inline=False, ratio_gin=0.75, ratio_gout=0.75):
|
| 747 |
+
super().__init__()
|
| 748 |
+
self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
|
| 749 |
+
norm_layer=norm_layer,
|
| 750 |
+
activation_layer=activation_layer,
|
| 751 |
+
padding_type=padding_type,
|
| 752 |
+
ratio_gin=ratio_gin, ratio_gout=ratio_gout)
|
| 753 |
+
self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
|
| 754 |
+
norm_layer=norm_layer,
|
| 755 |
+
activation_layer=activation_layer,
|
| 756 |
+
padding_type=padding_type,
|
| 757 |
+
ratio_gin=ratio_gin, ratio_gout=ratio_gout)
|
| 758 |
+
self.inline = inline
|
| 759 |
+
|
| 760 |
+
def forward(self, x, fname=None):
|
| 761 |
+
if self.inline:
|
| 762 |
+
x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
|
| 763 |
+
else:
|
| 764 |
+
x_l, x_g = x if type(x) is tuple else (x, 0)
|
| 765 |
+
|
| 766 |
+
id_l, id_g = x_l, x_g
|
| 767 |
+
|
| 768 |
+
x_l, x_g = self.conv1((x_l, x_g), fname=fname)
|
| 769 |
+
x_l, x_g = self.conv2((x_l, x_g), fname=fname)
|
| 770 |
+
|
| 771 |
+
x_l, x_g = id_l + x_l, id_g + x_g
|
| 772 |
+
out = x_l, x_g
|
| 773 |
+
if self.inline:
|
| 774 |
+
out = torch.cat(out, dim=1)
|
| 775 |
+
return out
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
class ConcatTupleLayer(nn.Module):
|
| 779 |
+
def forward(self, x):
|
| 780 |
+
assert isinstance(x, tuple)
|
| 781 |
+
x_l, x_g = x
|
| 782 |
+
assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
|
| 783 |
+
if not torch.is_tensor(x_g):
|
| 784 |
+
return x_l
|
| 785 |
+
return torch.cat(x, dim=1)
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
class FFCBlock(torch.nn.Module):
|
| 789 |
+
def __init__(self,
|
| 790 |
+
dim, # Number of output/input channels.
|
| 791 |
+
kernel_size, # Width and height of the convolution kernel.
|
| 792 |
+
padding,
|
| 793 |
+
ratio_gin=0.75,
|
| 794 |
+
ratio_gout=0.75,
|
| 795 |
+
activation='linear', # Activation function: 'relu', 'lrelu', etc.
|
| 796 |
+
):
|
| 797 |
+
super().__init__()
|
| 798 |
+
if activation == 'linear':
|
| 799 |
+
self.activation = nn.Identity
|
| 800 |
+
else:
|
| 801 |
+
self.activation = nn.ReLU
|
| 802 |
+
self.padding = padding
|
| 803 |
+
self.kernel_size = kernel_size
|
| 804 |
+
self.ffc_block = FFCResnetBlock(dim=dim,
|
| 805 |
+
padding_type='reflect',
|
| 806 |
+
norm_layer=nn.SyncBatchNorm,
|
| 807 |
+
activation_layer=self.activation,
|
| 808 |
+
dilation=1,
|
| 809 |
+
ratio_gin=ratio_gin,
|
| 810 |
+
ratio_gout=ratio_gout)
|
| 811 |
+
|
| 812 |
+
self.concat_layer = ConcatTupleLayer()
|
| 813 |
+
|
| 814 |
+
def forward(self, gen_ft, mask, fname=None):
|
| 815 |
+
x = gen_ft.float()
|
| 816 |
+
|
| 817 |
+
x_l, x_g = x[:, :-self.ffc_block.conv1.ffc.global_in_num], x[:, -self.ffc_block.conv1.ffc.global_in_num:]
|
| 818 |
+
id_l, id_g = x_l, x_g
|
| 819 |
+
|
| 820 |
+
x_l, x_g = self.ffc_block((x_l, x_g), fname=fname)
|
| 821 |
+
x_l, x_g = id_l + x_l, id_g + x_g
|
| 822 |
+
x = self.concat_layer((x_l, x_g))
|
| 823 |
+
|
| 824 |
+
return x + gen_ft.float()
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
class FFCSkipLayer(torch.nn.Module):
|
| 828 |
+
def __init__(self,
|
| 829 |
+
dim, # Number of input/output channels.
|
| 830 |
+
kernel_size=3, # Convolution kernel size.
|
| 831 |
+
ratio_gin=0.75,
|
| 832 |
+
ratio_gout=0.75,
|
| 833 |
+
):
|
| 834 |
+
super().__init__()
|
| 835 |
+
self.padding = kernel_size // 2
|
| 836 |
+
|
| 837 |
+
self.ffc_act = FFCBlock(dim=dim, kernel_size=kernel_size, activation=nn.ReLU,
|
| 838 |
+
padding=self.padding, ratio_gin=ratio_gin, ratio_gout=ratio_gout)
|
| 839 |
+
|
| 840 |
+
def forward(self, gen_ft, mask, fname=None):
|
| 841 |
+
x = self.ffc_act(gen_ft, mask, fname=fname)
|
| 842 |
+
return x
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
class SynthesisBlock(torch.nn.Module):
|
| 846 |
+
def __init__(self,
|
| 847 |
+
in_channels, # Number of input channels, 0 = first block.
|
| 848 |
+
out_channels, # Number of output channels.
|
| 849 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 850 |
+
resolution, # Resolution of this block.
|
| 851 |
+
img_channels, # Number of output color channels.
|
| 852 |
+
is_last, # Is this the last block?
|
| 853 |
+
architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
|
| 854 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
| 855 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 856 |
+
use_fp16=False, # Use FP16 for this block?
|
| 857 |
+
fp16_channels_last=False, # Use channels-last memory format with FP16?
|
| 858 |
+
**layer_kwargs, # Arguments for SynthesisLayer.
|
| 859 |
+
):
|
| 860 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
| 861 |
+
super().__init__()
|
| 862 |
+
self.in_channels = in_channels
|
| 863 |
+
self.w_dim = w_dim
|
| 864 |
+
self.resolution = resolution
|
| 865 |
+
self.img_channels = img_channels
|
| 866 |
+
self.is_last = is_last
|
| 867 |
+
self.architecture = architecture
|
| 868 |
+
self.use_fp16 = use_fp16
|
| 869 |
+
self.channels_last = (use_fp16 and fp16_channels_last)
|
| 870 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
| 871 |
+
self.num_conv = 0
|
| 872 |
+
self.num_torgb = 0
|
| 873 |
+
self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1}
|
| 874 |
+
|
| 875 |
+
if in_channels != 0 and resolution >= 8:
|
| 876 |
+
self.ffc_skip = nn.ModuleList()
|
| 877 |
+
for _ in range(self.res_ffc[resolution]):
|
| 878 |
+
self.ffc_skip.append(FFCSkipLayer(dim=out_channels))
|
| 879 |
+
|
| 880 |
+
if in_channels == 0:
|
| 881 |
+
self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
|
| 882 |
+
|
| 883 |
+
if in_channels != 0:
|
| 884 |
+
self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim * 3, resolution=resolution, up=2,
|
| 885 |
+
resample_filter=resample_filter, conv_clamp=conv_clamp,
|
| 886 |
+
channels_last=self.channels_last, **layer_kwargs)
|
| 887 |
+
self.num_conv += 1
|
| 888 |
+
|
| 889 |
+
self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim * 3, resolution=resolution,
|
| 890 |
+
conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
|
| 891 |
+
self.num_conv += 1
|
| 892 |
+
|
| 893 |
+
if is_last or architecture == 'skip':
|
| 894 |
+
self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim * 3,
|
| 895 |
+
conv_clamp=conv_clamp, channels_last=self.channels_last)
|
| 896 |
+
self.num_torgb += 1
|
| 897 |
+
|
| 898 |
+
if in_channels != 0 and architecture == 'resnet':
|
| 899 |
+
self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
|
| 900 |
+
resample_filter=resample_filter, channels_last=self.channels_last)
|
| 901 |
+
|
| 902 |
+
def forward(self, x, mask, feats, img, ws, fname=None, force_fp32=False, fused_modconv=None, **layer_kwargs):
|
| 903 |
+
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
| 904 |
+
dtype = torch.float32
|
| 905 |
+
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
| 906 |
+
if fused_modconv is None:
|
| 907 |
+
fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
|
| 908 |
+
|
| 909 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
| 910 |
+
x_skip = feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format)
|
| 911 |
+
|
| 912 |
+
# Main layers.
|
| 913 |
+
if self.in_channels == 0:
|
| 914 |
+
x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs)
|
| 915 |
+
elif self.architecture == 'resnet':
|
| 916 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
| 917 |
+
x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs)
|
| 918 |
+
if len(self.ffc_skip) > 0:
|
| 919 |
+
mask = F.interpolate(mask, size=x_skip.shape[2:], )
|
| 920 |
+
z = x + x_skip
|
| 921 |
+
for fres in self.ffc_skip:
|
| 922 |
+
z = fres(z, mask)
|
| 923 |
+
x = x + z
|
| 924 |
+
else:
|
| 925 |
+
x = x + x_skip
|
| 926 |
+
x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
|
| 927 |
+
x = y.add_(x)
|
| 928 |
+
else:
|
| 929 |
+
x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs)
|
| 930 |
+
if len(self.ffc_skip) > 0:
|
| 931 |
+
mask = F.interpolate(mask, size=x_skip.shape[2:], )
|
| 932 |
+
z = x + x_skip
|
| 933 |
+
for fres in self.ffc_skip:
|
| 934 |
+
z = fres(z, mask)
|
| 935 |
+
x = x + z
|
| 936 |
+
else:
|
| 937 |
+
x = x + x_skip
|
| 938 |
+
x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs)
|
| 939 |
+
# ToRGB.
|
| 940 |
+
if img is not None:
|
| 941 |
+
img = upsample2d(img, self.resample_filter)
|
| 942 |
+
if self.is_last or self.architecture == 'skip':
|
| 943 |
+
y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv)
|
| 944 |
+
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
| 945 |
+
img = img.add_(y) if img is not None else y
|
| 946 |
+
|
| 947 |
+
x = x.to(dtype=dtype)
|
| 948 |
+
assert x.dtype == dtype
|
| 949 |
+
assert img is None or img.dtype == torch.float32
|
| 950 |
+
return x, img
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
class SynthesisNetwork(torch.nn.Module):
|
| 954 |
+
def __init__(self,
|
| 955 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 956 |
+
z_dim, # Output Latent (Z) dimensionality.
|
| 957 |
+
img_resolution, # Output image resolution.
|
| 958 |
+
img_channels, # Number of color channels.
|
| 959 |
+
channel_base=16384, # Overall multiplier for the number of channels.
|
| 960 |
+
channel_max=512, # Maximum number of channels in any layer.
|
| 961 |
+
num_fp16_res=0, # Use FP16 for the N highest resolutions.
|
| 962 |
+
**block_kwargs, # Arguments for SynthesisBlock.
|
| 963 |
+
):
|
| 964 |
+
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
|
| 965 |
+
super().__init__()
|
| 966 |
+
self.w_dim = w_dim
|
| 967 |
+
self.img_resolution = img_resolution
|
| 968 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
| 969 |
+
self.img_channels = img_channels
|
| 970 |
+
self.block_resolutions = [2 ** i for i in range(3, self.img_resolution_log2 + 1)]
|
| 971 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
|
| 972 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
| 973 |
+
|
| 974 |
+
self.foreword = SynthesisForeword(img_channels=img_channels, in_channels=min(channel_base // 4, channel_max),
|
| 975 |
+
z_dim=z_dim * 2, resolution=4)
|
| 976 |
+
|
| 977 |
+
self.num_ws = self.img_resolution_log2 * 2 - 2
|
| 978 |
+
for res in self.block_resolutions:
|
| 979 |
+
if res // 2 in channels_dict.keys():
|
| 980 |
+
in_channels = channels_dict[res // 2] if res > 4 else 0
|
| 981 |
+
else:
|
| 982 |
+
in_channels = min(channel_base // (res // 2), channel_max)
|
| 983 |
+
out_channels = channels_dict[res]
|
| 984 |
+
use_fp16 = (res >= fp16_resolution)
|
| 985 |
+
use_fp16 = False
|
| 986 |
+
is_last = (res == self.img_resolution)
|
| 987 |
+
block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
|
| 988 |
+
img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
|
| 989 |
+
setattr(self, f'b{res}', block)
|
| 990 |
+
|
| 991 |
+
def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs):
|
| 992 |
+
|
| 993 |
+
img = None
|
| 994 |
+
|
| 995 |
+
x, img = self.foreword(x_global, ws, feats, img)
|
| 996 |
+
|
| 997 |
+
for res in self.block_resolutions:
|
| 998 |
+
block = getattr(self, f'b{res}')
|
| 999 |
+
mod_vector0 = []
|
| 1000 |
+
mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5])
|
| 1001 |
+
mod_vector0.append(x_global.clone())
|
| 1002 |
+
mod_vector0 = torch.cat(mod_vector0, dim=1)
|
| 1003 |
+
|
| 1004 |
+
mod_vector1 = []
|
| 1005 |
+
mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4])
|
| 1006 |
+
mod_vector1.append(x_global.clone())
|
| 1007 |
+
mod_vector1 = torch.cat(mod_vector1, dim=1)
|
| 1008 |
+
|
| 1009 |
+
mod_vector_rgb = []
|
| 1010 |
+
mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3])
|
| 1011 |
+
mod_vector_rgb.append(x_global.clone())
|
| 1012 |
+
mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1)
|
| 1013 |
+
x, img = block(x, mask, feats, img, (mod_vector0, mod_vector1, mod_vector_rgb), fname=fname, **block_kwargs)
|
| 1014 |
+
return img
|
| 1015 |
+
|
| 1016 |
+
|
| 1017 |
+
class MappingNetwork(torch.nn.Module):
|
| 1018 |
+
def __init__(self,
|
| 1019 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
| 1020 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
| 1021 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 1022 |
+
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
| 1023 |
+
num_layers=8, # Number of mapping layers.
|
| 1024 |
+
embed_features=None, # Label embedding dimensionality, None = same as w_dim.
|
| 1025 |
+
layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
| 1026 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 1027 |
+
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
|
| 1028 |
+
w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
| 1029 |
+
):
|
| 1030 |
+
super().__init__()
|
| 1031 |
+
self.z_dim = z_dim
|
| 1032 |
+
self.c_dim = c_dim
|
| 1033 |
+
self.w_dim = w_dim
|
| 1034 |
+
self.num_ws = num_ws
|
| 1035 |
+
self.num_layers = num_layers
|
| 1036 |
+
self.w_avg_beta = w_avg_beta
|
| 1037 |
+
|
| 1038 |
+
if embed_features is None:
|
| 1039 |
+
embed_features = w_dim
|
| 1040 |
+
if c_dim == 0:
|
| 1041 |
+
embed_features = 0
|
| 1042 |
+
if layer_features is None:
|
| 1043 |
+
layer_features = w_dim
|
| 1044 |
+
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
| 1045 |
+
|
| 1046 |
+
if c_dim > 0:
|
| 1047 |
+
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
| 1048 |
+
for idx in range(num_layers):
|
| 1049 |
+
in_features = features_list[idx]
|
| 1050 |
+
out_features = features_list[idx + 1]
|
| 1051 |
+
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
|
| 1052 |
+
setattr(self, f'fc{idx}', layer)
|
| 1053 |
+
|
| 1054 |
+
if num_ws is not None and w_avg_beta is not None:
|
| 1055 |
+
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
| 1056 |
+
|
| 1057 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
|
| 1058 |
+
# Embed, normalize, and concat inputs.
|
| 1059 |
+
x = None
|
| 1060 |
+
with torch.autograd.profiler.record_function('input'):
|
| 1061 |
+
if self.z_dim > 0:
|
| 1062 |
+
x = normalize_2nd_moment(z.to(torch.float32))
|
| 1063 |
+
if self.c_dim > 0:
|
| 1064 |
+
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
| 1065 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
| 1066 |
+
|
| 1067 |
+
# Main layers.
|
| 1068 |
+
for idx in range(self.num_layers):
|
| 1069 |
+
layer = getattr(self, f'fc{idx}')
|
| 1070 |
+
x = layer(x)
|
| 1071 |
+
|
| 1072 |
+
# Update moving average of W.
|
| 1073 |
+
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
| 1074 |
+
with torch.autograd.profiler.record_function('update_w_avg'):
|
| 1075 |
+
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
| 1076 |
+
|
| 1077 |
+
# Broadcast.
|
| 1078 |
+
if self.num_ws is not None:
|
| 1079 |
+
with torch.autograd.profiler.record_function('broadcast'):
|
| 1080 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
| 1081 |
+
|
| 1082 |
+
# Apply truncation.
|
| 1083 |
+
if truncation_psi != 1:
|
| 1084 |
+
with torch.autograd.profiler.record_function('truncate'):
|
| 1085 |
+
assert self.w_avg_beta is not None
|
| 1086 |
+
if self.num_ws is None or truncation_cutoff is None:
|
| 1087 |
+
x = self.w_avg.lerp(x, truncation_psi)
|
| 1088 |
+
else:
|
| 1089 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
| 1090 |
+
return x
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
class Generator(torch.nn.Module):
|
| 1094 |
+
def __init__(self,
|
| 1095 |
+
z_dim, # Input latent (Z) dimensionality.
|
| 1096 |
+
c_dim, # Conditioning label (C) dimensionality.
|
| 1097 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 1098 |
+
img_resolution, # Output resolution.
|
| 1099 |
+
img_channels, # Number of output color channels.
|
| 1100 |
+
encoder_kwargs={}, # Arguments for EncoderNetwork.
|
| 1101 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
| 1102 |
+
synthesis_kwargs={}, # Arguments for SynthesisNetwork.
|
| 1103 |
+
):
|
| 1104 |
+
super().__init__()
|
| 1105 |
+
self.z_dim = z_dim
|
| 1106 |
+
self.c_dim = c_dim
|
| 1107 |
+
self.w_dim = w_dim
|
| 1108 |
+
self.img_resolution = img_resolution
|
| 1109 |
+
self.img_channels = img_channels
|
| 1110 |
+
self.encoder = EncoderNetwork(c_dim=c_dim, z_dim=z_dim, img_resolution=img_resolution,
|
| 1111 |
+
img_channels=img_channels, **encoder_kwargs)
|
| 1112 |
+
self.synthesis = SynthesisNetwork(z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution,
|
| 1113 |
+
img_channels=img_channels, **synthesis_kwargs)
|
| 1114 |
+
self.num_ws = self.synthesis.num_ws
|
| 1115 |
+
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
|
| 1116 |
+
|
| 1117 |
+
def forward(self, img, c, fname=None, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
|
| 1118 |
+
mask = img[:, -1].unsqueeze(1)
|
| 1119 |
+
x_global, z, feats = self.encoder(img, c)
|
| 1120 |
+
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
|
| 1121 |
+
img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs)
|
| 1122 |
+
return img
|
| 1123 |
+
|
| 1124 |
+
|
| 1125 |
+
FCF_MODEL_URL = os.environ.get(
|
| 1126 |
+
"FCF_MODEL_URL",
|
| 1127 |
+
"https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth",
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
|
| 1131 |
+
class FcF(InpaintModel):
|
| 1132 |
+
min_size = 512
|
| 1133 |
+
pad_mod = 512
|
| 1134 |
+
pad_to_square = True
|
| 1135 |
+
|
| 1136 |
+
def init_model(self, device, **kwargs):
|
| 1137 |
+
seed = 0
|
| 1138 |
+
random.seed(seed)
|
| 1139 |
+
np.random.seed(seed)
|
| 1140 |
+
torch.manual_seed(seed)
|
| 1141 |
+
torch.cuda.manual_seed_all(seed)
|
| 1142 |
+
torch.backends.cudnn.deterministic = True
|
| 1143 |
+
torch.backends.cudnn.benchmark = False
|
| 1144 |
+
|
| 1145 |
+
kwargs = {'channel_base': 1 * 32768, 'channel_max': 512, 'num_fp16_res': 4, 'conv_clamp': 256}
|
| 1146 |
+
G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3,
|
| 1147 |
+
synthesis_kwargs=kwargs, encoder_kwargs=kwargs, mapping_kwargs={'num_layers': 2})
|
| 1148 |
+
self.model = load_model(G, FCF_MODEL_URL, device)
|
| 1149 |
+
self.label = torch.zeros([1, self.model.c_dim], device=device)
|
| 1150 |
+
|
| 1151 |
+
@staticmethod
|
| 1152 |
+
def is_downloaded() -> bool:
|
| 1153 |
+
return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))
|
| 1154 |
+
|
| 1155 |
+
@torch.no_grad()
|
| 1156 |
+
def __call__(self, image, mask, config: Config):
|
| 1157 |
+
"""
|
| 1158 |
+
images: [H, W, C] RGB, not normalized
|
| 1159 |
+
masks: [H, W]
|
| 1160 |
+
return: BGR IMAGE
|
| 1161 |
+
"""
|
| 1162 |
+
if image.shape[0] == 512 and image.shape[1] == 512:
|
| 1163 |
+
return self._pad_forward(image, mask, config)
|
| 1164 |
+
|
| 1165 |
+
boxes = boxes_from_mask(mask)
|
| 1166 |
+
crop_result = []
|
| 1167 |
+
config.hd_strategy_crop_margin = 128
|
| 1168 |
+
for box in boxes:
|
| 1169 |
+
crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
|
| 1170 |
+
origin_size = crop_image.shape[:2]
|
| 1171 |
+
resize_image = resize_max_size(crop_image, size_limit=512)
|
| 1172 |
+
resize_mask = resize_max_size(crop_mask, size_limit=512)
|
| 1173 |
+
inpaint_result = self._pad_forward(resize_image, resize_mask, config)
|
| 1174 |
+
|
| 1175 |
+
# only paste masked area result
|
| 1176 |
+
inpaint_result = cv2.resize(inpaint_result, (origin_size[1], origin_size[0]), interpolation=cv2.INTER_CUBIC)
|
| 1177 |
+
|
| 1178 |
+
original_pixel_indices = crop_mask < 127
|
| 1179 |
+
inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][original_pixel_indices]
|
| 1180 |
+
|
| 1181 |
+
crop_result.append((inpaint_result, crop_box))
|
| 1182 |
+
|
| 1183 |
+
inpaint_result = image[:, :, ::-1]
|
| 1184 |
+
for crop_image, crop_box in crop_result:
|
| 1185 |
+
x1, y1, x2, y2 = crop_box
|
| 1186 |
+
inpaint_result[y1:y2, x1:x2, :] = crop_image
|
| 1187 |
+
|
| 1188 |
+
return inpaint_result
|
| 1189 |
+
|
| 1190 |
+
def forward(self, image, mask, config: Config):
|
| 1191 |
+
"""Input images and output images have same size
|
| 1192 |
+
images: [H, W, C] RGB
|
| 1193 |
+
masks: [H, W] mask area == 255
|
| 1194 |
+
return: BGR IMAGE
|
| 1195 |
+
"""
|
| 1196 |
+
|
| 1197 |
+
image = norm_img(image) # [0, 1]
|
| 1198 |
+
image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
| 1199 |
+
mask = (mask > 120) * 255
|
| 1200 |
+
mask = norm_img(mask)
|
| 1201 |
+
|
| 1202 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
| 1203 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
| 1204 |
+
|
| 1205 |
+
erased_img = image * (1 - mask)
|
| 1206 |
+
input_image = torch.cat([0.5 - mask, erased_img], dim=1)
|
| 1207 |
+
|
| 1208 |
+
output = self.model(input_image, self.label, truncation_psi=0.1, noise_mode='none')
|
| 1209 |
+
output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8)
|
| 1210 |
+
output = output[0].cpu().numpy()
|
| 1211 |
+
cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
| 1212 |
+
return cur_res
|
lama_cleaner/model/lama.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from loguru import logger
|
| 7 |
+
|
| 8 |
+
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
| 9 |
+
from lama_cleaner.model.base import InpaintModel
|
| 10 |
+
from lama_cleaner.schema import Config
|
| 11 |
+
|
| 12 |
+
LAMA_MODEL_URL = os.environ.get(
|
| 13 |
+
"LAMA_MODEL_URL",
|
| 14 |
+
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LaMa(InpaintModel):
|
| 19 |
+
pad_mod = 8
|
| 20 |
+
|
| 21 |
+
def init_model(self, device, **kwargs):
|
| 22 |
+
if os.environ.get("LAMA_MODEL"):
|
| 23 |
+
model_path = os.environ.get("LAMA_MODEL")
|
| 24 |
+
if not os.path.exists(model_path):
|
| 25 |
+
raise FileNotFoundError(
|
| 26 |
+
f"lama torchscript model not found: {model_path}"
|
| 27 |
+
)
|
| 28 |
+
else:
|
| 29 |
+
model_path = download_model(LAMA_MODEL_URL)
|
| 30 |
+
# TODO used to create a lambda docker image
|
| 31 |
+
# model_path = '../app/big-lama.pt'
|
| 32 |
+
logger.info(f"Load LaMa model from: {model_path}")
|
| 33 |
+
model = torch.jit.load(model_path, map_location="cpu")
|
| 34 |
+
model = model.to(device)
|
| 35 |
+
model.eval()
|
| 36 |
+
self.model = model
|
| 37 |
+
self.model_path = model_path
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def is_downloaded() -> bool:
|
| 41 |
+
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
| 42 |
+
|
| 43 |
+
def forward(self, image, mask, config: Config):
|
| 44 |
+
"""Input image and output image have same size
|
| 45 |
+
image: [H, W, C] RGB
|
| 46 |
+
mask: [H, W]
|
| 47 |
+
return: BGR IMAGE
|
| 48 |
+
"""
|
| 49 |
+
image = norm_img(image)
|
| 50 |
+
mask = norm_img(mask)
|
| 51 |
+
|
| 52 |
+
mask = (mask > 0) * 1
|
| 53 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
| 54 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
| 55 |
+
|
| 56 |
+
inpainted_image = self.model(image, mask)
|
| 57 |
+
|
| 58 |
+
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
| 59 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
| 60 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
| 61 |
+
return cur_res
|
lama_cleaner/model/ldm.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from lama_cleaner.model.base import InpaintModel
|
| 7 |
+
from lama_cleaner.model.ddim_sampler import DDIMSampler
|
| 8 |
+
from lama_cleaner.model.plms_sampler import PLMSSampler
|
| 9 |
+
from lama_cleaner.schema import Config, LDMSampler
|
| 10 |
+
|
| 11 |
+
torch.manual_seed(42)
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from lama_cleaner.helper import (
|
| 14 |
+
norm_img,
|
| 15 |
+
get_cache_path_by_url,
|
| 16 |
+
load_jit_model,
|
| 17 |
+
)
|
| 18 |
+
from lama_cleaner.model.utils import (
|
| 19 |
+
make_beta_schedule,
|
| 20 |
+
timestep_embedding,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
LDM_ENCODE_MODEL_URL = os.environ.get(
|
| 24 |
+
"LDM_ENCODE_MODEL_URL",
|
| 25 |
+
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
LDM_DECODE_MODEL_URL = os.environ.get(
|
| 29 |
+
"LDM_DECODE_MODEL_URL",
|
| 30 |
+
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
LDM_DIFFUSION_MODEL_URL = os.environ.get(
|
| 34 |
+
"LDM_DIFFUSION_MODEL_URL",
|
| 35 |
+
"https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class DDPM(nn.Module):
|
| 40 |
+
# classic DDPM with Gaussian diffusion, in image space
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
device,
|
| 44 |
+
timesteps=1000,
|
| 45 |
+
beta_schedule="linear",
|
| 46 |
+
linear_start=0.0015,
|
| 47 |
+
linear_end=0.0205,
|
| 48 |
+
cosine_s=0.008,
|
| 49 |
+
original_elbo_weight=0.0,
|
| 50 |
+
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
| 51 |
+
l_simple_weight=1.0,
|
| 52 |
+
parameterization="eps", # all assuming fixed variance schedules
|
| 53 |
+
use_positional_encodings=False,
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.device = device
|
| 57 |
+
self.parameterization = parameterization
|
| 58 |
+
self.use_positional_encodings = use_positional_encodings
|
| 59 |
+
|
| 60 |
+
self.v_posterior = v_posterior
|
| 61 |
+
self.original_elbo_weight = original_elbo_weight
|
| 62 |
+
self.l_simple_weight = l_simple_weight
|
| 63 |
+
|
| 64 |
+
self.register_schedule(
|
| 65 |
+
beta_schedule=beta_schedule,
|
| 66 |
+
timesteps=timesteps,
|
| 67 |
+
linear_start=linear_start,
|
| 68 |
+
linear_end=linear_end,
|
| 69 |
+
cosine_s=cosine_s,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def register_schedule(
|
| 73 |
+
self,
|
| 74 |
+
given_betas=None,
|
| 75 |
+
beta_schedule="linear",
|
| 76 |
+
timesteps=1000,
|
| 77 |
+
linear_start=1e-4,
|
| 78 |
+
linear_end=2e-2,
|
| 79 |
+
cosine_s=8e-3,
|
| 80 |
+
):
|
| 81 |
+
betas = make_beta_schedule(
|
| 82 |
+
self.device,
|
| 83 |
+
beta_schedule,
|
| 84 |
+
timesteps,
|
| 85 |
+
linear_start=linear_start,
|
| 86 |
+
linear_end=linear_end,
|
| 87 |
+
cosine_s=cosine_s,
|
| 88 |
+
)
|
| 89 |
+
alphas = 1.0 - betas
|
| 90 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 91 |
+
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
| 92 |
+
|
| 93 |
+
(timesteps,) = betas.shape
|
| 94 |
+
self.num_timesteps = int(timesteps)
|
| 95 |
+
self.linear_start = linear_start
|
| 96 |
+
self.linear_end = linear_end
|
| 97 |
+
assert (
|
| 98 |
+
alphas_cumprod.shape[0] == self.num_timesteps
|
| 99 |
+
), "alphas have to be defined for each timestep"
|
| 100 |
+
|
| 101 |
+
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
|
| 102 |
+
|
| 103 |
+
self.register_buffer("betas", to_torch(betas))
|
| 104 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
| 105 |
+
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
| 106 |
+
|
| 107 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 108 |
+
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
| 109 |
+
self.register_buffer(
|
| 110 |
+
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
| 111 |
+
)
|
| 112 |
+
self.register_buffer(
|
| 113 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
| 114 |
+
)
|
| 115 |
+
self.register_buffer(
|
| 116 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
| 117 |
+
)
|
| 118 |
+
self.register_buffer(
|
| 119 |
+
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 123 |
+
posterior_variance = (1 - self.v_posterior) * betas * (
|
| 124 |
+
1.0 - alphas_cumprod_prev
|
| 125 |
+
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
| 126 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 127 |
+
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
| 128 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 129 |
+
self.register_buffer(
|
| 130 |
+
"posterior_log_variance_clipped",
|
| 131 |
+
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
|
| 132 |
+
)
|
| 133 |
+
self.register_buffer(
|
| 134 |
+
"posterior_mean_coef1",
|
| 135 |
+
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
|
| 136 |
+
)
|
| 137 |
+
self.register_buffer(
|
| 138 |
+
"posterior_mean_coef2",
|
| 139 |
+
to_torch(
|
| 140 |
+
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
|
| 141 |
+
),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if self.parameterization == "eps":
|
| 145 |
+
lvlb_weights = self.betas**2 / (
|
| 146 |
+
2
|
| 147 |
+
* self.posterior_variance
|
| 148 |
+
* to_torch(alphas)
|
| 149 |
+
* (1 - self.alphas_cumprod)
|
| 150 |
+
)
|
| 151 |
+
elif self.parameterization == "x0":
|
| 152 |
+
lvlb_weights = (
|
| 153 |
+
0.5
|
| 154 |
+
* np.sqrt(torch.Tensor(alphas_cumprod))
|
| 155 |
+
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
raise NotImplementedError("mu not supported")
|
| 159 |
+
# TODO how to choose this term
|
| 160 |
+
lvlb_weights[0] = lvlb_weights[1]
|
| 161 |
+
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
|
| 162 |
+
assert not torch.isnan(self.lvlb_weights).all()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class LatentDiffusion(DDPM):
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
diffusion_model,
|
| 169 |
+
device,
|
| 170 |
+
cond_stage_key="image",
|
| 171 |
+
cond_stage_trainable=False,
|
| 172 |
+
concat_mode=True,
|
| 173 |
+
scale_factor=1.0,
|
| 174 |
+
scale_by_std=False,
|
| 175 |
+
*args,
|
| 176 |
+
**kwargs,
|
| 177 |
+
):
|
| 178 |
+
self.num_timesteps_cond = 1
|
| 179 |
+
self.scale_by_std = scale_by_std
|
| 180 |
+
super().__init__(device, *args, **kwargs)
|
| 181 |
+
self.diffusion_model = diffusion_model
|
| 182 |
+
self.concat_mode = concat_mode
|
| 183 |
+
self.cond_stage_trainable = cond_stage_trainable
|
| 184 |
+
self.cond_stage_key = cond_stage_key
|
| 185 |
+
self.num_downs = 2
|
| 186 |
+
self.scale_factor = scale_factor
|
| 187 |
+
|
| 188 |
+
def make_cond_schedule(
|
| 189 |
+
self,
|
| 190 |
+
):
|
| 191 |
+
self.cond_ids = torch.full(
|
| 192 |
+
size=(self.num_timesteps,),
|
| 193 |
+
fill_value=self.num_timesteps - 1,
|
| 194 |
+
dtype=torch.long,
|
| 195 |
+
)
|
| 196 |
+
ids = torch.round(
|
| 197 |
+
torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
|
| 198 |
+
).long()
|
| 199 |
+
self.cond_ids[: self.num_timesteps_cond] = ids
|
| 200 |
+
|
| 201 |
+
def register_schedule(
|
| 202 |
+
self,
|
| 203 |
+
given_betas=None,
|
| 204 |
+
beta_schedule="linear",
|
| 205 |
+
timesteps=1000,
|
| 206 |
+
linear_start=1e-4,
|
| 207 |
+
linear_end=2e-2,
|
| 208 |
+
cosine_s=8e-3,
|
| 209 |
+
):
|
| 210 |
+
super().register_schedule(
|
| 211 |
+
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
| 215 |
+
if self.shorten_cond_schedule:
|
| 216 |
+
self.make_cond_schedule()
|
| 217 |
+
|
| 218 |
+
def apply_model(self, x_noisy, t, cond):
|
| 219 |
+
# x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
|
| 220 |
+
t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
|
| 221 |
+
x_recon = self.diffusion_model(x_noisy, t_emb, cond)
|
| 222 |
+
return x_recon
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class LDM(InpaintModel):
|
| 226 |
+
pad_mod = 32
|
| 227 |
+
|
| 228 |
+
def __init__(self, device, fp16: bool = True, **kwargs):
|
| 229 |
+
self.fp16 = fp16
|
| 230 |
+
super().__init__(device)
|
| 231 |
+
self.device = device
|
| 232 |
+
|
| 233 |
+
def init_model(self, device, **kwargs):
|
| 234 |
+
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
|
| 235 |
+
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
|
| 236 |
+
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
|
| 237 |
+
if self.fp16 and "cuda" in str(device):
|
| 238 |
+
self.diffusion_model = self.diffusion_model.half()
|
| 239 |
+
self.cond_stage_model_decode = self.cond_stage_model_decode.half()
|
| 240 |
+
self.cond_stage_model_encode = self.cond_stage_model_encode.half()
|
| 241 |
+
|
| 242 |
+
self.model = LatentDiffusion(self.diffusion_model, device)
|
| 243 |
+
|
| 244 |
+
@staticmethod
|
| 245 |
+
def is_downloaded() -> bool:
|
| 246 |
+
model_paths = [
|
| 247 |
+
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
|
| 248 |
+
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
|
| 249 |
+
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
|
| 250 |
+
]
|
| 251 |
+
return all([os.path.exists(it) for it in model_paths])
|
| 252 |
+
|
| 253 |
+
@torch.cuda.amp.autocast()
|
| 254 |
+
def forward(self, image, mask, config: Config):
|
| 255 |
+
"""
|
| 256 |
+
image: [H, W, C] RGB
|
| 257 |
+
mask: [H, W, 1]
|
| 258 |
+
return: BGR IMAGE
|
| 259 |
+
"""
|
| 260 |
+
# image [1,3,512,512] float32
|
| 261 |
+
# mask: [1,1,512,512] float32
|
| 262 |
+
# masked_image: [1,3,512,512] float32
|
| 263 |
+
if config.ldm_sampler == LDMSampler.ddim:
|
| 264 |
+
sampler = DDIMSampler(self.model)
|
| 265 |
+
elif config.ldm_sampler == LDMSampler.plms:
|
| 266 |
+
sampler = PLMSSampler(self.model)
|
| 267 |
+
else:
|
| 268 |
+
raise ValueError()
|
| 269 |
+
|
| 270 |
+
steps = config.ldm_steps
|
| 271 |
+
image = norm_img(image)
|
| 272 |
+
mask = norm_img(mask)
|
| 273 |
+
|
| 274 |
+
mask[mask < 0.5] = 0
|
| 275 |
+
mask[mask >= 0.5] = 1
|
| 276 |
+
|
| 277 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
| 278 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
| 279 |
+
masked_image = (1 - mask) * image
|
| 280 |
+
|
| 281 |
+
mask = self._norm(mask)
|
| 282 |
+
masked_image = self._norm(masked_image)
|
| 283 |
+
|
| 284 |
+
c = self.cond_stage_model_encode(masked_image)
|
| 285 |
+
torch.cuda.empty_cache()
|
| 286 |
+
|
| 287 |
+
cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
|
| 288 |
+
c = torch.cat((c, cc), dim=1) # 1,4,128,128
|
| 289 |
+
|
| 290 |
+
shape = (c.shape[1] - 1,) + c.shape[2:]
|
| 291 |
+
samples_ddim = sampler.sample(
|
| 292 |
+
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
|
| 293 |
+
)
|
| 294 |
+
torch.cuda.empty_cache()
|
| 295 |
+
x_samples_ddim = self.cond_stage_model_decode(
|
| 296 |
+
samples_ddim
|
| 297 |
+
) # samples_ddim: 1, 3, 128, 128 float32
|
| 298 |
+
torch.cuda.empty_cache()
|
| 299 |
+
|
| 300 |
+
# image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
| 301 |
+
# mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
|
| 302 |
+
inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
| 303 |
+
|
| 304 |
+
# inpainted = (1 - mask) * image + mask * predicted_image
|
| 305 |
+
inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
| 306 |
+
inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
|
| 307 |
+
return inpainted_image
|
| 308 |
+
|
| 309 |
+
def _norm(self, tensor):
|
| 310 |
+
return tensor * 2.0 - 1.0
|
lama_cleaner/model/manga.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from loguru import logger
|
| 9 |
+
|
| 10 |
+
from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
|
| 11 |
+
from lama_cleaner.model.base import InpaintModel
|
| 12 |
+
from lama_cleaner.schema import Config
|
| 13 |
+
|
| 14 |
+
# def norm(np_img):
|
| 15 |
+
# return np_img / 255 * 2 - 1.0
|
| 16 |
+
#
|
| 17 |
+
#
|
| 18 |
+
# @torch.no_grad()
|
| 19 |
+
# def run():
|
| 20 |
+
# name = 'manga_1080x740.jpg'
|
| 21 |
+
# img_p = f'/Users/qing/code/github/MangaInpainting/examples/test/imgs/{name}'
|
| 22 |
+
# mask_p = f'/Users/qing/code/github/MangaInpainting/examples/test/masks/mask_{name}'
|
| 23 |
+
# erika_model = torch.jit.load('erika.jit')
|
| 24 |
+
# manga_inpaintor_model = torch.jit.load('manga_inpaintor.jit')
|
| 25 |
+
#
|
| 26 |
+
# img = cv2.imread(img_p)
|
| 27 |
+
# gray_img = cv2.imread(img_p, cv2.IMREAD_GRAYSCALE)
|
| 28 |
+
# mask = cv2.imread(mask_p, cv2.IMREAD_GRAYSCALE)
|
| 29 |
+
#
|
| 30 |
+
# kernel = np.ones((9, 9), dtype=np.uint8)
|
| 31 |
+
# mask = cv2.dilate(mask, kernel, 2)
|
| 32 |
+
# # cv2.imwrite("mask.jpg", mask)
|
| 33 |
+
# # cv2.imshow('dilated_mask', cv2.hconcat([mask, dilated_mask]))
|
| 34 |
+
# # cv2.waitKey(0)
|
| 35 |
+
# # exit()
|
| 36 |
+
#
|
| 37 |
+
# # img = pad(img)
|
| 38 |
+
# gray_img = pad(gray_img).astype(np.float32)
|
| 39 |
+
# mask = pad(mask)
|
| 40 |
+
#
|
| 41 |
+
# # pad_mod = 16
|
| 42 |
+
# import time
|
| 43 |
+
# start = time.time()
|
| 44 |
+
# y = erika_model(torch.from_numpy(gray_img[np.newaxis, np.newaxis, :, :]))
|
| 45 |
+
# y = torch.clamp(y, 0, 255)
|
| 46 |
+
# lines = y.cpu().numpy()
|
| 47 |
+
# print(f"erika_model time: {time.time() - start}")
|
| 48 |
+
#
|
| 49 |
+
# cv2.imwrite('lines.png', lines[0][0])
|
| 50 |
+
#
|
| 51 |
+
# start = time.time()
|
| 52 |
+
# masks = torch.from_numpy(mask[np.newaxis, np.newaxis, :, :])
|
| 53 |
+
# masks = torch.where(masks > 0.5, torch.tensor(1.0), torch.tensor(0.0))
|
| 54 |
+
# noise = torch.randn_like(masks)
|
| 55 |
+
#
|
| 56 |
+
# images = torch.from_numpy(norm(gray_img)[np.newaxis, np.newaxis, :, :])
|
| 57 |
+
# lines = torch.from_numpy(norm(lines))
|
| 58 |
+
#
|
| 59 |
+
# outputs = manga_inpaintor_model(images, lines, masks, noise)
|
| 60 |
+
# print(f"manga_inpaintor_model time: {time.time() - start}")
|
| 61 |
+
#
|
| 62 |
+
# outputs_merged = (outputs * masks) + (images * (1 - masks))
|
| 63 |
+
# outputs_merged = outputs_merged * 127.5 + 127.5
|
| 64 |
+
# outputs_merged = outputs_merged.permute(0, 2, 3, 1)[0].detach().cpu().numpy().astype(np.uint8)
|
| 65 |
+
# cv2.imwrite(f'output_{name}', outputs_merged)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
MANGA_INPAINTOR_MODEL_URL = os.environ.get(
|
| 69 |
+
"MANGA_INPAINTOR_MODEL_URL",
|
| 70 |
+
"https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit"
|
| 71 |
+
)
|
| 72 |
+
MANGA_LINE_MODEL_URL = os.environ.get(
|
| 73 |
+
"MANGA_LINE_MODEL_URL",
|
| 74 |
+
"https://github.com/Sanster/models/releases/download/manga/erika.jit"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Manga(InpaintModel):
|
| 79 |
+
pad_mod = 16
|
| 80 |
+
|
| 81 |
+
def init_model(self, device, **kwargs):
|
| 82 |
+
self.inpaintor_model = load_jit_model(MANGA_INPAINTOR_MODEL_URL, device)
|
| 83 |
+
self.line_model = load_jit_model(MANGA_LINE_MODEL_URL, device)
|
| 84 |
+
self.seed = 42
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
def is_downloaded() -> bool:
|
| 88 |
+
model_paths = [
|
| 89 |
+
get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL),
|
| 90 |
+
get_cache_path_by_url(MANGA_LINE_MODEL_URL),
|
| 91 |
+
]
|
| 92 |
+
return all([os.path.exists(it) for it in model_paths])
|
| 93 |
+
|
| 94 |
+
def forward(self, image, mask, config: Config):
|
| 95 |
+
"""
|
| 96 |
+
image: [H, W, C] RGB
|
| 97 |
+
mask: [H, W, 1]
|
| 98 |
+
return: BGR IMAGE
|
| 99 |
+
"""
|
| 100 |
+
seed = self.seed
|
| 101 |
+
random.seed(seed)
|
| 102 |
+
np.random.seed(seed)
|
| 103 |
+
torch.manual_seed(seed)
|
| 104 |
+
torch.cuda.manual_seed_all(seed)
|
| 105 |
+
|
| 106 |
+
gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 107 |
+
gray_img = torch.from_numpy(gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)).to(self.device)
|
| 108 |
+
start = time.time()
|
| 109 |
+
lines = self.line_model(gray_img)
|
| 110 |
+
torch.cuda.empty_cache()
|
| 111 |
+
lines = torch.clamp(lines, 0, 255)
|
| 112 |
+
logger.info(f"erika_model time: {time.time() - start}")
|
| 113 |
+
|
| 114 |
+
mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device)
|
| 115 |
+
mask = mask.permute(0, 3, 1, 2)
|
| 116 |
+
mask = torch.where(mask > 0.5, 1.0, 0.0)
|
| 117 |
+
noise = torch.randn_like(mask)
|
| 118 |
+
ones = torch.ones_like(mask)
|
| 119 |
+
|
| 120 |
+
gray_img = gray_img / 255 * 2 - 1.0
|
| 121 |
+
lines = lines / 255 * 2 - 1.0
|
| 122 |
+
|
| 123 |
+
start = time.time()
|
| 124 |
+
inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones)
|
| 125 |
+
logger.info(f"image_inpaintor_model time: {time.time() - start}")
|
| 126 |
+
|
| 127 |
+
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
| 128 |
+
cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8)
|
| 129 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR)
|
| 130 |
+
return cur_res
|
lama_cleaner/model/mat.py
ADDED
|
@@ -0,0 +1,1444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.utils.checkpoint as checkpoint
|
| 10 |
+
|
| 11 |
+
from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img
|
| 12 |
+
from lama_cleaner.model.base import InpaintModel
|
| 13 |
+
from lama_cleaner.model.utils import setup_filter, Conv2dLayer, FullyConnectedLayer, conv2d_resample, bias_act, \
|
| 14 |
+
upsample2d, activation_funcs, MinibatchStdLayer, to_2tuple, normalize_2nd_moment
|
| 15 |
+
from lama_cleaner.schema import Config
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ModulatedConv2d(nn.Module):
|
| 19 |
+
def __init__(self,
|
| 20 |
+
in_channels, # Number of input channels.
|
| 21 |
+
out_channels, # Number of output channels.
|
| 22 |
+
kernel_size, # Width and height of the convolution kernel.
|
| 23 |
+
style_dim, # dimension of the style code
|
| 24 |
+
demodulate=True, # perfrom demodulation
|
| 25 |
+
up=1, # Integer upsampling factor.
|
| 26 |
+
down=1, # Integer downsampling factor.
|
| 27 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
| 28 |
+
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.demodulate = demodulate
|
| 32 |
+
|
| 33 |
+
self.weight = torch.nn.Parameter(torch.randn([1, out_channels, in_channels, kernel_size, kernel_size]))
|
| 34 |
+
self.out_channels = out_channels
|
| 35 |
+
self.kernel_size = kernel_size
|
| 36 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
| 37 |
+
self.padding = self.kernel_size // 2
|
| 38 |
+
self.up = up
|
| 39 |
+
self.down = down
|
| 40 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
| 41 |
+
self.conv_clamp = conv_clamp
|
| 42 |
+
|
| 43 |
+
self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1)
|
| 44 |
+
|
| 45 |
+
def forward(self, x, style):
|
| 46 |
+
batch, in_channels, height, width = x.shape
|
| 47 |
+
style = self.affine(style).view(batch, 1, in_channels, 1, 1)
|
| 48 |
+
weight = self.weight * self.weight_gain * style
|
| 49 |
+
|
| 50 |
+
if self.demodulate:
|
| 51 |
+
decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt()
|
| 52 |
+
weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1)
|
| 53 |
+
|
| 54 |
+
weight = weight.view(batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size)
|
| 55 |
+
x = x.view(1, batch * in_channels, height, width)
|
| 56 |
+
x = conv2d_resample(x=x, w=weight, f=self.resample_filter, up=self.up, down=self.down,
|
| 57 |
+
padding=self.padding, groups=batch)
|
| 58 |
+
out = x.view(batch, self.out_channels, *x.shape[2:])
|
| 59 |
+
|
| 60 |
+
return out
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class StyleConv(torch.nn.Module):
|
| 64 |
+
def __init__(self,
|
| 65 |
+
in_channels, # Number of input channels.
|
| 66 |
+
out_channels, # Number of output channels.
|
| 67 |
+
style_dim, # Intermediate latent (W) dimensionality.
|
| 68 |
+
resolution, # Resolution of this layer.
|
| 69 |
+
kernel_size=3, # Convolution kernel size.
|
| 70 |
+
up=1, # Integer upsampling factor.
|
| 71 |
+
use_noise=False, # Enable noise input?
|
| 72 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 73 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
| 74 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
| 75 |
+
demodulate=True, # perform demodulation
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
self.conv = ModulatedConv2d(in_channels=in_channels,
|
| 80 |
+
out_channels=out_channels,
|
| 81 |
+
kernel_size=kernel_size,
|
| 82 |
+
style_dim=style_dim,
|
| 83 |
+
demodulate=demodulate,
|
| 84 |
+
up=up,
|
| 85 |
+
resample_filter=resample_filter,
|
| 86 |
+
conv_clamp=conv_clamp)
|
| 87 |
+
|
| 88 |
+
self.use_noise = use_noise
|
| 89 |
+
self.resolution = resolution
|
| 90 |
+
if use_noise:
|
| 91 |
+
self.register_buffer('noise_const', torch.randn([resolution, resolution]))
|
| 92 |
+
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
|
| 93 |
+
|
| 94 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
| 95 |
+
self.activation = activation
|
| 96 |
+
self.act_gain = activation_funcs[activation].def_gain
|
| 97 |
+
self.conv_clamp = conv_clamp
|
| 98 |
+
|
| 99 |
+
def forward(self, x, style, noise_mode='random', gain=1):
|
| 100 |
+
x = self.conv(x, style)
|
| 101 |
+
|
| 102 |
+
assert noise_mode in ['random', 'const', 'none']
|
| 103 |
+
|
| 104 |
+
if self.use_noise:
|
| 105 |
+
if noise_mode == 'random':
|
| 106 |
+
xh, xw = x.size()[-2:]
|
| 107 |
+
noise = torch.randn([x.shape[0], 1, xh, xw], device=x.device) \
|
| 108 |
+
* self.noise_strength
|
| 109 |
+
if noise_mode == 'const':
|
| 110 |
+
noise = self.noise_const * self.noise_strength
|
| 111 |
+
x = x + noise
|
| 112 |
+
|
| 113 |
+
act_gain = self.act_gain * gain
|
| 114 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
| 115 |
+
out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
|
| 116 |
+
|
| 117 |
+
return out
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class ToRGB(torch.nn.Module):
|
| 121 |
+
def __init__(self,
|
| 122 |
+
in_channels,
|
| 123 |
+
out_channels,
|
| 124 |
+
style_dim,
|
| 125 |
+
kernel_size=1,
|
| 126 |
+
resample_filter=[1, 3, 3, 1],
|
| 127 |
+
conv_clamp=None,
|
| 128 |
+
demodulate=False):
|
| 129 |
+
super().__init__()
|
| 130 |
+
|
| 131 |
+
self.conv = ModulatedConv2d(in_channels=in_channels,
|
| 132 |
+
out_channels=out_channels,
|
| 133 |
+
kernel_size=kernel_size,
|
| 134 |
+
style_dim=style_dim,
|
| 135 |
+
demodulate=demodulate,
|
| 136 |
+
resample_filter=resample_filter,
|
| 137 |
+
conv_clamp=conv_clamp)
|
| 138 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
| 139 |
+
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
| 140 |
+
self.conv_clamp = conv_clamp
|
| 141 |
+
|
| 142 |
+
def forward(self, x, style, skip=None):
|
| 143 |
+
x = self.conv(x, style)
|
| 144 |
+
out = bias_act(x, self.bias, clamp=self.conv_clamp)
|
| 145 |
+
|
| 146 |
+
if skip is not None:
|
| 147 |
+
if skip.shape != out.shape:
|
| 148 |
+
skip = upsample2d(skip, self.resample_filter)
|
| 149 |
+
out = out + skip
|
| 150 |
+
|
| 151 |
+
return out
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_style_code(a, b):
|
| 155 |
+
return torch.cat([a, b], dim=1)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class DecBlockFirst(nn.Module):
|
| 159 |
+
def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.fc = FullyConnectedLayer(in_features=in_channels * 2,
|
| 162 |
+
out_features=in_channels * 4 ** 2,
|
| 163 |
+
activation=activation)
|
| 164 |
+
self.conv = StyleConv(in_channels=in_channels,
|
| 165 |
+
out_channels=out_channels,
|
| 166 |
+
style_dim=style_dim,
|
| 167 |
+
resolution=4,
|
| 168 |
+
kernel_size=3,
|
| 169 |
+
use_noise=use_noise,
|
| 170 |
+
activation=activation,
|
| 171 |
+
demodulate=demodulate,
|
| 172 |
+
)
|
| 173 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
| 174 |
+
out_channels=img_channels,
|
| 175 |
+
style_dim=style_dim,
|
| 176 |
+
kernel_size=1,
|
| 177 |
+
demodulate=False,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def forward(self, x, ws, gs, E_features, noise_mode='random'):
|
| 181 |
+
x = self.fc(x).view(x.shape[0], -1, 4, 4)
|
| 182 |
+
x = x + E_features[2]
|
| 183 |
+
style = get_style_code(ws[:, 0], gs)
|
| 184 |
+
x = self.conv(x, style, noise_mode=noise_mode)
|
| 185 |
+
style = get_style_code(ws[:, 1], gs)
|
| 186 |
+
img = self.toRGB(x, style, skip=None)
|
| 187 |
+
|
| 188 |
+
return x, img
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class DecBlockFirstV2(nn.Module):
|
| 192 |
+
def __init__(self, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
| 195 |
+
out_channels=in_channels,
|
| 196 |
+
kernel_size=3,
|
| 197 |
+
activation=activation,
|
| 198 |
+
)
|
| 199 |
+
self.conv1 = StyleConv(in_channels=in_channels,
|
| 200 |
+
out_channels=out_channels,
|
| 201 |
+
style_dim=style_dim,
|
| 202 |
+
resolution=4,
|
| 203 |
+
kernel_size=3,
|
| 204 |
+
use_noise=use_noise,
|
| 205 |
+
activation=activation,
|
| 206 |
+
demodulate=demodulate,
|
| 207 |
+
)
|
| 208 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
| 209 |
+
out_channels=img_channels,
|
| 210 |
+
style_dim=style_dim,
|
| 211 |
+
kernel_size=1,
|
| 212 |
+
demodulate=False,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
def forward(self, x, ws, gs, E_features, noise_mode='random'):
|
| 216 |
+
# x = self.fc(x).view(x.shape[0], -1, 4, 4)
|
| 217 |
+
x = self.conv0(x)
|
| 218 |
+
x = x + E_features[2]
|
| 219 |
+
style = get_style_code(ws[:, 0], gs)
|
| 220 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
| 221 |
+
style = get_style_code(ws[:, 1], gs)
|
| 222 |
+
img = self.toRGB(x, style, skip=None)
|
| 223 |
+
|
| 224 |
+
return x, img
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class DecBlock(nn.Module):
|
| 228 |
+
def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate,
|
| 229 |
+
img_channels): # res = 2, ..., resolution_log2
|
| 230 |
+
super().__init__()
|
| 231 |
+
self.res = res
|
| 232 |
+
|
| 233 |
+
self.conv0 = StyleConv(in_channels=in_channels,
|
| 234 |
+
out_channels=out_channels,
|
| 235 |
+
style_dim=style_dim,
|
| 236 |
+
resolution=2 ** res,
|
| 237 |
+
kernel_size=3,
|
| 238 |
+
up=2,
|
| 239 |
+
use_noise=use_noise,
|
| 240 |
+
activation=activation,
|
| 241 |
+
demodulate=demodulate,
|
| 242 |
+
)
|
| 243 |
+
self.conv1 = StyleConv(in_channels=out_channels,
|
| 244 |
+
out_channels=out_channels,
|
| 245 |
+
style_dim=style_dim,
|
| 246 |
+
resolution=2 ** res,
|
| 247 |
+
kernel_size=3,
|
| 248 |
+
use_noise=use_noise,
|
| 249 |
+
activation=activation,
|
| 250 |
+
demodulate=demodulate,
|
| 251 |
+
)
|
| 252 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
| 253 |
+
out_channels=img_channels,
|
| 254 |
+
style_dim=style_dim,
|
| 255 |
+
kernel_size=1,
|
| 256 |
+
demodulate=False,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
|
| 260 |
+
style = get_style_code(ws[:, self.res * 2 - 5], gs)
|
| 261 |
+
x = self.conv0(x, style, noise_mode=noise_mode)
|
| 262 |
+
x = x + E_features[self.res]
|
| 263 |
+
style = get_style_code(ws[:, self.res * 2 - 4], gs)
|
| 264 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
| 265 |
+
style = get_style_code(ws[:, self.res * 2 - 3], gs)
|
| 266 |
+
img = self.toRGB(x, style, skip=img)
|
| 267 |
+
|
| 268 |
+
return x, img
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class MappingNet(torch.nn.Module):
|
| 272 |
+
def __init__(self,
|
| 273 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
| 274 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
| 275 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 276 |
+
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
| 277 |
+
num_layers=8, # Number of mapping layers.
|
| 278 |
+
embed_features=None, # Label embedding dimensionality, None = same as w_dim.
|
| 279 |
+
layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
| 280 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 281 |
+
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
|
| 282 |
+
w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
| 283 |
+
):
|
| 284 |
+
super().__init__()
|
| 285 |
+
self.z_dim = z_dim
|
| 286 |
+
self.c_dim = c_dim
|
| 287 |
+
self.w_dim = w_dim
|
| 288 |
+
self.num_ws = num_ws
|
| 289 |
+
self.num_layers = num_layers
|
| 290 |
+
self.w_avg_beta = w_avg_beta
|
| 291 |
+
|
| 292 |
+
if embed_features is None:
|
| 293 |
+
embed_features = w_dim
|
| 294 |
+
if c_dim == 0:
|
| 295 |
+
embed_features = 0
|
| 296 |
+
if layer_features is None:
|
| 297 |
+
layer_features = w_dim
|
| 298 |
+
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
| 299 |
+
|
| 300 |
+
if c_dim > 0:
|
| 301 |
+
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
| 302 |
+
for idx in range(num_layers):
|
| 303 |
+
in_features = features_list[idx]
|
| 304 |
+
out_features = features_list[idx + 1]
|
| 305 |
+
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
|
| 306 |
+
setattr(self, f'fc{idx}', layer)
|
| 307 |
+
|
| 308 |
+
if num_ws is not None and w_avg_beta is not None:
|
| 309 |
+
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
| 310 |
+
|
| 311 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
|
| 312 |
+
# Embed, normalize, and concat inputs.
|
| 313 |
+
x = None
|
| 314 |
+
with torch.autograd.profiler.record_function('input'):
|
| 315 |
+
if self.z_dim > 0:
|
| 316 |
+
x = normalize_2nd_moment(z.to(torch.float32))
|
| 317 |
+
if self.c_dim > 0:
|
| 318 |
+
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
| 319 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
| 320 |
+
|
| 321 |
+
# Main layers.
|
| 322 |
+
for idx in range(self.num_layers):
|
| 323 |
+
layer = getattr(self, f'fc{idx}')
|
| 324 |
+
x = layer(x)
|
| 325 |
+
|
| 326 |
+
# Update moving average of W.
|
| 327 |
+
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
| 328 |
+
with torch.autograd.profiler.record_function('update_w_avg'):
|
| 329 |
+
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
| 330 |
+
|
| 331 |
+
# Broadcast.
|
| 332 |
+
if self.num_ws is not None:
|
| 333 |
+
with torch.autograd.profiler.record_function('broadcast'):
|
| 334 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
| 335 |
+
|
| 336 |
+
# Apply truncation.
|
| 337 |
+
if truncation_psi != 1:
|
| 338 |
+
with torch.autograd.profiler.record_function('truncate'):
|
| 339 |
+
assert self.w_avg_beta is not None
|
| 340 |
+
if self.num_ws is None or truncation_cutoff is None:
|
| 341 |
+
x = self.w_avg.lerp(x, truncation_psi)
|
| 342 |
+
else:
|
| 343 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
| 344 |
+
|
| 345 |
+
return x
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class DisFromRGB(nn.Module):
|
| 349 |
+
def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
|
| 350 |
+
super().__init__()
|
| 351 |
+
self.conv = Conv2dLayer(in_channels=in_channels,
|
| 352 |
+
out_channels=out_channels,
|
| 353 |
+
kernel_size=1,
|
| 354 |
+
activation=activation,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
def forward(self, x):
|
| 358 |
+
return self.conv(x)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class DisBlock(nn.Module):
|
| 362 |
+
def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
|
| 363 |
+
super().__init__()
|
| 364 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
| 365 |
+
out_channels=in_channels,
|
| 366 |
+
kernel_size=3,
|
| 367 |
+
activation=activation,
|
| 368 |
+
)
|
| 369 |
+
self.conv1 = Conv2dLayer(in_channels=in_channels,
|
| 370 |
+
out_channels=out_channels,
|
| 371 |
+
kernel_size=3,
|
| 372 |
+
down=2,
|
| 373 |
+
activation=activation,
|
| 374 |
+
)
|
| 375 |
+
self.skip = Conv2dLayer(in_channels=in_channels,
|
| 376 |
+
out_channels=out_channels,
|
| 377 |
+
kernel_size=1,
|
| 378 |
+
down=2,
|
| 379 |
+
bias=False,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
def forward(self, x):
|
| 383 |
+
skip = self.skip(x, gain=np.sqrt(0.5))
|
| 384 |
+
x = self.conv0(x)
|
| 385 |
+
x = self.conv1(x, gain=np.sqrt(0.5))
|
| 386 |
+
out = skip + x
|
| 387 |
+
|
| 388 |
+
return out
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class Discriminator(torch.nn.Module):
|
| 392 |
+
def __init__(self,
|
| 393 |
+
c_dim, # Conditioning label (C) dimensionality.
|
| 394 |
+
img_resolution, # Input resolution.
|
| 395 |
+
img_channels, # Number of input color channels.
|
| 396 |
+
channel_base=32768, # Overall multiplier for the number of channels.
|
| 397 |
+
channel_max=512, # Maximum number of channels in any layer.
|
| 398 |
+
channel_decay=1,
|
| 399 |
+
cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
|
| 400 |
+
activation='lrelu',
|
| 401 |
+
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
| 402 |
+
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
| 403 |
+
):
|
| 404 |
+
super().__init__()
|
| 405 |
+
self.c_dim = c_dim
|
| 406 |
+
self.img_resolution = img_resolution
|
| 407 |
+
self.img_channels = img_channels
|
| 408 |
+
|
| 409 |
+
resolution_log2 = int(np.log2(img_resolution))
|
| 410 |
+
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
| 411 |
+
self.resolution_log2 = resolution_log2
|
| 412 |
+
|
| 413 |
+
def nf(stage):
|
| 414 |
+
return np.clip(int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max)
|
| 415 |
+
|
| 416 |
+
if cmap_dim == None:
|
| 417 |
+
cmap_dim = nf(2)
|
| 418 |
+
if c_dim == 0:
|
| 419 |
+
cmap_dim = 0
|
| 420 |
+
self.cmap_dim = cmap_dim
|
| 421 |
+
|
| 422 |
+
if c_dim > 0:
|
| 423 |
+
self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)
|
| 424 |
+
|
| 425 |
+
Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
|
| 426 |
+
for res in range(resolution_log2, 2, -1):
|
| 427 |
+
Dis.append(DisBlock(nf(res), nf(res - 1), activation))
|
| 428 |
+
|
| 429 |
+
if mbstd_num_channels > 0:
|
| 430 |
+
Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
|
| 431 |
+
Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
|
| 432 |
+
self.Dis = nn.Sequential(*Dis)
|
| 433 |
+
|
| 434 |
+
self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
|
| 435 |
+
self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
|
| 436 |
+
|
| 437 |
+
def forward(self, images_in, masks_in, c):
|
| 438 |
+
x = torch.cat([masks_in - 0.5, images_in], dim=1)
|
| 439 |
+
x = self.Dis(x)
|
| 440 |
+
x = self.fc1(self.fc0(x.flatten(start_dim=1)))
|
| 441 |
+
|
| 442 |
+
if self.c_dim > 0:
|
| 443 |
+
cmap = self.mapping(None, c)
|
| 444 |
+
|
| 445 |
+
if self.cmap_dim > 0:
|
| 446 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
| 447 |
+
|
| 448 |
+
return x
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512):
|
| 452 |
+
NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512}
|
| 453 |
+
return NF[2 ** stage]
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
class Mlp(nn.Module):
|
| 457 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 458 |
+
super().__init__()
|
| 459 |
+
out_features = out_features or in_features
|
| 460 |
+
hidden_features = hidden_features or in_features
|
| 461 |
+
self.fc1 = FullyConnectedLayer(in_features=in_features, out_features=hidden_features, activation='lrelu')
|
| 462 |
+
self.fc2 = FullyConnectedLayer(in_features=hidden_features, out_features=out_features)
|
| 463 |
+
|
| 464 |
+
def forward(self, x):
|
| 465 |
+
x = self.fc1(x)
|
| 466 |
+
x = self.fc2(x)
|
| 467 |
+
return x
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def window_partition(x, window_size):
|
| 471 |
+
"""
|
| 472 |
+
Args:
|
| 473 |
+
x: (B, H, W, C)
|
| 474 |
+
window_size (int): window size
|
| 475 |
+
Returns:
|
| 476 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 477 |
+
"""
|
| 478 |
+
B, H, W, C = x.shape
|
| 479 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 480 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 481 |
+
return windows
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def window_reverse(windows, window_size: int, H: int, W: int):
|
| 485 |
+
"""
|
| 486 |
+
Args:
|
| 487 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 488 |
+
window_size (int): Window size
|
| 489 |
+
H (int): Height of image
|
| 490 |
+
W (int): Width of image
|
| 491 |
+
Returns:
|
| 492 |
+
x: (B, H, W, C)
|
| 493 |
+
"""
|
| 494 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 495 |
+
# B = windows.shape[0] / (H * W / window_size / window_size)
|
| 496 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 497 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 498 |
+
return x
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class Conv2dLayerPartial(nn.Module):
|
| 502 |
+
def __init__(self,
|
| 503 |
+
in_channels, # Number of input channels.
|
| 504 |
+
out_channels, # Number of output channels.
|
| 505 |
+
kernel_size, # Width and height of the convolution kernel.
|
| 506 |
+
bias=True, # Apply additive bias before the activation function?
|
| 507 |
+
activation='linear', # Activation function: 'relu', 'lrelu', etc.
|
| 508 |
+
up=1, # Integer upsampling factor.
|
| 509 |
+
down=1, # Integer downsampling factor.
|
| 510 |
+
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
| 511 |
+
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
| 512 |
+
trainable=True, # Update the weights of this layer during training?
|
| 513 |
+
):
|
| 514 |
+
super().__init__()
|
| 515 |
+
self.conv = Conv2dLayer(in_channels, out_channels, kernel_size, bias, activation, up, down, resample_filter,
|
| 516 |
+
conv_clamp, trainable)
|
| 517 |
+
|
| 518 |
+
self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size)
|
| 519 |
+
self.slide_winsize = kernel_size ** 2
|
| 520 |
+
self.stride = down
|
| 521 |
+
self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0
|
| 522 |
+
|
| 523 |
+
def forward(self, x, mask=None):
|
| 524 |
+
if mask is not None:
|
| 525 |
+
with torch.no_grad():
|
| 526 |
+
if self.weight_maskUpdater.type() != x.type():
|
| 527 |
+
self.weight_maskUpdater = self.weight_maskUpdater.to(x)
|
| 528 |
+
update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride,
|
| 529 |
+
padding=self.padding)
|
| 530 |
+
mask_ratio = self.slide_winsize / (update_mask + 1e-8)
|
| 531 |
+
update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1
|
| 532 |
+
mask_ratio = torch.mul(mask_ratio, update_mask)
|
| 533 |
+
x = self.conv(x)
|
| 534 |
+
x = torch.mul(x, mask_ratio)
|
| 535 |
+
return x, update_mask
|
| 536 |
+
else:
|
| 537 |
+
x = self.conv(x)
|
| 538 |
+
return x, None
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
class WindowAttention(nn.Module):
|
| 542 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 543 |
+
It supports both of shifted and non-shifted window.
|
| 544 |
+
Args:
|
| 545 |
+
dim (int): Number of input channels.
|
| 546 |
+
window_size (tuple[int]): The height and width of the window.
|
| 547 |
+
num_heads (int): Number of attention heads.
|
| 548 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 549 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 550 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 551 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
def __init__(self, dim, window_size, num_heads, down_ratio=1, qkv_bias=True, qk_scale=None, attn_drop=0.,
|
| 555 |
+
proj_drop=0.):
|
| 556 |
+
|
| 557 |
+
super().__init__()
|
| 558 |
+
self.dim = dim
|
| 559 |
+
self.window_size = window_size # Wh, Ww
|
| 560 |
+
self.num_heads = num_heads
|
| 561 |
+
head_dim = dim // num_heads
|
| 562 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 563 |
+
|
| 564 |
+
self.q = FullyConnectedLayer(in_features=dim, out_features=dim)
|
| 565 |
+
self.k = FullyConnectedLayer(in_features=dim, out_features=dim)
|
| 566 |
+
self.v = FullyConnectedLayer(in_features=dim, out_features=dim)
|
| 567 |
+
self.proj = FullyConnectedLayer(in_features=dim, out_features=dim)
|
| 568 |
+
|
| 569 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 570 |
+
|
| 571 |
+
def forward(self, x, mask_windows=None, mask=None):
|
| 572 |
+
"""
|
| 573 |
+
Args:
|
| 574 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 575 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 576 |
+
"""
|
| 577 |
+
B_, N, C = x.shape
|
| 578 |
+
norm_x = F.normalize(x, p=2.0, dim=-1)
|
| 579 |
+
q = self.q(norm_x).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 580 |
+
k = self.k(norm_x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 3, 1)
|
| 581 |
+
v = self.v(x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 582 |
+
|
| 583 |
+
attn = (q @ k) * self.scale
|
| 584 |
+
|
| 585 |
+
if mask is not None:
|
| 586 |
+
nW = mask.shape[0]
|
| 587 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
| 588 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 589 |
+
|
| 590 |
+
if mask_windows is not None:
|
| 591 |
+
attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1)
|
| 592 |
+
attn = attn + attn_mask_windows.masked_fill(attn_mask_windows == 0, float(-100.0)).masked_fill(
|
| 593 |
+
attn_mask_windows == 1, float(0.0))
|
| 594 |
+
with torch.no_grad():
|
| 595 |
+
mask_windows = torch.clamp(torch.sum(mask_windows, dim=1, keepdim=True), 0, 1).repeat(1, N, 1)
|
| 596 |
+
|
| 597 |
+
attn = self.softmax(attn)
|
| 598 |
+
|
| 599 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 600 |
+
x = self.proj(x)
|
| 601 |
+
return x, mask_windows
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
class SwinTransformerBlock(nn.Module):
|
| 605 |
+
r""" Swin Transformer Block.
|
| 606 |
+
Args:
|
| 607 |
+
dim (int): Number of input channels.
|
| 608 |
+
input_resolution (tuple[int]): Input resulotion.
|
| 609 |
+
num_heads (int): Number of attention heads.
|
| 610 |
+
window_size (int): Window size.
|
| 611 |
+
shift_size (int): Shift size for SW-MSA.
|
| 612 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 613 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 614 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 615 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 616 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 617 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 618 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 619 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 620 |
+
"""
|
| 621 |
+
|
| 622 |
+
def __init__(self, dim, input_resolution, num_heads, down_ratio=1, window_size=7, shift_size=0,
|
| 623 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
| 624 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 625 |
+
super().__init__()
|
| 626 |
+
self.dim = dim
|
| 627 |
+
self.input_resolution = input_resolution
|
| 628 |
+
self.num_heads = num_heads
|
| 629 |
+
self.window_size = window_size
|
| 630 |
+
self.shift_size = shift_size
|
| 631 |
+
self.mlp_ratio = mlp_ratio
|
| 632 |
+
if min(self.input_resolution) <= self.window_size:
|
| 633 |
+
# if window size is larger than input resolution, we don't partition windows
|
| 634 |
+
self.shift_size = 0
|
| 635 |
+
self.window_size = min(self.input_resolution)
|
| 636 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 637 |
+
|
| 638 |
+
if self.shift_size > 0:
|
| 639 |
+
down_ratio = 1
|
| 640 |
+
self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
| 641 |
+
down_ratio=down_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
| 642 |
+
proj_drop=drop)
|
| 643 |
+
|
| 644 |
+
self.fuse = FullyConnectedLayer(in_features=dim * 2, out_features=dim, activation='lrelu')
|
| 645 |
+
|
| 646 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 647 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 648 |
+
|
| 649 |
+
if self.shift_size > 0:
|
| 650 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
| 651 |
+
else:
|
| 652 |
+
attn_mask = None
|
| 653 |
+
|
| 654 |
+
self.register_buffer("attn_mask", attn_mask)
|
| 655 |
+
|
| 656 |
+
def calculate_mask(self, x_size):
|
| 657 |
+
# calculate attention mask for SW-MSA
|
| 658 |
+
H, W = x_size
|
| 659 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
| 660 |
+
h_slices = (slice(0, -self.window_size),
|
| 661 |
+
slice(-self.window_size, -self.shift_size),
|
| 662 |
+
slice(-self.shift_size, None))
|
| 663 |
+
w_slices = (slice(0, -self.window_size),
|
| 664 |
+
slice(-self.window_size, -self.shift_size),
|
| 665 |
+
slice(-self.shift_size, None))
|
| 666 |
+
cnt = 0
|
| 667 |
+
for h in h_slices:
|
| 668 |
+
for w in w_slices:
|
| 669 |
+
img_mask[:, h, w, :] = cnt
|
| 670 |
+
cnt += 1
|
| 671 |
+
|
| 672 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 673 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
| 674 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 675 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 676 |
+
|
| 677 |
+
return attn_mask
|
| 678 |
+
|
| 679 |
+
def forward(self, x, x_size, mask=None):
|
| 680 |
+
# H, W = self.input_resolution
|
| 681 |
+
H, W = x_size
|
| 682 |
+
B, L, C = x.shape
|
| 683 |
+
# assert L == H * W, "input feature has wrong size"
|
| 684 |
+
|
| 685 |
+
shortcut = x
|
| 686 |
+
x = x.view(B, H, W, C)
|
| 687 |
+
if mask is not None:
|
| 688 |
+
mask = mask.view(B, H, W, 1)
|
| 689 |
+
|
| 690 |
+
# cyclic shift
|
| 691 |
+
if self.shift_size > 0:
|
| 692 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 693 |
+
if mask is not None:
|
| 694 |
+
shifted_mask = torch.roll(mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 695 |
+
else:
|
| 696 |
+
shifted_x = x
|
| 697 |
+
if mask is not None:
|
| 698 |
+
shifted_mask = mask
|
| 699 |
+
|
| 700 |
+
# partition windows
|
| 701 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 702 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
| 703 |
+
if mask is not None:
|
| 704 |
+
mask_windows = window_partition(shifted_mask, self.window_size)
|
| 705 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1)
|
| 706 |
+
else:
|
| 707 |
+
mask_windows = None
|
| 708 |
+
|
| 709 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
| 710 |
+
if self.input_resolution == x_size:
|
| 711 |
+
attn_windows, mask_windows = self.attn(x_windows, mask_windows,
|
| 712 |
+
mask=self.attn_mask) # nW*B, window_size*window_size, C
|
| 713 |
+
else:
|
| 714 |
+
attn_windows, mask_windows = self.attn(x_windows, mask_windows, mask=self.calculate_mask(x_size).to(
|
| 715 |
+
x.device)) # nW*B, window_size*window_size, C
|
| 716 |
+
|
| 717 |
+
# merge windows
|
| 718 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 719 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
| 720 |
+
if mask is not None:
|
| 721 |
+
mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1)
|
| 722 |
+
shifted_mask = window_reverse(mask_windows, self.window_size, H, W)
|
| 723 |
+
|
| 724 |
+
# reverse cyclic shift
|
| 725 |
+
if self.shift_size > 0:
|
| 726 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 727 |
+
if mask is not None:
|
| 728 |
+
mask = torch.roll(shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 729 |
+
else:
|
| 730 |
+
x = shifted_x
|
| 731 |
+
if mask is not None:
|
| 732 |
+
mask = shifted_mask
|
| 733 |
+
x = x.view(B, H * W, C)
|
| 734 |
+
if mask is not None:
|
| 735 |
+
mask = mask.view(B, H * W, 1)
|
| 736 |
+
|
| 737 |
+
# FFN
|
| 738 |
+
x = self.fuse(torch.cat([shortcut, x], dim=-1))
|
| 739 |
+
x = self.mlp(x)
|
| 740 |
+
|
| 741 |
+
return x, mask
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
class PatchMerging(nn.Module):
|
| 745 |
+
def __init__(self, in_channels, out_channels, down=2):
|
| 746 |
+
super().__init__()
|
| 747 |
+
self.conv = Conv2dLayerPartial(in_channels=in_channels,
|
| 748 |
+
out_channels=out_channels,
|
| 749 |
+
kernel_size=3,
|
| 750 |
+
activation='lrelu',
|
| 751 |
+
down=down,
|
| 752 |
+
)
|
| 753 |
+
self.down = down
|
| 754 |
+
|
| 755 |
+
def forward(self, x, x_size, mask=None):
|
| 756 |
+
x = token2feature(x, x_size)
|
| 757 |
+
if mask is not None:
|
| 758 |
+
mask = token2feature(mask, x_size)
|
| 759 |
+
x, mask = self.conv(x, mask)
|
| 760 |
+
if self.down != 1:
|
| 761 |
+
ratio = 1 / self.down
|
| 762 |
+
x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio))
|
| 763 |
+
x = feature2token(x)
|
| 764 |
+
if mask is not None:
|
| 765 |
+
mask = feature2token(mask)
|
| 766 |
+
return x, x_size, mask
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
class PatchUpsampling(nn.Module):
|
| 770 |
+
def __init__(self, in_channels, out_channels, up=2):
|
| 771 |
+
super().__init__()
|
| 772 |
+
self.conv = Conv2dLayerPartial(in_channels=in_channels,
|
| 773 |
+
out_channels=out_channels,
|
| 774 |
+
kernel_size=3,
|
| 775 |
+
activation='lrelu',
|
| 776 |
+
up=up,
|
| 777 |
+
)
|
| 778 |
+
self.up = up
|
| 779 |
+
|
| 780 |
+
def forward(self, x, x_size, mask=None):
|
| 781 |
+
x = token2feature(x, x_size)
|
| 782 |
+
if mask is not None:
|
| 783 |
+
mask = token2feature(mask, x_size)
|
| 784 |
+
x, mask = self.conv(x, mask)
|
| 785 |
+
if self.up != 1:
|
| 786 |
+
x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up))
|
| 787 |
+
x = feature2token(x)
|
| 788 |
+
if mask is not None:
|
| 789 |
+
mask = feature2token(mask)
|
| 790 |
+
return x, x_size, mask
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
class BasicLayer(nn.Module):
|
| 794 |
+
""" A basic Swin Transformer layer for one stage.
|
| 795 |
+
Args:
|
| 796 |
+
dim (int): Number of input channels.
|
| 797 |
+
input_resolution (tuple[int]): Input resolution.
|
| 798 |
+
depth (int): Number of blocks.
|
| 799 |
+
num_heads (int): Number of attention heads.
|
| 800 |
+
window_size (int): Local window size.
|
| 801 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 802 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 803 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 804 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 805 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 806 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 807 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 808 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 809 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 810 |
+
"""
|
| 811 |
+
|
| 812 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size, down_ratio=1,
|
| 813 |
+
mlp_ratio=2., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
| 814 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
| 815 |
+
|
| 816 |
+
super().__init__()
|
| 817 |
+
self.dim = dim
|
| 818 |
+
self.input_resolution = input_resolution
|
| 819 |
+
self.depth = depth
|
| 820 |
+
self.use_checkpoint = use_checkpoint
|
| 821 |
+
|
| 822 |
+
# patch merging layer
|
| 823 |
+
if downsample is not None:
|
| 824 |
+
# self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
| 825 |
+
self.downsample = downsample
|
| 826 |
+
else:
|
| 827 |
+
self.downsample = None
|
| 828 |
+
|
| 829 |
+
# build blocks
|
| 830 |
+
self.blocks = nn.ModuleList([
|
| 831 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
| 832 |
+
num_heads=num_heads, down_ratio=down_ratio, window_size=window_size,
|
| 833 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
| 834 |
+
mlp_ratio=mlp_ratio,
|
| 835 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 836 |
+
drop=drop, attn_drop=attn_drop,
|
| 837 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 838 |
+
norm_layer=norm_layer)
|
| 839 |
+
for i in range(depth)])
|
| 840 |
+
|
| 841 |
+
self.conv = Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, activation='lrelu')
|
| 842 |
+
|
| 843 |
+
def forward(self, x, x_size, mask=None):
|
| 844 |
+
if self.downsample is not None:
|
| 845 |
+
x, x_size, mask = self.downsample(x, x_size, mask)
|
| 846 |
+
identity = x
|
| 847 |
+
for blk in self.blocks:
|
| 848 |
+
if self.use_checkpoint:
|
| 849 |
+
x, mask = checkpoint.checkpoint(blk, x, x_size, mask)
|
| 850 |
+
else:
|
| 851 |
+
x, mask = blk(x, x_size, mask)
|
| 852 |
+
if mask is not None:
|
| 853 |
+
mask = token2feature(mask, x_size)
|
| 854 |
+
x, mask = self.conv(token2feature(x, x_size), mask)
|
| 855 |
+
x = feature2token(x) + identity
|
| 856 |
+
if mask is not None:
|
| 857 |
+
mask = feature2token(mask)
|
| 858 |
+
return x, x_size, mask
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
class ToToken(nn.Module):
|
| 862 |
+
def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1):
|
| 863 |
+
super().__init__()
|
| 864 |
+
|
| 865 |
+
self.proj = Conv2dLayerPartial(in_channels=in_channels, out_channels=dim, kernel_size=kernel_size,
|
| 866 |
+
activation='lrelu')
|
| 867 |
+
|
| 868 |
+
def forward(self, x, mask):
|
| 869 |
+
x, mask = self.proj(x, mask)
|
| 870 |
+
|
| 871 |
+
return x, mask
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
class EncFromRGB(nn.Module):
|
| 875 |
+
def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log2
|
| 876 |
+
super().__init__()
|
| 877 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
| 878 |
+
out_channels=out_channels,
|
| 879 |
+
kernel_size=1,
|
| 880 |
+
activation=activation,
|
| 881 |
+
)
|
| 882 |
+
self.conv1 = Conv2dLayer(in_channels=out_channels,
|
| 883 |
+
out_channels=out_channels,
|
| 884 |
+
kernel_size=3,
|
| 885 |
+
activation=activation,
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
def forward(self, x):
|
| 889 |
+
x = self.conv0(x)
|
| 890 |
+
x = self.conv1(x)
|
| 891 |
+
|
| 892 |
+
return x
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
class ConvBlockDown(nn.Module):
|
| 896 |
+
def __init__(self, in_channels, out_channels, activation): # res = 2, ..., resolution_log
|
| 897 |
+
super().__init__()
|
| 898 |
+
|
| 899 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
| 900 |
+
out_channels=out_channels,
|
| 901 |
+
kernel_size=3,
|
| 902 |
+
activation=activation,
|
| 903 |
+
down=2,
|
| 904 |
+
)
|
| 905 |
+
self.conv1 = Conv2dLayer(in_channels=out_channels,
|
| 906 |
+
out_channels=out_channels,
|
| 907 |
+
kernel_size=3,
|
| 908 |
+
activation=activation,
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
def forward(self, x):
|
| 912 |
+
x = self.conv0(x)
|
| 913 |
+
x = self.conv1(x)
|
| 914 |
+
|
| 915 |
+
return x
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
def token2feature(x, x_size):
|
| 919 |
+
B, N, C = x.shape
|
| 920 |
+
h, w = x_size
|
| 921 |
+
x = x.permute(0, 2, 1).reshape(B, C, h, w)
|
| 922 |
+
return x
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
def feature2token(x):
|
| 926 |
+
B, C, H, W = x.shape
|
| 927 |
+
x = x.view(B, C, -1).transpose(1, 2)
|
| 928 |
+
return x
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
class Encoder(nn.Module):
|
| 932 |
+
def __init__(self, res_log2, img_channels, activation, patch_size=5, channels=16, drop_path_rate=0.1):
|
| 933 |
+
super().__init__()
|
| 934 |
+
|
| 935 |
+
self.resolution = []
|
| 936 |
+
|
| 937 |
+
for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16
|
| 938 |
+
res = 2 ** i
|
| 939 |
+
self.resolution.append(res)
|
| 940 |
+
if i == res_log2:
|
| 941 |
+
block = EncFromRGB(img_channels * 2 + 1, nf(i), activation)
|
| 942 |
+
else:
|
| 943 |
+
block = ConvBlockDown(nf(i + 1), nf(i), activation)
|
| 944 |
+
setattr(self, 'EncConv_Block_%dx%d' % (res, res), block)
|
| 945 |
+
|
| 946 |
+
def forward(self, x):
|
| 947 |
+
out = {}
|
| 948 |
+
for res in self.resolution:
|
| 949 |
+
res_log2 = int(np.log2(res))
|
| 950 |
+
x = getattr(self, 'EncConv_Block_%dx%d' % (res, res))(x)
|
| 951 |
+
out[res_log2] = x
|
| 952 |
+
|
| 953 |
+
return out
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
class ToStyle(nn.Module):
|
| 957 |
+
def __init__(self, in_channels, out_channels, activation, drop_rate):
|
| 958 |
+
super().__init__()
|
| 959 |
+
self.conv = nn.Sequential(
|
| 960 |
+
Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
|
| 961 |
+
down=2),
|
| 962 |
+
Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
|
| 963 |
+
down=2),
|
| 964 |
+
Conv2dLayer(in_channels=in_channels, out_channels=in_channels, kernel_size=3, activation=activation,
|
| 965 |
+
down=2),
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 969 |
+
self.fc = FullyConnectedLayer(in_features=in_channels,
|
| 970 |
+
out_features=out_channels,
|
| 971 |
+
activation=activation)
|
| 972 |
+
# self.dropout = nn.Dropout(drop_rate)
|
| 973 |
+
|
| 974 |
+
def forward(self, x):
|
| 975 |
+
x = self.conv(x)
|
| 976 |
+
x = self.pool(x)
|
| 977 |
+
x = self.fc(x.flatten(start_dim=1))
|
| 978 |
+
# x = self.dropout(x)
|
| 979 |
+
|
| 980 |
+
return x
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
class DecBlockFirstV2(nn.Module):
|
| 984 |
+
def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
|
| 985 |
+
super().__init__()
|
| 986 |
+
self.res = res
|
| 987 |
+
|
| 988 |
+
self.conv0 = Conv2dLayer(in_channels=in_channels,
|
| 989 |
+
out_channels=in_channels,
|
| 990 |
+
kernel_size=3,
|
| 991 |
+
activation=activation,
|
| 992 |
+
)
|
| 993 |
+
self.conv1 = StyleConv(in_channels=in_channels,
|
| 994 |
+
out_channels=out_channels,
|
| 995 |
+
style_dim=style_dim,
|
| 996 |
+
resolution=2 ** res,
|
| 997 |
+
kernel_size=3,
|
| 998 |
+
use_noise=use_noise,
|
| 999 |
+
activation=activation,
|
| 1000 |
+
demodulate=demodulate,
|
| 1001 |
+
)
|
| 1002 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
| 1003 |
+
out_channels=img_channels,
|
| 1004 |
+
style_dim=style_dim,
|
| 1005 |
+
kernel_size=1,
|
| 1006 |
+
demodulate=False,
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
def forward(self, x, ws, gs, E_features, noise_mode='random'):
|
| 1010 |
+
# x = self.fc(x).view(x.shape[0], -1, 4, 4)
|
| 1011 |
+
x = self.conv0(x)
|
| 1012 |
+
x = x + E_features[self.res]
|
| 1013 |
+
style = get_style_code(ws[:, 0], gs)
|
| 1014 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
| 1015 |
+
style = get_style_code(ws[:, 1], gs)
|
| 1016 |
+
img = self.toRGB(x, style, skip=None)
|
| 1017 |
+
|
| 1018 |
+
return x, img
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
class DecBlock(nn.Module):
|
| 1022 |
+
def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate,
|
| 1023 |
+
img_channels): # res = 4, ..., resolution_log2
|
| 1024 |
+
super().__init__()
|
| 1025 |
+
self.res = res
|
| 1026 |
+
|
| 1027 |
+
self.conv0 = StyleConv(in_channels=in_channels,
|
| 1028 |
+
out_channels=out_channels,
|
| 1029 |
+
style_dim=style_dim,
|
| 1030 |
+
resolution=2 ** res,
|
| 1031 |
+
kernel_size=3,
|
| 1032 |
+
up=2,
|
| 1033 |
+
use_noise=use_noise,
|
| 1034 |
+
activation=activation,
|
| 1035 |
+
demodulate=demodulate,
|
| 1036 |
+
)
|
| 1037 |
+
self.conv1 = StyleConv(in_channels=out_channels,
|
| 1038 |
+
out_channels=out_channels,
|
| 1039 |
+
style_dim=style_dim,
|
| 1040 |
+
resolution=2 ** res,
|
| 1041 |
+
kernel_size=3,
|
| 1042 |
+
use_noise=use_noise,
|
| 1043 |
+
activation=activation,
|
| 1044 |
+
demodulate=demodulate,
|
| 1045 |
+
)
|
| 1046 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
| 1047 |
+
out_channels=img_channels,
|
| 1048 |
+
style_dim=style_dim,
|
| 1049 |
+
kernel_size=1,
|
| 1050 |
+
demodulate=False,
|
| 1051 |
+
)
|
| 1052 |
+
|
| 1053 |
+
def forward(self, x, img, ws, gs, E_features, noise_mode='random'):
|
| 1054 |
+
style = get_style_code(ws[:, self.res * 2 - 9], gs)
|
| 1055 |
+
x = self.conv0(x, style, noise_mode=noise_mode)
|
| 1056 |
+
x = x + E_features[self.res]
|
| 1057 |
+
style = get_style_code(ws[:, self.res * 2 - 8], gs)
|
| 1058 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
| 1059 |
+
style = get_style_code(ws[:, self.res * 2 - 7], gs)
|
| 1060 |
+
img = self.toRGB(x, style, skip=img)
|
| 1061 |
+
|
| 1062 |
+
return x, img
|
| 1063 |
+
|
| 1064 |
+
|
| 1065 |
+
class Decoder(nn.Module):
|
| 1066 |
+
def __init__(self, res_log2, activation, style_dim, use_noise, demodulate, img_channels):
|
| 1067 |
+
super().__init__()
|
| 1068 |
+
self.Dec_16x16 = DecBlockFirstV2(4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels)
|
| 1069 |
+
for res in range(5, res_log2 + 1):
|
| 1070 |
+
setattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res),
|
| 1071 |
+
DecBlock(res, nf(res - 1), nf(res), activation, style_dim, use_noise, demodulate, img_channels))
|
| 1072 |
+
self.res_log2 = res_log2
|
| 1073 |
+
|
| 1074 |
+
def forward(self, x, ws, gs, E_features, noise_mode='random'):
|
| 1075 |
+
x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode)
|
| 1076 |
+
for res in range(5, self.res_log2 + 1):
|
| 1077 |
+
block = getattr(self, 'Dec_%dx%d' % (2 ** res, 2 ** res))
|
| 1078 |
+
x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode)
|
| 1079 |
+
|
| 1080 |
+
return img
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
class DecStyleBlock(nn.Module):
|
| 1084 |
+
def __init__(self, res, in_channels, out_channels, activation, style_dim, use_noise, demodulate, img_channels):
|
| 1085 |
+
super().__init__()
|
| 1086 |
+
self.res = res
|
| 1087 |
+
|
| 1088 |
+
self.conv0 = StyleConv(in_channels=in_channels,
|
| 1089 |
+
out_channels=out_channels,
|
| 1090 |
+
style_dim=style_dim,
|
| 1091 |
+
resolution=2 ** res,
|
| 1092 |
+
kernel_size=3,
|
| 1093 |
+
up=2,
|
| 1094 |
+
use_noise=use_noise,
|
| 1095 |
+
activation=activation,
|
| 1096 |
+
demodulate=demodulate,
|
| 1097 |
+
)
|
| 1098 |
+
self.conv1 = StyleConv(in_channels=out_channels,
|
| 1099 |
+
out_channels=out_channels,
|
| 1100 |
+
style_dim=style_dim,
|
| 1101 |
+
resolution=2 ** res,
|
| 1102 |
+
kernel_size=3,
|
| 1103 |
+
use_noise=use_noise,
|
| 1104 |
+
activation=activation,
|
| 1105 |
+
demodulate=demodulate,
|
| 1106 |
+
)
|
| 1107 |
+
self.toRGB = ToRGB(in_channels=out_channels,
|
| 1108 |
+
out_channels=img_channels,
|
| 1109 |
+
style_dim=style_dim,
|
| 1110 |
+
kernel_size=1,
|
| 1111 |
+
demodulate=False,
|
| 1112 |
+
)
|
| 1113 |
+
|
| 1114 |
+
def forward(self, x, img, style, skip, noise_mode='random'):
|
| 1115 |
+
x = self.conv0(x, style, noise_mode=noise_mode)
|
| 1116 |
+
x = x + skip
|
| 1117 |
+
x = self.conv1(x, style, noise_mode=noise_mode)
|
| 1118 |
+
img = self.toRGB(x, style, skip=img)
|
| 1119 |
+
|
| 1120 |
+
return x, img
|
| 1121 |
+
|
| 1122 |
+
|
| 1123 |
+
class FirstStage(nn.Module):
|
| 1124 |
+
def __init__(self, img_channels, img_resolution=256, dim=180, w_dim=512, use_noise=False, demodulate=True,
|
| 1125 |
+
activation='lrelu'):
|
| 1126 |
+
super().__init__()
|
| 1127 |
+
res = 64
|
| 1128 |
+
|
| 1129 |
+
self.conv_first = Conv2dLayerPartial(in_channels=img_channels + 1, out_channels=dim, kernel_size=3,
|
| 1130 |
+
activation=activation)
|
| 1131 |
+
self.enc_conv = nn.ModuleList()
|
| 1132 |
+
down_time = int(np.log2(img_resolution // res))
|
| 1133 |
+
# 根据图片尺寸构建 swim transformer 的层数
|
| 1134 |
+
for i in range(down_time): # from input size to 64
|
| 1135 |
+
self.enc_conv.append(
|
| 1136 |
+
Conv2dLayerPartial(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation)
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
# from 64 -> 16 -> 64
|
| 1140 |
+
depths = [2, 3, 4, 3, 2]
|
| 1141 |
+
ratios = [1, 1 / 2, 1 / 2, 2, 2]
|
| 1142 |
+
num_heads = 6
|
| 1143 |
+
window_sizes = [8, 16, 16, 16, 8]
|
| 1144 |
+
drop_path_rate = 0.1
|
| 1145 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 1146 |
+
|
| 1147 |
+
self.tran = nn.ModuleList()
|
| 1148 |
+
for i, depth in enumerate(depths):
|
| 1149 |
+
res = int(res * ratios[i])
|
| 1150 |
+
if ratios[i] < 1:
|
| 1151 |
+
merge = PatchMerging(dim, dim, down=int(1 / ratios[i]))
|
| 1152 |
+
elif ratios[i] > 1:
|
| 1153 |
+
merge = PatchUpsampling(dim, dim, up=ratios[i])
|
| 1154 |
+
else:
|
| 1155 |
+
merge = None
|
| 1156 |
+
self.tran.append(
|
| 1157 |
+
BasicLayer(dim=dim, input_resolution=[res, res], depth=depth, num_heads=num_heads,
|
| 1158 |
+
window_size=window_sizes[i], drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
| 1159 |
+
downsample=merge)
|
| 1160 |
+
)
|
| 1161 |
+
|
| 1162 |
+
# global style
|
| 1163 |
+
down_conv = []
|
| 1164 |
+
for i in range(int(np.log2(16))):
|
| 1165 |
+
down_conv.append(
|
| 1166 |
+
Conv2dLayer(in_channels=dim, out_channels=dim, kernel_size=3, down=2, activation=activation))
|
| 1167 |
+
down_conv.append(nn.AdaptiveAvgPool2d((1, 1)))
|
| 1168 |
+
self.down_conv = nn.Sequential(*down_conv)
|
| 1169 |
+
self.to_style = FullyConnectedLayer(in_features=dim, out_features=dim * 2, activation=activation)
|
| 1170 |
+
self.ws_style = FullyConnectedLayer(in_features=w_dim, out_features=dim, activation=activation)
|
| 1171 |
+
self.to_square = FullyConnectedLayer(in_features=dim, out_features=16 * 16, activation=activation)
|
| 1172 |
+
|
| 1173 |
+
style_dim = dim * 3
|
| 1174 |
+
self.dec_conv = nn.ModuleList()
|
| 1175 |
+
for i in range(down_time): # from 64 to input size
|
| 1176 |
+
res = res * 2
|
| 1177 |
+
self.dec_conv.append(
|
| 1178 |
+
DecStyleBlock(res, dim, dim, activation, style_dim, use_noise, demodulate, img_channels))
|
| 1179 |
+
|
| 1180 |
+
def forward(self, images_in, masks_in, ws, noise_mode='random'):
|
| 1181 |
+
x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1)
|
| 1182 |
+
|
| 1183 |
+
skips = []
|
| 1184 |
+
x, mask = self.conv_first(x, masks_in) # input size
|
| 1185 |
+
skips.append(x)
|
| 1186 |
+
for i, block in enumerate(self.enc_conv): # input size to 64
|
| 1187 |
+
x, mask = block(x, mask)
|
| 1188 |
+
if i != len(self.enc_conv) - 1:
|
| 1189 |
+
skips.append(x)
|
| 1190 |
+
|
| 1191 |
+
x_size = x.size()[-2:]
|
| 1192 |
+
x = feature2token(x)
|
| 1193 |
+
mask = feature2token(mask)
|
| 1194 |
+
mid = len(self.tran) // 2
|
| 1195 |
+
for i, block in enumerate(self.tran): # 64 to 16
|
| 1196 |
+
if i < mid:
|
| 1197 |
+
x, x_size, mask = block(x, x_size, mask)
|
| 1198 |
+
skips.append(x)
|
| 1199 |
+
elif i > mid:
|
| 1200 |
+
x, x_size, mask = block(x, x_size, None)
|
| 1201 |
+
x = x + skips[mid - i]
|
| 1202 |
+
else:
|
| 1203 |
+
x, x_size, mask = block(x, x_size, None)
|
| 1204 |
+
|
| 1205 |
+
mul_map = torch.ones_like(x) * 0.5
|
| 1206 |
+
mul_map = F.dropout(mul_map, training=True)
|
| 1207 |
+
ws = self.ws_style(ws[:, -1])
|
| 1208 |
+
add_n = self.to_square(ws).unsqueeze(1)
|
| 1209 |
+
add_n = F.interpolate(add_n, size=x.size(1), mode='linear', align_corners=False).squeeze(1).unsqueeze(
|
| 1210 |
+
-1)
|
| 1211 |
+
x = x * mul_map + add_n * (1 - mul_map)
|
| 1212 |
+
gs = self.to_style(self.down_conv(token2feature(x, x_size)).flatten(start_dim=1))
|
| 1213 |
+
style = torch.cat([gs, ws], dim=1)
|
| 1214 |
+
|
| 1215 |
+
x = token2feature(x, x_size).contiguous()
|
| 1216 |
+
img = None
|
| 1217 |
+
for i, block in enumerate(self.dec_conv):
|
| 1218 |
+
x, img = block(x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode)
|
| 1219 |
+
|
| 1220 |
+
# ensemble
|
| 1221 |
+
img = img * (1 - masks_in) + images_in * masks_in
|
| 1222 |
+
|
| 1223 |
+
return img
|
| 1224 |
+
|
| 1225 |
+
|
| 1226 |
+
class SynthesisNet(nn.Module):
|
| 1227 |
+
def __init__(self,
|
| 1228 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 1229 |
+
img_resolution, # Output image resolution.
|
| 1230 |
+
img_channels=3, # Number of color channels.
|
| 1231 |
+
channel_base=32768, # Overall multiplier for the number of channels.
|
| 1232 |
+
channel_decay=1.0,
|
| 1233 |
+
channel_max=512, # Maximum number of channels in any layer.
|
| 1234 |
+
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
|
| 1235 |
+
drop_rate=0.5,
|
| 1236 |
+
use_noise=False,
|
| 1237 |
+
demodulate=True,
|
| 1238 |
+
):
|
| 1239 |
+
super().__init__()
|
| 1240 |
+
resolution_log2 = int(np.log2(img_resolution))
|
| 1241 |
+
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
| 1242 |
+
|
| 1243 |
+
self.num_layers = resolution_log2 * 2 - 3 * 2
|
| 1244 |
+
self.img_resolution = img_resolution
|
| 1245 |
+
self.resolution_log2 = resolution_log2
|
| 1246 |
+
|
| 1247 |
+
# first stage
|
| 1248 |
+
self.first_stage = FirstStage(img_channels, img_resolution=img_resolution, w_dim=w_dim, use_noise=False,
|
| 1249 |
+
demodulate=demodulate)
|
| 1250 |
+
|
| 1251 |
+
# second stage
|
| 1252 |
+
self.enc = Encoder(resolution_log2, img_channels, activation, patch_size=5, channels=16)
|
| 1253 |
+
self.to_square = FullyConnectedLayer(in_features=w_dim, out_features=16 * 16, activation=activation)
|
| 1254 |
+
self.to_style = ToStyle(in_channels=nf(4), out_channels=nf(2) * 2, activation=activation, drop_rate=drop_rate)
|
| 1255 |
+
style_dim = w_dim + nf(2) * 2
|
| 1256 |
+
self.dec = Decoder(resolution_log2, activation, style_dim, use_noise, demodulate, img_channels)
|
| 1257 |
+
|
| 1258 |
+
def forward(self, images_in, masks_in, ws, noise_mode='random', return_stg1=False):
|
| 1259 |
+
out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode)
|
| 1260 |
+
|
| 1261 |
+
# encoder
|
| 1262 |
+
x = images_in * masks_in + out_stg1 * (1 - masks_in)
|
| 1263 |
+
x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1)
|
| 1264 |
+
E_features = self.enc(x)
|
| 1265 |
+
|
| 1266 |
+
fea_16 = E_features[4]
|
| 1267 |
+
mul_map = torch.ones_like(fea_16) * 0.5
|
| 1268 |
+
mul_map = F.dropout(mul_map, training=True)
|
| 1269 |
+
add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1)
|
| 1270 |
+
add_n = F.interpolate(add_n, size=fea_16.size()[-2:], mode='bilinear', align_corners=False)
|
| 1271 |
+
fea_16 = fea_16 * mul_map + add_n * (1 - mul_map)
|
| 1272 |
+
E_features[4] = fea_16
|
| 1273 |
+
|
| 1274 |
+
# style
|
| 1275 |
+
gs = self.to_style(fea_16)
|
| 1276 |
+
|
| 1277 |
+
# decoder
|
| 1278 |
+
img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode)
|
| 1279 |
+
|
| 1280 |
+
# ensemble
|
| 1281 |
+
img = img * (1 - masks_in) + images_in * masks_in
|
| 1282 |
+
|
| 1283 |
+
if not return_stg1:
|
| 1284 |
+
return img
|
| 1285 |
+
else:
|
| 1286 |
+
return img, out_stg1
|
| 1287 |
+
|
| 1288 |
+
|
| 1289 |
+
class Generator(nn.Module):
|
| 1290 |
+
def __init__(self,
|
| 1291 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
| 1292 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
| 1293 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
| 1294 |
+
img_resolution, # resolution of generated image
|
| 1295 |
+
img_channels, # Number of input color channels.
|
| 1296 |
+
synthesis_kwargs={}, # Arguments for SynthesisNetwork.
|
| 1297 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
| 1298 |
+
):
|
| 1299 |
+
super().__init__()
|
| 1300 |
+
self.z_dim = z_dim
|
| 1301 |
+
self.c_dim = c_dim
|
| 1302 |
+
self.w_dim = w_dim
|
| 1303 |
+
self.img_resolution = img_resolution
|
| 1304 |
+
self.img_channels = img_channels
|
| 1305 |
+
|
| 1306 |
+
self.synthesis = SynthesisNet(w_dim=w_dim,
|
| 1307 |
+
img_resolution=img_resolution,
|
| 1308 |
+
img_channels=img_channels,
|
| 1309 |
+
**synthesis_kwargs)
|
| 1310 |
+
self.mapping = MappingNet(z_dim=z_dim,
|
| 1311 |
+
c_dim=c_dim,
|
| 1312 |
+
w_dim=w_dim,
|
| 1313 |
+
num_ws=self.synthesis.num_layers,
|
| 1314 |
+
**mapping_kwargs)
|
| 1315 |
+
|
| 1316 |
+
def forward(self, images_in, masks_in, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False,
|
| 1317 |
+
noise_mode='none', return_stg1=False):
|
| 1318 |
+
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff,
|
| 1319 |
+
skip_w_avg_update=skip_w_avg_update)
|
| 1320 |
+
img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode)
|
| 1321 |
+
return img
|
| 1322 |
+
|
| 1323 |
+
|
| 1324 |
+
class Discriminator(torch.nn.Module):
|
| 1325 |
+
def __init__(self,
|
| 1326 |
+
c_dim, # Conditioning label (C) dimensionality.
|
| 1327 |
+
img_resolution, # Input resolution.
|
| 1328 |
+
img_channels, # Number of input color channels.
|
| 1329 |
+
channel_base=32768, # Overall multiplier for the number of channels.
|
| 1330 |
+
channel_max=512, # Maximum number of channels in any layer.
|
| 1331 |
+
channel_decay=1,
|
| 1332 |
+
cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
|
| 1333 |
+
activation='lrelu',
|
| 1334 |
+
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
| 1335 |
+
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
| 1336 |
+
):
|
| 1337 |
+
super().__init__()
|
| 1338 |
+
self.c_dim = c_dim
|
| 1339 |
+
self.img_resolution = img_resolution
|
| 1340 |
+
self.img_channels = img_channels
|
| 1341 |
+
|
| 1342 |
+
resolution_log2 = int(np.log2(img_resolution))
|
| 1343 |
+
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
| 1344 |
+
self.resolution_log2 = resolution_log2
|
| 1345 |
+
|
| 1346 |
+
if cmap_dim == None:
|
| 1347 |
+
cmap_dim = nf(2)
|
| 1348 |
+
if c_dim == 0:
|
| 1349 |
+
cmap_dim = 0
|
| 1350 |
+
self.cmap_dim = cmap_dim
|
| 1351 |
+
|
| 1352 |
+
if c_dim > 0:
|
| 1353 |
+
self.mapping = MappingNet(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None)
|
| 1354 |
+
|
| 1355 |
+
Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
|
| 1356 |
+
for res in range(resolution_log2, 2, -1):
|
| 1357 |
+
Dis.append(DisBlock(nf(res), nf(res - 1), activation))
|
| 1358 |
+
|
| 1359 |
+
if mbstd_num_channels > 0:
|
| 1360 |
+
Dis.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
|
| 1361 |
+
Dis.append(Conv2dLayer(nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation))
|
| 1362 |
+
self.Dis = nn.Sequential(*Dis)
|
| 1363 |
+
|
| 1364 |
+
self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
|
| 1365 |
+
self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
|
| 1366 |
+
|
| 1367 |
+
# for 64x64
|
| 1368 |
+
Dis_stg1 = [DisFromRGB(img_channels + 1, nf(resolution_log2) // 2, activation)]
|
| 1369 |
+
for res in range(resolution_log2, 2, -1):
|
| 1370 |
+
Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation))
|
| 1371 |
+
|
| 1372 |
+
if mbstd_num_channels > 0:
|
| 1373 |
+
Dis_stg1.append(MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels))
|
| 1374 |
+
Dis_stg1.append(Conv2dLayer(nf(2) // 2 + mbstd_num_channels, nf(2) // 2, kernel_size=3, activation=activation))
|
| 1375 |
+
self.Dis_stg1 = nn.Sequential(*Dis_stg1)
|
| 1376 |
+
|
| 1377 |
+
self.fc0_stg1 = FullyConnectedLayer(nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation)
|
| 1378 |
+
self.fc1_stg1 = FullyConnectedLayer(nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim)
|
| 1379 |
+
|
| 1380 |
+
def forward(self, images_in, masks_in, images_stg1, c):
|
| 1381 |
+
x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1))
|
| 1382 |
+
x = self.fc1(self.fc0(x.flatten(start_dim=1)))
|
| 1383 |
+
|
| 1384 |
+
x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1))
|
| 1385 |
+
x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1)))
|
| 1386 |
+
|
| 1387 |
+
if self.c_dim > 0:
|
| 1388 |
+
cmap = self.mapping(None, c)
|
| 1389 |
+
|
| 1390 |
+
if self.cmap_dim > 0:
|
| 1391 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
| 1392 |
+
x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
| 1393 |
+
|
| 1394 |
+
return x, x_stg1
|
| 1395 |
+
|
| 1396 |
+
|
| 1397 |
+
MAT_MODEL_URL = os.environ.get(
|
| 1398 |
+
"MAT_MODEL_URL",
|
| 1399 |
+
"https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth",
|
| 1400 |
+
)
|
| 1401 |
+
|
| 1402 |
+
|
| 1403 |
+
class MAT(InpaintModel):
|
| 1404 |
+
min_size = 512
|
| 1405 |
+
pad_mod = 512
|
| 1406 |
+
pad_to_square = True
|
| 1407 |
+
|
| 1408 |
+
def init_model(self, device, **kwargs):
|
| 1409 |
+
seed = 240 # pick up a random number
|
| 1410 |
+
random.seed(seed)
|
| 1411 |
+
np.random.seed(seed)
|
| 1412 |
+
torch.manual_seed(seed)
|
| 1413 |
+
|
| 1414 |
+
G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3)
|
| 1415 |
+
self.model = load_model(G, MAT_MODEL_URL, device)
|
| 1416 |
+
self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device) # [1., 512]
|
| 1417 |
+
self.label = torch.zeros([1, self.model.c_dim], device=device)
|
| 1418 |
+
|
| 1419 |
+
@staticmethod
|
| 1420 |
+
def is_downloaded() -> bool:
|
| 1421 |
+
return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))
|
| 1422 |
+
|
| 1423 |
+
def forward(self, image, mask, config: Config):
|
| 1424 |
+
"""Input images and output images have same size
|
| 1425 |
+
images: [H, W, C] RGB
|
| 1426 |
+
masks: [H, W] mask area == 255
|
| 1427 |
+
return: BGR IMAGE
|
| 1428 |
+
"""
|
| 1429 |
+
|
| 1430 |
+
image = norm_img(image) # [0, 1]
|
| 1431 |
+
image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
| 1432 |
+
|
| 1433 |
+
mask = (mask > 127) * 255
|
| 1434 |
+
mask = 255 - mask
|
| 1435 |
+
mask = norm_img(mask)
|
| 1436 |
+
|
| 1437 |
+
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
| 1438 |
+
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
| 1439 |
+
|
| 1440 |
+
output = self.model(image, mask, self.z, self.label, truncation_psi=1, noise_mode='none')
|
| 1441 |
+
output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8)
|
| 1442 |
+
output = output[0].cpu().numpy()
|
| 1443 |
+
cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
| 1444 |
+
return cur_res
|
lama_cleaner/model/opencv2.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
|
| 3 |
+
from lama_cleaner.model.base import InpaintModel
|
| 4 |
+
from lama_cleaner.schema import Config
|
| 5 |
+
|
| 6 |
+
flag_map = {
|
| 7 |
+
"INPAINT_NS": cv2.INPAINT_NS,
|
| 8 |
+
"INPAINT_TELEA": cv2.INPAINT_TELEA
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
class OpenCV2(InpaintModel):
|
| 12 |
+
pad_mod = 1
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
def is_downloaded() -> bool:
|
| 16 |
+
return True
|
| 17 |
+
|
| 18 |
+
def forward(self, image, mask, config: Config):
|
| 19 |
+
"""Input image and output image have same size
|
| 20 |
+
image: [H, W, C] RGB
|
| 21 |
+
mask: [H, W, 1]
|
| 22 |
+
return: BGR IMAGE
|
| 23 |
+
"""
|
| 24 |
+
cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag])
|
| 25 |
+
return cur_res
|