Spaces:
Runtime error
Runtime error
Ahsen Khaliq
commited on
Commit
·
178d84c
1
Parent(s):
d0b8090
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
os.system("gdown https://drive.google.com/uc?id=14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT")
|
|
|
|
| 3 |
import clip
|
| 4 |
import os
|
| 5 |
from torch import nn
|
|
@@ -32,7 +33,6 @@ TA = Union[T, ARRAY]
|
|
| 32 |
D = torch.device
|
| 33 |
CPU = torch.device('cpu')
|
| 34 |
|
| 35 |
-
model_path = 'conceptual_weights.pt'
|
| 36 |
|
| 37 |
def get_device(device_id: int) -> D:
|
| 38 |
if not torch.cuda.is_available():
|
|
@@ -234,14 +234,16 @@ prefix_length = 10
|
|
| 234 |
|
| 235 |
model = ClipCaptionModel(prefix_length)
|
| 236 |
|
| 237 |
-
model.load_state_dict(torch.load(model_path, map_location=CPU))
|
| 238 |
-
|
| 239 |
-
model = model.eval()
|
| 240 |
-
device = CUDA(0) if is_gpu else "cpu"
|
| 241 |
model = model.to(device)
|
| 242 |
|
| 243 |
|
| 244 |
-
def inference(img):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
use_beam_search = False
|
| 246 |
image = io.imread(img.name)
|
| 247 |
pil_image = PIL.Image.fromarray(image)
|
|
@@ -262,7 +264,7 @@ article = "<p style='text-align: center'><a href='https://github.com/rmokady/CLI
|
|
| 262 |
examples=[['water.jpeg']]
|
| 263 |
gr.Interface(
|
| 264 |
inference,
|
| 265 |
-
gr.inputs.Image(type="file", label="Input"),
|
| 266 |
gr.outputs.Textbox(label="Output"),
|
| 267 |
title=title,
|
| 268 |
description=description,
|
|
|
|
| 1 |
import os
|
| 2 |
os.system("gdown https://drive.google.com/uc?id=14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT")
|
| 3 |
+
os.system("gdown https://drive.google.com/uc?id=1IdaBtMSvtyzF0ByVaBHtvM0JYSXRExRX")
|
| 4 |
import clip
|
| 5 |
import os
|
| 6 |
from torch import nn
|
|
|
|
| 33 |
D = torch.device
|
| 34 |
CPU = torch.device('cpu')
|
| 35 |
|
|
|
|
| 36 |
|
| 37 |
def get_device(device_id: int) -> D:
|
| 38 |
if not torch.cuda.is_available():
|
|
|
|
| 234 |
|
| 235 |
model = ClipCaptionModel(prefix_length)
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
model = model.to(device)
|
| 238 |
|
| 239 |
|
| 240 |
+
def inference(img,model):
|
| 241 |
+
if model == "COCO":
|
| 242 |
+
model_path = 'coco_weights.pt'
|
| 243 |
+
else:
|
| 244 |
+
model_path = 'conceptual_weights.pt'
|
| 245 |
+
model.load_state_dict(torch.load(model_path, map_location=CPU))
|
| 246 |
+
|
| 247 |
use_beam_search = False
|
| 248 |
image = io.imread(img.name)
|
| 249 |
pil_image = PIL.Image.fromarray(image)
|
|
|
|
| 264 |
examples=[['water.jpeg']]
|
| 265 |
gr.Interface(
|
| 266 |
inference,
|
| 267 |
+
[gr.inputs.Image(type="file", label="Input"),gr.inputs.Radio(choices["COCO","Conceptual captions"], type="value", default="COCO", label="Model")],
|
| 268 |
gr.outputs.Textbox(label="Output"),
|
| 269 |
title=title,
|
| 270 |
description=description,
|