Update main.py
Browse files
main.py
CHANGED
|
@@ -20,9 +20,9 @@ def under_max(image):
|
|
| 20 |
return image
|
| 21 |
|
| 22 |
class Model(object):
|
| 23 |
-
def __init__(self, gpu=
|
| 24 |
config = Config()
|
| 25 |
-
config.device = 'cuda:{}'.format(gpu)
|
| 26 |
model, _ = caption_model.build_model(config)
|
| 27 |
checkpoint = torch.load('./checkpoint.pth', map_location='cpu')
|
| 28 |
model.load_state_dict(checkpoint['model'])
|
|
|
|
| 20 |
return image
|
| 21 |
|
| 22 |
class Model(object):
|
| 23 |
+
def __init__(self, gpu=None):
|
| 24 |
config = Config()
|
| 25 |
+
config.device = 'cpu' if gpu is None else 'cuda:{}'.format(gpu)
|
| 26 |
model, _ = caption_model.build_model(config)
|
| 27 |
checkpoint = torch.load('./checkpoint.pth', map_location='cpu')
|
| 28 |
model.load_state_dict(checkpoint['model'])
|