AlessioChenn commited on
Commit
f2d0a8e
·
verified ·
1 Parent(s): 4b4be8a
__pycache__/explainer.cpython-312.pyc ADDED
Binary file (7.36 kB). View file
 
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoConfig": "explainer.ExplainerConfig",
4
+ "AutoModel": "explainer.Explainer"
5
+ },
6
+ "base_model_name": "google/siglip2-giant-opt-patch16-384",
7
+ "giant": true,
8
+ "hidden_dim": 768,
9
+ "model_type": "explainer",
10
+ "torch_dtype": "float32"
11
+ }
explainer.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, SiglipModel
2
+ import torch
3
+ import torch.nn as nn
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ class ExplainerConfig(PretrainedConfig):
7
+ model_type = "explainer"
8
+
9
+ def __init__(self, base_model_name='google/siglip2-giant-opt-patch16-384',
10
+ hidden_dim=768, giant=True, **kwargs):
11
+ self.base_model_name = base_model_name
12
+ self.hidden_dim = hidden_dim
13
+ self.giant = giant
14
+ super().__init__(**kwargs)
15
+
16
+ class SigLIPBBoxRegressor(nn.Module):
17
+ def __init__(self, siglip_model, hidden_dim=768, giant=True):
18
+ super().__init__()
19
+ self.siglip = siglip_model
20
+
21
+ vision_dim = self.siglip.vision_model.config.hidden_size
22
+ text_dim = self.siglip.text_model.config.hidden_size
23
+ if giant: text_dim = 1536
24
+
25
+ self.vision_projector = nn.Sequential(
26
+ nn.Linear(vision_dim, hidden_dim),
27
+ nn.ReLU(),
28
+ nn.Dropout(0.1)
29
+ )
30
+ self.text_projector = nn.Sequential(
31
+ nn.Linear(text_dim, hidden_dim),
32
+ nn.ReLU(),
33
+ nn.Dropout(0.1)
34
+ )
35
+ self.fusion_layer = nn.Sequential(
36
+ nn.Linear(hidden_dim*2, hidden_dim),
37
+ nn.ReLU(),
38
+ nn.Dropout(0.2),
39
+ nn.Linear(hidden_dim, hidden_dim//2),
40
+ nn.ReLU(),
41
+ nn.Dropout(0.1)
42
+ )
43
+ self.topleft_regressor = nn.Sequential(
44
+ nn.Linear(hidden_dim//2, 256),
45
+ nn.ReLU(),
46
+ nn.Dropout(0.1),
47
+ nn.Linear(256, 128),
48
+ nn.ReLU(),
49
+ nn.Linear(128, 2),
50
+ )
51
+ self.bottomright_regressor = nn.Sequential(
52
+ nn.Linear(hidden_dim//2, 256),
53
+ nn.ReLU(),
54
+ nn.Dropout(0.1),
55
+ nn.Linear(256, 128),
56
+ nn.ReLU(),
57
+ nn.Linear(128, 2),
58
+ )
59
+
60
+ def forward(self, pixel_values, input_ids):
61
+ with torch.no_grad():
62
+ outputs = self.siglip(pixel_values=pixel_values, input_ids=input_ids, return_dict=True)
63
+ vision_features = outputs.image_embeds.float()
64
+ text_features = outputs.text_embeds.float()
65
+
66
+ vision_proj = self.vision_projector(vision_features)
67
+ text_proj = self.text_projector(text_features)
68
+ fused = torch.cat([vision_proj, text_proj], dim=1)
69
+ fused_features = self.fusion_layer(fused)
70
+
71
+ topleft_pred = self.topleft_regressor(fused_features)
72
+ bottomright_pred = self.bottomright_regressor(fused_features)
73
+ return torch.cat([topleft_pred, bottomright_pred], dim=1)
74
+
75
+ class Explainer(PreTrainedModel):
76
+ config_class = ExplainerConfig
77
+
78
+ def __init__(self, config):
79
+ super().__init__(config)
80
+ self.siglip_model = SiglipModel.from_pretrained(config.base_model_name)
81
+ self.bbox_regressor = SigLIPBBoxRegressor(self.siglip_model)
82
+ self.processor = AutoProcessor.from_pretrained(config.base_model_name, use_fast=True)
83
+
84
+ def forward(self, pixel_values=None, input_ids=None):
85
+ return self.bbox_regressor(pixel_values, input_ids)
86
+
87
+ def predict(self, image, text, device="cuda"):
88
+ self.to(device)
89
+ self.eval()
90
+ inputs = self.processor(
91
+ text=text,
92
+ images=image,
93
+ return_tensors="pt",
94
+ padding="max_length",
95
+ truncation=True,
96
+ max_length=64
97
+ )
98
+ pixel_values = inputs["pixel_values"].to(device).half()
99
+ input_ids = inputs["input_ids"].to(device)
100
+ with torch.no_grad():
101
+ pred_bbox = self.forward(pixel_values, input_ids)
102
+ return pred_bbox[0].cpu().numpy().tolist()
103
+
104
+
105
+ @classmethod
106
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
107
+ # Load config automatically (HF passes `config` here sometimes)
108
+ config = kwargs.pop("config", None)
109
+ if config is None:
110
+ config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path)
111
+
112
+ model = cls(config)
113
+
114
+ checkpoint_path = hf_hub_download(
115
+ repo_id=pretrained_model_name_or_path,
116
+ filename="pytorch_model.bin"
117
+ )
118
+
119
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
120
+ model.siglip_model.load_state_dict(checkpoint["siglip_model"])
121
+ model.bbox_regressor.load_state_dict(checkpoint["bbox_regressor"])
122
+ return model
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b18398c9a5ff1226de68d7cc50080e7aa0efc1ea1370cf816b1a994984afd15
3
+ size 3760831559