set device in inference_state manually
Browse files- .gitignore +1 -0
- app.py +51 -43
.gitignore
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
*.egg-info/
|
| 2 |
__pycache__/
|
|
|
|
|
|
| 1 |
*.egg-info/
|
| 2 |
__pycache__/
|
| 3 |
+
*.DS_Store
|
app.py
CHANGED
|
@@ -174,6 +174,8 @@ def preprocess_video_in(
|
|
| 174 |
input_labels,
|
| 175 |
inference_state,
|
| 176 |
):
|
|
|
|
|
|
|
| 177 |
if video_path is None:
|
| 178 |
return (
|
| 179 |
gr.update(open=True), # video_in_drawer
|
|
@@ -255,6 +257,8 @@ def segment_with_points(
|
|
| 255 |
inference_state,
|
| 256 |
evt: gr.SelectData,
|
| 257 |
):
|
|
|
|
|
|
|
| 258 |
input_points.append(evt.index)
|
| 259 |
print(f"TRACKING INPUT POINT: {input_points}")
|
| 260 |
|
|
@@ -336,55 +340,59 @@ def propagate_to_all(
|
|
| 336 |
input_points,
|
| 337 |
inference_state,
|
| 338 |
):
|
| 339 |
-
# torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
| 340 |
if torch.cuda.get_device_properties(0).major >= 8:
|
| 341 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 342 |
torch.backends.cudnn.allow_tf32 = True
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
)
|
| 367 |
-
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
| 368 |
-
mask_image = show_mask(out_mask)
|
| 369 |
-
output_frame = Image.alpha_composite(transparent_background, mask_image)
|
| 370 |
-
output_frame = np.array(output_frame)
|
| 371 |
-
output_frames.append(output_frame)
|
| 372 |
-
|
| 373 |
-
torch.cuda.empty_cache()
|
| 374 |
-
|
| 375 |
-
# Create a video clip from the image sequence
|
| 376 |
-
original_fps = get_video_fps(video_in)
|
| 377 |
-
fps = original_fps # Frames per second
|
| 378 |
-
clip = ImageSequenceClip(output_frames, fps=fps)
|
| 379 |
-
# Write the result to a file
|
| 380 |
-
unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 381 |
-
final_vid_output_path = f"output_video_{unique_id}.mp4"
|
| 382 |
-
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
|
| 383 |
|
| 384 |
-
|
| 385 |
-
|
| 386 |
|
| 387 |
-
|
| 388 |
|
| 389 |
|
| 390 |
def update_ui():
|
|
|
|
| 174 |
input_labels,
|
| 175 |
inference_state,
|
| 176 |
):
|
| 177 |
+
predictor.to("cpu")
|
| 178 |
+
inference_state["device"] = predictor.device
|
| 179 |
if video_path is None:
|
| 180 |
return (
|
| 181 |
gr.update(open=True), # video_in_drawer
|
|
|
|
| 257 |
inference_state,
|
| 258 |
evt: gr.SelectData,
|
| 259 |
):
|
| 260 |
+
predictor.to("cpu")
|
| 261 |
+
inference_state["device"] = predictor.device
|
| 262 |
input_points.append(evt.index)
|
| 263 |
print(f"TRACKING INPUT POINT: {input_points}")
|
| 264 |
|
|
|
|
| 340 |
input_points,
|
| 341 |
inference_state,
|
| 342 |
):
|
|
|
|
| 343 |
if torch.cuda.get_device_properties(0).major >= 8:
|
| 344 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 345 |
torch.backends.cudnn.allow_tf32 = True
|
| 346 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 347 |
+
predictor.to("cuda")
|
| 348 |
+
inference_state["device"] = predictor.device
|
| 349 |
+
|
| 350 |
+
if len(input_points) == 0 or video_in is None or inference_state is None:
|
| 351 |
+
return None
|
| 352 |
+
# run propagation throughout the video and collect the results in a dict
|
| 353 |
+
video_segments = (
|
| 354 |
+
{}
|
| 355 |
+
) # video_segments contains the per-frame segmentation results
|
| 356 |
+
print("starting propagate_in_video")
|
| 357 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
| 358 |
+
inference_state
|
| 359 |
+
):
|
| 360 |
+
video_segments[out_frame_idx] = {
|
| 361 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
| 362 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
# obtain the segmentation results every few frames
|
| 366 |
+
vis_frame_stride = 1
|
| 367 |
+
|
| 368 |
+
output_frames = []
|
| 369 |
+
for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
|
| 370 |
+
transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert(
|
| 371 |
+
"RGBA"
|
| 372 |
+
)
|
| 373 |
+
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
| 374 |
+
mask_image = show_mask(out_mask)
|
| 375 |
+
output_frame = Image.alpha_composite(transparent_background, mask_image)
|
| 376 |
+
output_frame = np.array(output_frame)
|
| 377 |
+
output_frames.append(output_frame)
|
| 378 |
+
|
| 379 |
+
torch.cuda.empty_cache()
|
| 380 |
+
|
| 381 |
+
# Create a video clip from the image sequence
|
| 382 |
+
original_fps = get_video_fps(video_in)
|
| 383 |
+
fps = original_fps # Frames per second
|
| 384 |
+
clip = ImageSequenceClip(output_frames, fps=fps)
|
| 385 |
+
# Write the result to a file
|
| 386 |
+
unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 387 |
+
final_vid_output_path = f"output_video_{unique_id}.mp4"
|
| 388 |
+
final_vid_output_path = os.path.join(
|
| 389 |
+
tempfile.gettempdir(), final_vid_output_path
|
| 390 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
+
# Write the result to a file
|
| 393 |
+
clip.write_videofile(final_vid_output_path, codec="libx264")
|
| 394 |
|
| 395 |
+
return gr.update(value=final_vid_output_path)
|
| 396 |
|
| 397 |
|
| 398 |
def update_ui():
|