File size: 9,833 Bytes
0b18c05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import torch
import torch.nn.functional as F
from PIL import Image
import os
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from transformers import CLIPProcessor

from model import SiameseCLIPModel
from config import Config

class SimilarityInference:
    def __init__(self, checkpoint_path, device=None):
        """
        Inicializa el modelo para inferencia.
        
        Args:
            checkpoint_path: Ruta al checkpoint del modelo entrenado
            device: Dispositivo para inferencia (cuda o cpu)
        """
        # Configurar dispositivo
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)
            
        print(f"Utilizando dispositivo: {self.device}")
        
        # Cargar configuraci贸n
        self.config = Config
        
        # Inicializar modelo
        self.model = SiameseCLIPModel(self.config)
        
        # Cargar checkpoint
        self._load_checkpoint(checkpoint_path)
        
        # Mover modelo al dispositivo
        self.model.to(self.device)
        self.model.eval()
        
        # Inicializar procesador de CLIP
        self.processor = CLIPProcessor.from_pretrained(self.config.CLIP_MODEL_NAME)
        
    def _load_checkpoint(self, checkpoint_path):
        """
        Carga los pesos del modelo desde un checkpoint.
        
        Args:
            checkpoint_path: Ruta al checkpoint
        """
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"No se encontr贸 el checkpoint en {checkpoint_path}")
        
        # Cargar checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        # Cargar estado del modelo
        if 'model_state_dict' in checkpoint:
            self.model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Modelo cargado desde {checkpoint_path}")
            print(f"脡poca: {checkpoint.get('epoch', 'N/A')}")
            print(f"Mejor p茅rdida de validaci贸n: {checkpoint.get('best_val_loss', 'N/A')}")
        else:
            # Si el checkpoint solo contiene los pesos del modelo
            self.model.load_state_dict(checkpoint)
            print(f"Modelo cargado desde {checkpoint_path} (formato simple)")
    
    def preprocess_image(self, image_path):
        """
        Preprocesa una imagen para inferencia.
        
        Args:
            image_path: Ruta a la imagen
            
        Returns:
            Tensor de la imagen procesada
        """
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"No se encontr贸 la imagen en {image_path}")
        
        # Cargar imagen
        image = Image.open(image_path).convert('RGB')
        
        # Procesar imagen con el procesador de CLIP
        inputs = self.processor(images=image, return_tensors="pt")
        
        return inputs.pixel_values.to(self.device)
    
    def preprocess_text(self, text):
        """
        Preprocesa un texto para inferencia.
        
        Args:
            text: Texto a procesar
            
        Returns:
            Tensores de input_ids y attention_mask
        """
        # Procesar texto con el procesador de CLIP
        inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True)
        
        return {
            'input_ids': inputs.input_ids.to(self.device),
            'attention_mask': inputs.attention_mask.to(self.device)
        }
    
    def calculate_similarity(self, image1_path, image2_path, text1=None, text2=None):
        """
        Calcula la similitud entre dos im谩genes (y opcionalmente sus textos).
        
        Args:
            image1_path: Ruta a la primera imagen
            image2_path: Ruta a la segunda imagen
            text1: Descripci贸n de la primera imagen (opcional)
            text2: Descripci贸n de la segunda imagen (opcional)
            
        Returns:
            Similitud del coseno entre los embeddings [-1, 1]
        """
        # Preprocesar im谩genes
        image1 = self.preprocess_image(image1_path)
        image2 = self.preprocess_image(image2_path)
        
        # Preprocesar textos si se proporcionan
        text1_inputs = None
        text2_inputs = None
        
        if text1 is not None and text2 is not None and self.config.USE_TEXT_EMBEDDINGS:
            text1_inputs = self.preprocess_text(text1)
            text2_inputs = self.preprocess_text(text2)
        
        # Calcular similitud
        with torch.no_grad():
            if text1_inputs is not None and text2_inputs is not None:
                similarity = self.model.calculate_similarity(
                    image1_pixel_values=image1,
                    image2_pixel_values=image2,
                    text1_input_ids=text1_inputs['input_ids'],
                    text2_input_ids=text2_inputs['input_ids'],
                    text1_attention_mask=text1_inputs['attention_mask'],
                    text2_attention_mask=text2_inputs['attention_mask']
                )
            else:
                similarity = self.model.calculate_similarity(
                    image1_pixel_values=image1,
                    image2_pixel_values=image2
                )
        
        return similarity.item()
    
    def calculate_batch_similarities(self, reference_image, comparison_images, reference_text=None, comparison_texts=None):
        """
        Calcula similitudes entre una imagen de referencia y m煤ltiples im谩genes de comparaci贸n.
        
        Args:
            reference_image: Ruta a la imagen de referencia
            comparison_images: Lista de rutas a im谩genes para comparar
            reference_text: Descripci贸n de la imagen de referencia (opcional)
            comparison_texts: Lista de descripciones para las im谩genes de comparaci贸n (opcional)
            
        Returns:
            Lista de similitudes ordenadas de mayor a menor
        """
        # Preprocesar imagen de referencia
        ref_image = self.preprocess_image(reference_image)
        
        # Preprocesar texto de referencia si se proporciona
        ref_text_inputs = None
        if reference_text is not None and self.config.USE_TEXT_EMBEDDINGS:
            ref_text_inputs = self.preprocess_text(reference_text)
        
        results = []
        
        # Calcular similitud para cada imagen de comparaci贸n
        for i, comp_image_path in enumerate(comparison_images):
            try:
                # Preprocesar imagen de comparaci贸n
                comp_image = self.preprocess_image(comp_image_path)
                
                # Preprocesar texto de comparaci贸n si se proporciona
                comp_text_inputs = None
                if comparison_texts is not None and i < len(comparison_texts) and self.config.USE_TEXT_EMBEDDINGS:
                    comp_text_inputs = self.preprocess_text(comparison_texts[i])
                
                # Calcular similitud
                with torch.no_grad():
                    if ref_text_inputs is not None and comp_text_inputs is not None:
                        similarity = self.model.calculate_similarity(
                            image1_pixel_values=ref_image,
                            image2_pixel_values=comp_image,
                            text1_input_ids=ref_text_inputs['input_ids'],
                            text2_input_ids=comp_text_inputs['input_ids'],
                            text1_attention_mask=ref_text_inputs['attention_mask'],
                            text2_attention_mask=comp_text_inputs['attention_mask']
                        )
                    else:
                        similarity = self.model.calculate_similarity(
                            image1_pixel_values=ref_image,
                            image2_pixel_values=comp_image
                        )
                
                results.append({
                    'image_path': comp_image_path,
                    'similarity': similarity.item()
                })
            except Exception as e:
                print(f"Error al procesar {comp_image_path}: {e}")
        
        # Ordenar resultados por similitud (de mayor a menor)
        results.sort(key=lambda x: x['similarity'], reverse=True)
        
        return results
    
    def visualize_similarities(self, reference_image, comparison_results, num_images=5, figsize=(15, 10)):
        """
        Visualiza las similitudes entre una imagen de referencia y las im谩genes m谩s similares.
        
        Args:
            reference_image: Ruta a la imagen de referencia
            comparison_results: Resultados de calculate_batch_similarities
            num_images: N煤mero de im谩genes similares a mostrar
            figsize: Tama帽o de la figura
        """
        # Limitar el n煤mero de im谩genes a mostrar
        num_images = min(num_images, len(comparison_results))
        
        # Crear figura
        fig, axes = plt.subplots(1, num_images + 1, figsize=figsize)
        
        # Mostrar imagen de referencia
        ref_img = Image.open(reference_image).convert('RGB')
        axes[0].imshow(ref_img)
        axes[0].set_title("Imagen de referencia")
        axes[0].axis('off')
        
        # Mostrar im谩genes similares
        for i in range(num_images):
            img_path = comparison_results[i]['image_path']
            similarity = comparison_results[i]['similarity']
            
            img = Image.open(img_path).convert('RGB')
            axes[i+1].imshow(img)
            axes[i+1].set_title(f"Sim: {similarity:.4f}")
            axes[i+1].axis('off')
        
        plt.tight_layout()
        plt.show()