Spaces:
Sleeping
Sleeping
Commit
·
5cf7dee
1
Parent(s):
089f816
Update main.py: Refactor code for CPU compatibility
Browse files
main.py
CHANGED
|
@@ -5,12 +5,12 @@ from starlette.middleware.cors import CORSMiddleware
|
|
| 5 |
|
| 6 |
from PIL import Image
|
| 7 |
from io import BytesIO
|
| 8 |
-
from transformers import CLIPFeatureExtractor
|
| 9 |
from diffusers import (
|
| 10 |
AutoPipelineForText2Image,
|
| 11 |
AutoPipelineForImage2Image,
|
| 12 |
AutoPipelineForInpainting,
|
| 13 |
)
|
|
|
|
| 14 |
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
| 15 |
|
| 16 |
|
|
@@ -18,11 +18,11 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
|
| 18 |
async def lifespan(app: FastAPI):
|
| 19 |
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
| 20 |
"openai/clip-vit-base-patch32"
|
| 21 |
-
)
|
| 22 |
|
| 23 |
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
| 24 |
"CompVis/stable-diffusion-safety-checker"
|
| 25 |
-
)
|
| 26 |
|
| 27 |
text2img = AutoPipelineForText2Image.from_pretrained(
|
| 28 |
"stabilityai/sd-turbo",
|
|
@@ -39,6 +39,7 @@ async def lifespan(app: FastAPI):
|
|
| 39 |
del inpaint
|
| 40 |
del img2img
|
| 41 |
del text2img
|
|
|
|
| 42 |
del safety_checker
|
| 43 |
del feature_extractor
|
| 44 |
|
|
@@ -68,7 +69,9 @@ async def text_to_image(
|
|
| 68 |
num_inference_steps: int = Form(1),
|
| 69 |
):
|
| 70 |
results = request.state.text2img(
|
| 71 |
-
prompt=prompt,
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
if not results.nsfw_content_detected[0]:
|
|
|
|
| 5 |
|
| 6 |
from PIL import Image
|
| 7 |
from io import BytesIO
|
|
|
|
| 8 |
from diffusers import (
|
| 9 |
AutoPipelineForText2Image,
|
| 10 |
AutoPipelineForImage2Image,
|
| 11 |
AutoPipelineForInpainting,
|
| 12 |
)
|
| 13 |
+
from transformers import CLIPFeatureExtractor
|
| 14 |
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
| 15 |
|
| 16 |
|
|
|
|
| 18 |
async def lifespan(app: FastAPI):
|
| 19 |
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
| 20 |
"openai/clip-vit-base-patch32"
|
| 21 |
+
).to("cpu")
|
| 22 |
|
| 23 |
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
| 24 |
"CompVis/stable-diffusion-safety-checker"
|
| 25 |
+
).to("cpu")
|
| 26 |
|
| 27 |
text2img = AutoPipelineForText2Image.from_pretrained(
|
| 28 |
"stabilityai/sd-turbo",
|
|
|
|
| 39 |
del inpaint
|
| 40 |
del img2img
|
| 41 |
del text2img
|
| 42 |
+
|
| 43 |
del safety_checker
|
| 44 |
del feature_extractor
|
| 45 |
|
|
|
|
| 69 |
num_inference_steps: int = Form(1),
|
| 70 |
):
|
| 71 |
results = request.state.text2img(
|
| 72 |
+
prompt=prompt,
|
| 73 |
+
num_inference_steps=num_inference_steps,
|
| 74 |
+
guidance_scale=0.0,
|
| 75 |
)
|
| 76 |
|
| 77 |
if not results.nsfw_content_detected[0]:
|