File size: 9,833 Bytes
0b18c05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
import torch
import torch.nn.functional as F
from PIL import Image
import os
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from transformers import CLIPProcessor
from model import SiameseCLIPModel
from config import Config
class SimilarityInference:
def __init__(self, checkpoint_path, device=None):
"""
Inicializa el modelo para inferencia.
Args:
checkpoint_path: Ruta al checkpoint del modelo entrenado
device: Dispositivo para inferencia (cuda o cpu)
"""
# Configurar dispositivo
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
print(f"Utilizando dispositivo: {self.device}")
# Cargar configuraci贸n
self.config = Config
# Inicializar modelo
self.model = SiameseCLIPModel(self.config)
# Cargar checkpoint
self._load_checkpoint(checkpoint_path)
# Mover modelo al dispositivo
self.model.to(self.device)
self.model.eval()
# Inicializar procesador de CLIP
self.processor = CLIPProcessor.from_pretrained(self.config.CLIP_MODEL_NAME)
def _load_checkpoint(self, checkpoint_path):
"""
Carga los pesos del modelo desde un checkpoint.
Args:
checkpoint_path: Ruta al checkpoint
"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"No se encontr贸 el checkpoint en {checkpoint_path}")
# Cargar checkpoint
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Cargar estado del modelo
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
print(f"Modelo cargado desde {checkpoint_path}")
print(f"脡poca: {checkpoint.get('epoch', 'N/A')}")
print(f"Mejor p茅rdida de validaci贸n: {checkpoint.get('best_val_loss', 'N/A')}")
else:
# Si el checkpoint solo contiene los pesos del modelo
self.model.load_state_dict(checkpoint)
print(f"Modelo cargado desde {checkpoint_path} (formato simple)")
def preprocess_image(self, image_path):
"""
Preprocesa una imagen para inferencia.
Args:
image_path: Ruta a la imagen
Returns:
Tensor de la imagen procesada
"""
if not os.path.exists(image_path):
raise FileNotFoundError(f"No se encontr贸 la imagen en {image_path}")
# Cargar imagen
image = Image.open(image_path).convert('RGB')
# Procesar imagen con el procesador de CLIP
inputs = self.processor(images=image, return_tensors="pt")
return inputs.pixel_values.to(self.device)
def preprocess_text(self, text):
"""
Preprocesa un texto para inferencia.
Args:
text: Texto a procesar
Returns:
Tensores de input_ids y attention_mask
"""
# Procesar texto con el procesador de CLIP
inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True)
return {
'input_ids': inputs.input_ids.to(self.device),
'attention_mask': inputs.attention_mask.to(self.device)
}
def calculate_similarity(self, image1_path, image2_path, text1=None, text2=None):
"""
Calcula la similitud entre dos im谩genes (y opcionalmente sus textos).
Args:
image1_path: Ruta a la primera imagen
image2_path: Ruta a la segunda imagen
text1: Descripci贸n de la primera imagen (opcional)
text2: Descripci贸n de la segunda imagen (opcional)
Returns:
Similitud del coseno entre los embeddings [-1, 1]
"""
# Preprocesar im谩genes
image1 = self.preprocess_image(image1_path)
image2 = self.preprocess_image(image2_path)
# Preprocesar textos si se proporcionan
text1_inputs = None
text2_inputs = None
if text1 is not None and text2 is not None and self.config.USE_TEXT_EMBEDDINGS:
text1_inputs = self.preprocess_text(text1)
text2_inputs = self.preprocess_text(text2)
# Calcular similitud
with torch.no_grad():
if text1_inputs is not None and text2_inputs is not None:
similarity = self.model.calculate_similarity(
image1_pixel_values=image1,
image2_pixel_values=image2,
text1_input_ids=text1_inputs['input_ids'],
text2_input_ids=text2_inputs['input_ids'],
text1_attention_mask=text1_inputs['attention_mask'],
text2_attention_mask=text2_inputs['attention_mask']
)
else:
similarity = self.model.calculate_similarity(
image1_pixel_values=image1,
image2_pixel_values=image2
)
return similarity.item()
def calculate_batch_similarities(self, reference_image, comparison_images, reference_text=None, comparison_texts=None):
"""
Calcula similitudes entre una imagen de referencia y m煤ltiples im谩genes de comparaci贸n.
Args:
reference_image: Ruta a la imagen de referencia
comparison_images: Lista de rutas a im谩genes para comparar
reference_text: Descripci贸n de la imagen de referencia (opcional)
comparison_texts: Lista de descripciones para las im谩genes de comparaci贸n (opcional)
Returns:
Lista de similitudes ordenadas de mayor a menor
"""
# Preprocesar imagen de referencia
ref_image = self.preprocess_image(reference_image)
# Preprocesar texto de referencia si se proporciona
ref_text_inputs = None
if reference_text is not None and self.config.USE_TEXT_EMBEDDINGS:
ref_text_inputs = self.preprocess_text(reference_text)
results = []
# Calcular similitud para cada imagen de comparaci贸n
for i, comp_image_path in enumerate(comparison_images):
try:
# Preprocesar imagen de comparaci贸n
comp_image = self.preprocess_image(comp_image_path)
# Preprocesar texto de comparaci贸n si se proporciona
comp_text_inputs = None
if comparison_texts is not None and i < len(comparison_texts) and self.config.USE_TEXT_EMBEDDINGS:
comp_text_inputs = self.preprocess_text(comparison_texts[i])
# Calcular similitud
with torch.no_grad():
if ref_text_inputs is not None and comp_text_inputs is not None:
similarity = self.model.calculate_similarity(
image1_pixel_values=ref_image,
image2_pixel_values=comp_image,
text1_input_ids=ref_text_inputs['input_ids'],
text2_input_ids=comp_text_inputs['input_ids'],
text1_attention_mask=ref_text_inputs['attention_mask'],
text2_attention_mask=comp_text_inputs['attention_mask']
)
else:
similarity = self.model.calculate_similarity(
image1_pixel_values=ref_image,
image2_pixel_values=comp_image
)
results.append({
'image_path': comp_image_path,
'similarity': similarity.item()
})
except Exception as e:
print(f"Error al procesar {comp_image_path}: {e}")
# Ordenar resultados por similitud (de mayor a menor)
results.sort(key=lambda x: x['similarity'], reverse=True)
return results
def visualize_similarities(self, reference_image, comparison_results, num_images=5, figsize=(15, 10)):
"""
Visualiza las similitudes entre una imagen de referencia y las im谩genes m谩s similares.
Args:
reference_image: Ruta a la imagen de referencia
comparison_results: Resultados de calculate_batch_similarities
num_images: N煤mero de im谩genes similares a mostrar
figsize: Tama帽o de la figura
"""
# Limitar el n煤mero de im谩genes a mostrar
num_images = min(num_images, len(comparison_results))
# Crear figura
fig, axes = plt.subplots(1, num_images + 1, figsize=figsize)
# Mostrar imagen de referencia
ref_img = Image.open(reference_image).convert('RGB')
axes[0].imshow(ref_img)
axes[0].set_title("Imagen de referencia")
axes[0].axis('off')
# Mostrar im谩genes similares
for i in range(num_images):
img_path = comparison_results[i]['image_path']
similarity = comparison_results[i]['similarity']
img = Image.open(img_path).convert('RGB')
axes[i+1].imshow(img)
axes[i+1].set_title(f"Sim: {similarity:.4f}")
axes[i+1].axis('off')
plt.tight_layout()
plt.show()
|