|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
import os |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
from transformers import CLIPProcessor |
|
|
|
|
|
from model import SiameseCLIPModel |
|
|
from config import Config |
|
|
|
|
|
class SimilarityInference: |
|
|
def __init__(self, checkpoint_path, device=None): |
|
|
""" |
|
|
Inicializa el modelo para inferencia. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Ruta al checkpoint del modelo entrenado |
|
|
device: Dispositivo para inferencia (cuda o cpu) |
|
|
""" |
|
|
|
|
|
if device is None: |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
else: |
|
|
self.device = torch.device(device) |
|
|
|
|
|
print(f"Utilizando dispositivo: {self.device}") |
|
|
|
|
|
|
|
|
self.config = Config |
|
|
|
|
|
|
|
|
self.model = SiameseCLIPModel(self.config) |
|
|
|
|
|
|
|
|
self._load_checkpoint(checkpoint_path) |
|
|
|
|
|
|
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.processor = CLIPProcessor.from_pretrained(self.config.CLIP_MODEL_NAME) |
|
|
|
|
|
def _load_checkpoint(self, checkpoint_path): |
|
|
""" |
|
|
Carga los pesos del modelo desde un checkpoint. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Ruta al checkpoint |
|
|
""" |
|
|
if not os.path.exists(checkpoint_path): |
|
|
raise FileNotFoundError(f"No se encontr贸 el checkpoint en {checkpoint_path}") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|
|
|
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
print(f"Modelo cargado desde {checkpoint_path}") |
|
|
print(f"脡poca: {checkpoint.get('epoch', 'N/A')}") |
|
|
print(f"Mejor p茅rdida de validaci贸n: {checkpoint.get('best_val_loss', 'N/A')}") |
|
|
else: |
|
|
|
|
|
self.model.load_state_dict(checkpoint) |
|
|
print(f"Modelo cargado desde {checkpoint_path} (formato simple)") |
|
|
|
|
|
def preprocess_image(self, image_path): |
|
|
""" |
|
|
Preprocesa una imagen para inferencia. |
|
|
|
|
|
Args: |
|
|
image_path: Ruta a la imagen |
|
|
|
|
|
Returns: |
|
|
Tensor de la imagen procesada |
|
|
""" |
|
|
if not os.path.exists(image_path): |
|
|
raise FileNotFoundError(f"No se encontr贸 la imagen en {image_path}") |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
|
|
|
inputs = self.processor(images=image, return_tensors="pt") |
|
|
|
|
|
return inputs.pixel_values.to(self.device) |
|
|
|
|
|
def preprocess_text(self, text): |
|
|
""" |
|
|
Preprocesa un texto para inferencia. |
|
|
|
|
|
Args: |
|
|
text: Texto a procesar |
|
|
|
|
|
Returns: |
|
|
Tensores de input_ids y attention_mask |
|
|
""" |
|
|
|
|
|
inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True) |
|
|
|
|
|
return { |
|
|
'input_ids': inputs.input_ids.to(self.device), |
|
|
'attention_mask': inputs.attention_mask.to(self.device) |
|
|
} |
|
|
|
|
|
def calculate_similarity(self, image1_path, image2_path, text1=None, text2=None): |
|
|
""" |
|
|
Calcula la similitud entre dos im谩genes (y opcionalmente sus textos). |
|
|
|
|
|
Args: |
|
|
image1_path: Ruta a la primera imagen |
|
|
image2_path: Ruta a la segunda imagen |
|
|
text1: Descripci贸n de la primera imagen (opcional) |
|
|
text2: Descripci贸n de la segunda imagen (opcional) |
|
|
|
|
|
Returns: |
|
|
Similitud del coseno entre los embeddings [-1, 1] |
|
|
""" |
|
|
|
|
|
image1 = self.preprocess_image(image1_path) |
|
|
image2 = self.preprocess_image(image2_path) |
|
|
|
|
|
|
|
|
text1_inputs = None |
|
|
text2_inputs = None |
|
|
|
|
|
if text1 is not None and text2 is not None and self.config.USE_TEXT_EMBEDDINGS: |
|
|
text1_inputs = self.preprocess_text(text1) |
|
|
text2_inputs = self.preprocess_text(text2) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
if text1_inputs is not None and text2_inputs is not None: |
|
|
similarity = self.model.calculate_similarity( |
|
|
image1_pixel_values=image1, |
|
|
image2_pixel_values=image2, |
|
|
text1_input_ids=text1_inputs['input_ids'], |
|
|
text2_input_ids=text2_inputs['input_ids'], |
|
|
text1_attention_mask=text1_inputs['attention_mask'], |
|
|
text2_attention_mask=text2_inputs['attention_mask'] |
|
|
) |
|
|
else: |
|
|
similarity = self.model.calculate_similarity( |
|
|
image1_pixel_values=image1, |
|
|
image2_pixel_values=image2 |
|
|
) |
|
|
|
|
|
return similarity.item() |
|
|
|
|
|
def calculate_batch_similarities(self, reference_image, comparison_images, reference_text=None, comparison_texts=None): |
|
|
""" |
|
|
Calcula similitudes entre una imagen de referencia y m煤ltiples im谩genes de comparaci贸n. |
|
|
|
|
|
Args: |
|
|
reference_image: Ruta a la imagen de referencia |
|
|
comparison_images: Lista de rutas a im谩genes para comparar |
|
|
reference_text: Descripci贸n de la imagen de referencia (opcional) |
|
|
comparison_texts: Lista de descripciones para las im谩genes de comparaci贸n (opcional) |
|
|
|
|
|
Returns: |
|
|
Lista de similitudes ordenadas de mayor a menor |
|
|
""" |
|
|
|
|
|
ref_image = self.preprocess_image(reference_image) |
|
|
|
|
|
|
|
|
ref_text_inputs = None |
|
|
if reference_text is not None and self.config.USE_TEXT_EMBEDDINGS: |
|
|
ref_text_inputs = self.preprocess_text(reference_text) |
|
|
|
|
|
results = [] |
|
|
|
|
|
|
|
|
for i, comp_image_path in enumerate(comparison_images): |
|
|
try: |
|
|
|
|
|
comp_image = self.preprocess_image(comp_image_path) |
|
|
|
|
|
|
|
|
comp_text_inputs = None |
|
|
if comparison_texts is not None and i < len(comparison_texts) and self.config.USE_TEXT_EMBEDDINGS: |
|
|
comp_text_inputs = self.preprocess_text(comparison_texts[i]) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
if ref_text_inputs is not None and comp_text_inputs is not None: |
|
|
similarity = self.model.calculate_similarity( |
|
|
image1_pixel_values=ref_image, |
|
|
image2_pixel_values=comp_image, |
|
|
text1_input_ids=ref_text_inputs['input_ids'], |
|
|
text2_input_ids=comp_text_inputs['input_ids'], |
|
|
text1_attention_mask=ref_text_inputs['attention_mask'], |
|
|
text2_attention_mask=comp_text_inputs['attention_mask'] |
|
|
) |
|
|
else: |
|
|
similarity = self.model.calculate_similarity( |
|
|
image1_pixel_values=ref_image, |
|
|
image2_pixel_values=comp_image |
|
|
) |
|
|
|
|
|
results.append({ |
|
|
'image_path': comp_image_path, |
|
|
'similarity': similarity.item() |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Error al procesar {comp_image_path}: {e}") |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x['similarity'], reverse=True) |
|
|
|
|
|
return results |
|
|
|
|
|
def visualize_similarities(self, reference_image, comparison_results, num_images=5, figsize=(15, 10)): |
|
|
""" |
|
|
Visualiza las similitudes entre una imagen de referencia y las im谩genes m谩s similares. |
|
|
|
|
|
Args: |
|
|
reference_image: Ruta a la imagen de referencia |
|
|
comparison_results: Resultados de calculate_batch_similarities |
|
|
num_images: N煤mero de im谩genes similares a mostrar |
|
|
figsize: Tama帽o de la figura |
|
|
""" |
|
|
|
|
|
num_images = min(num_images, len(comparison_results)) |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, num_images + 1, figsize=figsize) |
|
|
|
|
|
|
|
|
ref_img = Image.open(reference_image).convert('RGB') |
|
|
axes[0].imshow(ref_img) |
|
|
axes[0].set_title("Imagen de referencia") |
|
|
axes[0].axis('off') |
|
|
|
|
|
|
|
|
for i in range(num_images): |
|
|
img_path = comparison_results[i]['image_path'] |
|
|
similarity = comparison_results[i]['similarity'] |
|
|
|
|
|
img = Image.open(img_path).convert('RGB') |
|
|
axes[i+1].imshow(img) |
|
|
axes[i+1].set_title(f"Sim: {similarity:.4f}") |
|
|
axes[i+1].axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
|