BATUTO_imagen / app.py
ivanoctaviogaitansantos's picture
Create app py
8c0cd05 verified
raw
history blame
4.79 kB
import gradio as gr
import torch
import gc
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AIImageGeneratorNSFW:
def __init__(self):
self.pipeline = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# ¡Aquí va la línea!
self.model_id = "segmind/Segmind-DE-XL"
self.is_model_loaded = False
logger.info(f"Inicializando en dispositivo: {self.device}")
def load_model(self):
if self.is_model_loaded:
return True
try:
logger.info("Cargando modelo NSFW...")
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
tokenizer_1 = CLIPTokenizer.from_pretrained(self.model_id, subfolder="tokenizer", use_fast=False)
tokenizer_2 = CLIPTokenizer.from_pretrained(self.model_id, subfolder="tokenizer_2", use_fast=False)
text_encoder_1 = CLIPTextModel.from_pretrained(self.model_id, subfolder="text_encoder", torch_dtype=torch_dtype, low_cpu_mem_usage=True)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(self.model_id, subfolder="text_encoder_2", torch_dtype=torch_dtype, low_cpu_mem_usage=True)
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
self.model_id,
tokenizer=[tokenizer_1, tokenizer_2],
text_encoder=[text_encoder_1, text_encoder_2],
torch_dtype=torch_dtype,
scheduler=EulerDiscreteScheduler.from_pretrained(self.model_id, subfolder="scheduler"),
safety_checker=None,
use_safetensors=True,
variant="fp16" if self.device == "cuda" else None
)
self.pipeline.to(self.device)
self.is_model_loaded = True
logger.info("Modelo NSFW cargado correctamente.")
return True
except Exception as e:
logger.error(f"Error cargando modelo NSFW: {e}")
return False
def generate_image(self, prompt, width=1024, height=576, steps=35, guidance_scale=12.0):
if not self.is_model_loaded and not self.load_model():
return None
try:
with torch.inference_mode():
generator = torch.Generator(self.device).manual_seed(torch.randint(0, 2**32, (1,)).item())
result = self.pipeline(
prompt=prompt,
width=(width // 8) * 8,
height=(height // 8) * 8,
num_inference_steps=steps,
guidance_scale=guidance_scale,
generator=generator,
output_type="pil"
)
gc.collect()
return result.images[0]
except Exception as e:
logger.error(f"Error generando imagen NSFW: {e}")
gc.collect()
return None
# Instancia global
generator_nsfw = None
def initialize_generator_nsfw():
global generator_nsfw
if generator_nsfw is None:
generator_nsfw = AIImageGeneratorNSFW()
return generator_nsfw
def generate_image_nsfw(prompt, width, height, steps, guidance_scale):
gen = initialize_generator_nsfw()
if not prompt.strip():
return None
return gen.generate_image(
prompt=prompt,
width=int(width),
height=int(height),
steps=int(steps),
guidance_scale=float(guidance_scale)
)
def create_nsfw_interface():
with gr.Blocks(title="Generador de Imágenes NSFW con IA - Stable Diffusion") as iface:
gr.Markdown("# 🎨 Generador NSFW basado en Stable Diffusion\n_Uso responsable y solo para adultos_")
prompt = gr.Textbox(label="Prompt para la imagen NSFW", placeholder="Describe el contenido explícito...", lines=3)
width = gr.Slider(512, 1536, value=1024, step=8, label="Ancho (pixeles)")
height = gr.Slider(512, 1536, value=576, step=8, label="Alto (pixeles)")
steps = gr.Slider(10, 50, value=35, step=1, label="Pasos de inferencia")
guidance_scale = gr.Slider(1.0, 20.0, value=12.0, step=0.1, label="Escala de guía")
btn_generate = gr.Button("Generar Imagen NSFW")
img_output = gr.Image(label="Imagen generada")
btn_generate.click(
fn=generate_image_nsfw,
inputs=[prompt, width, height, steps, guidance_scale],
outputs=img_output
)
return iface
if __name__ == "__main__":
nsfw_app = create_nsfw_interface()
nsfw_app.launch()