|
|
import gradio as gr |
|
|
import torch |
|
|
import gc |
|
|
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler |
|
|
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection |
|
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class AIImageGeneratorNSFW: |
|
|
def __init__(self): |
|
|
self.pipeline = None |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.model_id = "segmind/Segmind-DE-XL" |
|
|
self.is_model_loaded = False |
|
|
logger.info(f"Inicializando en dispositivo: {self.device}") |
|
|
|
|
|
def load_model(self): |
|
|
if self.is_model_loaded: |
|
|
return True |
|
|
try: |
|
|
logger.info("Cargando modelo NSFW...") |
|
|
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 |
|
|
|
|
|
tokenizer_1 = CLIPTokenizer.from_pretrained(self.model_id, subfolder="tokenizer", use_fast=False) |
|
|
tokenizer_2 = CLIPTokenizer.from_pretrained(self.model_id, subfolder="tokenizer_2", use_fast=False) |
|
|
|
|
|
text_encoder_1 = CLIPTextModel.from_pretrained(self.model_id, subfolder="text_encoder", torch_dtype=torch_dtype, low_cpu_mem_usage=True) |
|
|
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(self.model_id, subfolder="text_encoder_2", torch_dtype=torch_dtype, low_cpu_mem_usage=True) |
|
|
|
|
|
self.pipeline = StableDiffusionXLPipeline.from_pretrained( |
|
|
self.model_id, |
|
|
tokenizer=[tokenizer_1, tokenizer_2], |
|
|
text_encoder=[text_encoder_1, text_encoder_2], |
|
|
torch_dtype=torch_dtype, |
|
|
scheduler=EulerDiscreteScheduler.from_pretrained(self.model_id, subfolder="scheduler"), |
|
|
safety_checker=None, |
|
|
use_safetensors=True, |
|
|
variant="fp16" if self.device == "cuda" else None |
|
|
) |
|
|
self.pipeline.to(self.device) |
|
|
self.is_model_loaded = True |
|
|
logger.info("Modelo NSFW cargado correctamente.") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Error cargando modelo NSFW: {e}") |
|
|
return False |
|
|
|
|
|
def generate_image(self, prompt, width=1024, height=576, steps=35, guidance_scale=12.0): |
|
|
if not self.is_model_loaded and not self.load_model(): |
|
|
return None |
|
|
try: |
|
|
with torch.inference_mode(): |
|
|
generator = torch.Generator(self.device).manual_seed(torch.randint(0, 2**32, (1,)).item()) |
|
|
result = self.pipeline( |
|
|
prompt=prompt, |
|
|
width=(width // 8) * 8, |
|
|
height=(height // 8) * 8, |
|
|
num_inference_steps=steps, |
|
|
guidance_scale=guidance_scale, |
|
|
generator=generator, |
|
|
output_type="pil" |
|
|
) |
|
|
gc.collect() |
|
|
return result.images[0] |
|
|
except Exception as e: |
|
|
logger.error(f"Error generando imagen NSFW: {e}") |
|
|
gc.collect() |
|
|
return None |
|
|
|
|
|
|
|
|
generator_nsfw = None |
|
|
|
|
|
def initialize_generator_nsfw(): |
|
|
global generator_nsfw |
|
|
if generator_nsfw is None: |
|
|
generator_nsfw = AIImageGeneratorNSFW() |
|
|
return generator_nsfw |
|
|
|
|
|
def generate_image_nsfw(prompt, width, height, steps, guidance_scale): |
|
|
gen = initialize_generator_nsfw() |
|
|
if not prompt.strip(): |
|
|
return None |
|
|
return gen.generate_image( |
|
|
prompt=prompt, |
|
|
width=int(width), |
|
|
height=int(height), |
|
|
steps=int(steps), |
|
|
guidance_scale=float(guidance_scale) |
|
|
) |
|
|
|
|
|
def create_nsfw_interface(): |
|
|
with gr.Blocks(title="Generador de Imágenes NSFW con IA - Stable Diffusion") as iface: |
|
|
gr.Markdown("# 🎨 Generador NSFW basado en Stable Diffusion\n_Uso responsable y solo para adultos_") |
|
|
|
|
|
prompt = gr.Textbox(label="Prompt para la imagen NSFW", placeholder="Describe el contenido explícito...", lines=3) |
|
|
width = gr.Slider(512, 1536, value=1024, step=8, label="Ancho (pixeles)") |
|
|
height = gr.Slider(512, 1536, value=576, step=8, label="Alto (pixeles)") |
|
|
steps = gr.Slider(10, 50, value=35, step=1, label="Pasos de inferencia") |
|
|
guidance_scale = gr.Slider(1.0, 20.0, value=12.0, step=0.1, label="Escala de guía") |
|
|
|
|
|
btn_generate = gr.Button("Generar Imagen NSFW") |
|
|
img_output = gr.Image(label="Imagen generada") |
|
|
|
|
|
btn_generate.click( |
|
|
fn=generate_image_nsfw, |
|
|
inputs=[prompt, width, height, steps, guidance_scale], |
|
|
outputs=img_output |
|
|
) |
|
|
return iface |
|
|
|
|
|
if __name__ == "__main__": |
|
|
nsfw_app = create_nsfw_interface() |
|
|
nsfw_app.launch() |