ArtSimilarity / inference.py
Dant33's picture
Upload inference.py with huggingface_hub
0b18c05 verified
import torch
import torch.nn.functional as F
from PIL import Image
import os
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from transformers import CLIPProcessor
from model import SiameseCLIPModel
from config import Config
class SimilarityInference:
def __init__(self, checkpoint_path, device=None):
"""
Inicializa el modelo para inferencia.
Args:
checkpoint_path: Ruta al checkpoint del modelo entrenado
device: Dispositivo para inferencia (cuda o cpu)
"""
# Configurar dispositivo
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
print(f"Utilizando dispositivo: {self.device}")
# Cargar configuraci贸n
self.config = Config
# Inicializar modelo
self.model = SiameseCLIPModel(self.config)
# Cargar checkpoint
self._load_checkpoint(checkpoint_path)
# Mover modelo al dispositivo
self.model.to(self.device)
self.model.eval()
# Inicializar procesador de CLIP
self.processor = CLIPProcessor.from_pretrained(self.config.CLIP_MODEL_NAME)
def _load_checkpoint(self, checkpoint_path):
"""
Carga los pesos del modelo desde un checkpoint.
Args:
checkpoint_path: Ruta al checkpoint
"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"No se encontr贸 el checkpoint en {checkpoint_path}")
# Cargar checkpoint
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Cargar estado del modelo
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
print(f"Modelo cargado desde {checkpoint_path}")
print(f"脡poca: {checkpoint.get('epoch', 'N/A')}")
print(f"Mejor p茅rdida de validaci贸n: {checkpoint.get('best_val_loss', 'N/A')}")
else:
# Si el checkpoint solo contiene los pesos del modelo
self.model.load_state_dict(checkpoint)
print(f"Modelo cargado desde {checkpoint_path} (formato simple)")
def preprocess_image(self, image_path):
"""
Preprocesa una imagen para inferencia.
Args:
image_path: Ruta a la imagen
Returns:
Tensor de la imagen procesada
"""
if not os.path.exists(image_path):
raise FileNotFoundError(f"No se encontr贸 la imagen en {image_path}")
# Cargar imagen
image = Image.open(image_path).convert('RGB')
# Procesar imagen con el procesador de CLIP
inputs = self.processor(images=image, return_tensors="pt")
return inputs.pixel_values.to(self.device)
def preprocess_text(self, text):
"""
Preprocesa un texto para inferencia.
Args:
text: Texto a procesar
Returns:
Tensores de input_ids y attention_mask
"""
# Procesar texto con el procesador de CLIP
inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True)
return {
'input_ids': inputs.input_ids.to(self.device),
'attention_mask': inputs.attention_mask.to(self.device)
}
def calculate_similarity(self, image1_path, image2_path, text1=None, text2=None):
"""
Calcula la similitud entre dos im谩genes (y opcionalmente sus textos).
Args:
image1_path: Ruta a la primera imagen
image2_path: Ruta a la segunda imagen
text1: Descripci贸n de la primera imagen (opcional)
text2: Descripci贸n de la segunda imagen (opcional)
Returns:
Similitud del coseno entre los embeddings [-1, 1]
"""
# Preprocesar im谩genes
image1 = self.preprocess_image(image1_path)
image2 = self.preprocess_image(image2_path)
# Preprocesar textos si se proporcionan
text1_inputs = None
text2_inputs = None
if text1 is not None and text2 is not None and self.config.USE_TEXT_EMBEDDINGS:
text1_inputs = self.preprocess_text(text1)
text2_inputs = self.preprocess_text(text2)
# Calcular similitud
with torch.no_grad():
if text1_inputs is not None and text2_inputs is not None:
similarity = self.model.calculate_similarity(
image1_pixel_values=image1,
image2_pixel_values=image2,
text1_input_ids=text1_inputs['input_ids'],
text2_input_ids=text2_inputs['input_ids'],
text1_attention_mask=text1_inputs['attention_mask'],
text2_attention_mask=text2_inputs['attention_mask']
)
else:
similarity = self.model.calculate_similarity(
image1_pixel_values=image1,
image2_pixel_values=image2
)
return similarity.item()
def calculate_batch_similarities(self, reference_image, comparison_images, reference_text=None, comparison_texts=None):
"""
Calcula similitudes entre una imagen de referencia y m煤ltiples im谩genes de comparaci贸n.
Args:
reference_image: Ruta a la imagen de referencia
comparison_images: Lista de rutas a im谩genes para comparar
reference_text: Descripci贸n de la imagen de referencia (opcional)
comparison_texts: Lista de descripciones para las im谩genes de comparaci贸n (opcional)
Returns:
Lista de similitudes ordenadas de mayor a menor
"""
# Preprocesar imagen de referencia
ref_image = self.preprocess_image(reference_image)
# Preprocesar texto de referencia si se proporciona
ref_text_inputs = None
if reference_text is not None and self.config.USE_TEXT_EMBEDDINGS:
ref_text_inputs = self.preprocess_text(reference_text)
results = []
# Calcular similitud para cada imagen de comparaci贸n
for i, comp_image_path in enumerate(comparison_images):
try:
# Preprocesar imagen de comparaci贸n
comp_image = self.preprocess_image(comp_image_path)
# Preprocesar texto de comparaci贸n si se proporciona
comp_text_inputs = None
if comparison_texts is not None and i < len(comparison_texts) and self.config.USE_TEXT_EMBEDDINGS:
comp_text_inputs = self.preprocess_text(comparison_texts[i])
# Calcular similitud
with torch.no_grad():
if ref_text_inputs is not None and comp_text_inputs is not None:
similarity = self.model.calculate_similarity(
image1_pixel_values=ref_image,
image2_pixel_values=comp_image,
text1_input_ids=ref_text_inputs['input_ids'],
text2_input_ids=comp_text_inputs['input_ids'],
text1_attention_mask=ref_text_inputs['attention_mask'],
text2_attention_mask=comp_text_inputs['attention_mask']
)
else:
similarity = self.model.calculate_similarity(
image1_pixel_values=ref_image,
image2_pixel_values=comp_image
)
results.append({
'image_path': comp_image_path,
'similarity': similarity.item()
})
except Exception as e:
print(f"Error al procesar {comp_image_path}: {e}")
# Ordenar resultados por similitud (de mayor a menor)
results.sort(key=lambda x: x['similarity'], reverse=True)
return results
def visualize_similarities(self, reference_image, comparison_results, num_images=5, figsize=(15, 10)):
"""
Visualiza las similitudes entre una imagen de referencia y las im谩genes m谩s similares.
Args:
reference_image: Ruta a la imagen de referencia
comparison_results: Resultados de calculate_batch_similarities
num_images: N煤mero de im谩genes similares a mostrar
figsize: Tama帽o de la figura
"""
# Limitar el n煤mero de im谩genes a mostrar
num_images = min(num_images, len(comparison_results))
# Crear figura
fig, axes = plt.subplots(1, num_images + 1, figsize=figsize)
# Mostrar imagen de referencia
ref_img = Image.open(reference_image).convert('RGB')
axes[0].imshow(ref_img)
axes[0].set_title("Imagen de referencia")
axes[0].axis('off')
# Mostrar im谩genes similares
for i in range(num_images):
img_path = comparison_results[i]['image_path']
similarity = comparison_results[i]['similarity']
img = Image.open(img_path).convert('RGB')
axes[i+1].imshow(img)
axes[i+1].set_title(f"Sim: {similarity:.4f}")
axes[i+1].axis('off')
plt.tight_layout()
plt.show()