Spaces:
Running
on
Zero
Running
on
Zero
Upload 71 files
Browse files- Amodal3R/pipelines/image_to_3d.py +2 -1
- app.py +81 -22
Amodal3R/pipelines/image_to_3d.py
CHANGED
|
@@ -377,6 +377,7 @@ class Amodal3RImageTo3DPipeline(Pipeline):
|
|
| 377 |
slat_sampler_params: dict = {},
|
| 378 |
formats: List[str] = ['mesh', 'gaussian'],
|
| 379 |
mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
|
|
|
|
| 380 |
) -> dict:
|
| 381 |
"""
|
| 382 |
Run the pipeline with multiple images as condition
|
|
@@ -388,7 +389,7 @@ class Amodal3RImageTo3DPipeline(Pipeline):
|
|
| 388 |
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
|
| 389 |
preprocess_image (bool): Whether to preprocess the image.
|
| 390 |
"""
|
| 391 |
-
images, masks, masks_occ = zip(*[self.preprocess_image_w_mask(image, mask) for image, mask in zip(images, masks)])
|
| 392 |
images = list(images)
|
| 393 |
masks = list(masks)
|
| 394 |
masks_occ = list(masks_occ)
|
|
|
|
| 377 |
slat_sampler_params: dict = {},
|
| 378 |
formats: List[str] = ['mesh', 'gaussian'],
|
| 379 |
mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
|
| 380 |
+
erode_kernel_size: int = 3,
|
| 381 |
) -> dict:
|
| 382 |
"""
|
| 383 |
Run the pipeline with multiple images as condition
|
|
|
|
| 389 |
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
|
| 390 |
preprocess_image (bool): Whether to preprocess the image.
|
| 391 |
"""
|
| 392 |
+
images, masks, masks_occ = zip(*[self.preprocess_image_w_mask(image, mask, erode_kernel_size) for image, mask in zip(images, masks)])
|
| 393 |
images = list(images)
|
| 394 |
masks = list(masks)
|
| 395 |
masks_occ = list(masks_occ)
|
app.py
CHANGED
|
@@ -103,6 +103,7 @@ def image_to_3d(
|
|
| 103 |
ss_sampling_steps: int,
|
| 104 |
slat_guidance_strength: float,
|
| 105 |
slat_sampling_steps: int,
|
|
|
|
| 106 |
req: gr.Request,
|
| 107 |
) -> Tuple[dict, str]:
|
| 108 |
"""
|
|
@@ -136,8 +137,9 @@ def image_to_3d(
|
|
| 136 |
"cfg_strength": slat_guidance_strength,
|
| 137 |
},
|
| 138 |
mode="stochastic",
|
|
|
|
| 139 |
)
|
| 140 |
-
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
| 141 |
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
| 142 |
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
| 143 |
video_path = os.path.join(user_dir, 'sample.mp4')
|
|
@@ -323,7 +325,7 @@ def delete_mask(mask_list):
|
|
| 323 |
mask_list.pop()
|
| 324 |
return mask_list
|
| 325 |
|
| 326 |
-
def check_combined_mask(image, visibility_mask, mask_list, scale=0.
|
| 327 |
updated_image = image.copy()
|
| 328 |
# combine all the mask:
|
| 329 |
combined_mask = np.zeros_like(updated_image[:, :, 0])
|
|
@@ -394,13 +396,13 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
| 394 |
|
| 395 |
|
| 396 |
with gr.Row():
|
| 397 |
-
gr.Markdown("""
|
|
|
|
| 398 |
* Please wait for a few seconds after uploading the image. The 2D segmenter is getting ready.
|
| 399 |
-
* Add the point prompts to indicate the target object
|
| 400 |
-
* "Render Point", see the position of the point to be added.
|
| 401 |
-
* "
|
| 402 |
-
*
|
| 403 |
-
* "Add mask", current mask will be added for 3D amodal completion.
|
| 404 |
""")
|
| 405 |
with gr.Row():
|
| 406 |
with gr.Column():
|
|
@@ -434,11 +436,13 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
| 434 |
undo_vis_mask = gr.Button("Undo Last Mask")
|
| 435 |
vis_input = gr.Image(label='Visible Input', interactive=False, height=300)
|
| 436 |
with gr.Row():
|
| 437 |
-
zoom_scale = gr.Slider(0.3, 1.0, label="Target Object Scale", value=0.
|
| 438 |
check_visible_input = gr.Button("Generate Occluded Input")
|
| 439 |
with gr.Row():
|
| 440 |
-
gr.Markdown("""
|
|
|
|
| 441 |
* Different random seeds can be tried in "Generation Settings", if you think the results are not ideal.
|
|
|
|
| 442 |
* If the reconstruction 3D asset is satisfactory, you can extract the GLB file and download it.
|
| 443 |
""")
|
| 444 |
with gr.Row():
|
|
@@ -446,6 +450,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
| 446 |
with gr.Accordion(label="Generation Settings", open=True):
|
| 447 |
seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
|
| 448 |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
|
|
|
| 449 |
gr.Markdown("Stage 1: Sparse Structure Generation")
|
| 450 |
with gr.Row():
|
| 451 |
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
|
@@ -454,10 +459,37 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
| 454 |
with gr.Row():
|
| 455 |
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
|
| 456 |
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
| 457 |
-
generate_btn = gr.Button("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
with gr.Column():
|
| 459 |
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
# # Handlers
|
| 462 |
demo.load(start_session)
|
| 463 |
demo.unload(end_session)
|
|
@@ -536,21 +568,48 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
| 536 |
|
| 537 |
|
| 538 |
# 3D Amodal Reconstruction
|
| 539 |
-
# generate_btn.click(
|
| 540 |
-
# get_seed,
|
| 541 |
-
# inputs=[randomize_seed, seed],
|
| 542 |
-
# outputs=[seed],
|
| 543 |
-
# ).then(
|
| 544 |
-
# image_to_3d,
|
| 545 |
-
# inputs=[vis_input, occluded_mask, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
|
| 546 |
-
# outputs=[output_buf, video_output],
|
| 547 |
-
# )
|
| 548 |
-
|
| 549 |
generate_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
image_to_3d,
|
| 551 |
-
inputs=[vis_input, occluded_mask, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
|
| 552 |
outputs=[output_buf, video_output],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
|
| 556 |
# 启动 Gradio App
|
|
|
|
| 103 |
ss_sampling_steps: int,
|
| 104 |
slat_guidance_strength: float,
|
| 105 |
slat_sampling_steps: int,
|
| 106 |
+
erode_kernel_size: int,
|
| 107 |
req: gr.Request,
|
| 108 |
) -> Tuple[dict, str]:
|
| 109 |
"""
|
|
|
|
| 137 |
"cfg_strength": slat_guidance_strength,
|
| 138 |
},
|
| 139 |
mode="stochastic",
|
| 140 |
+
erode_kernel_size=erode_kernel_size,
|
| 141 |
)
|
| 142 |
+
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120, bg_color=(1,1,1))['color']
|
| 143 |
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
| 144 |
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
| 145 |
video_path = os.path.join(user_dir, 'sample.mp4')
|
|
|
|
| 325 |
mask_list.pop()
|
| 326 |
return mask_list
|
| 327 |
|
| 328 |
+
def check_combined_mask(image, visibility_mask, mask_list, scale=0.65):
|
| 329 |
updated_image = image.copy()
|
| 330 |
# combine all the mask:
|
| 331 |
combined_mask = np.zeros_like(updated_image[:, :, 0])
|
|
|
|
| 396 |
|
| 397 |
|
| 398 |
with gr.Row():
|
| 399 |
+
gr.Markdown("""
|
| 400 |
+
### Step 1 - Generate Visibility Mask and Occlusion Mask.
|
| 401 |
* Please wait for a few seconds after uploading the image. The 2D segmenter is getting ready.
|
| 402 |
+
* Add the point prompts to indicate the target object.
|
| 403 |
+
* "Render Point", see the position of the point to be added. "Add Point", the point will be added to the list.
|
| 404 |
+
* "Generate mask", see the segmented area corresponding to current point list. "Add mask", current mask will be added for 3D amodal completion.
|
| 405 |
+
* The target object need to be put in the center of the image, the scale can be adjusted for better reconstruction.
|
|
|
|
| 406 |
""")
|
| 407 |
with gr.Row():
|
| 408 |
with gr.Column():
|
|
|
|
| 436 |
undo_vis_mask = gr.Button("Undo Last Mask")
|
| 437 |
vis_input = gr.Image(label='Visible Input', interactive=False, height=300)
|
| 438 |
with gr.Row():
|
| 439 |
+
zoom_scale = gr.Slider(0.3, 1.0, label="Target Object Scale", value=0.68, step=0.1)
|
| 440 |
check_visible_input = gr.Button("Generate Occluded Input")
|
| 441 |
with gr.Row():
|
| 442 |
+
gr.Markdown("""
|
| 443 |
+
### Step 2 - 3D Amodal Completion.
|
| 444 |
* Different random seeds can be tried in "Generation Settings", if you think the results are not ideal.
|
| 445 |
+
* The boundary of the segmentation may not be accurate, so here we provide the option to erode the visible area.
|
| 446 |
* If the reconstruction 3D asset is satisfactory, you can extract the GLB file and download it.
|
| 447 |
""")
|
| 448 |
with gr.Row():
|
|
|
|
| 450 |
with gr.Accordion(label="Generation Settings", open=True):
|
| 451 |
seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
|
| 452 |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
| 453 |
+
erode_kernel_size = gr.Slider(0, 5, label="Erode Kernel Size", value=0, step=1)
|
| 454 |
gr.Markdown("Stage 1: Sparse Structure Generation")
|
| 455 |
with gr.Row():
|
| 456 |
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
|
|
|
| 459 |
with gr.Row():
|
| 460 |
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
|
| 461 |
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
| 462 |
+
generate_btn = gr.Button("Amodal 3D Reconstruction")
|
| 463 |
+
with gr.Accordion(label="GLB Extraction Settings", open=False):
|
| 464 |
+
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
|
| 465 |
+
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
|
| 466 |
+
with gr.Row():
|
| 467 |
+
extract_glb_btn = gr.Button("Extract GLB")
|
| 468 |
+
extract_gs_btn = gr.Button("Extract Gaussian")
|
| 469 |
+
gr.Markdown("""
|
| 470 |
+
*NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
|
| 471 |
+
""")
|
| 472 |
with gr.Column():
|
| 473 |
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
| 474 |
+
model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
|
| 475 |
+
|
| 476 |
+
with gr.Row():
|
| 477 |
+
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
| 478 |
+
download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
|
| 479 |
|
| 480 |
+
with gr.Row() as single_image_example:
|
| 481 |
+
examples = gr.Examples(
|
| 482 |
+
examples=[
|
| 483 |
+
f'assets/example_image/{image}'
|
| 484 |
+
for image in os.listdir("assets/example_image")
|
| 485 |
+
],
|
| 486 |
+
inputs=[input_image],
|
| 487 |
+
fn=lambda image: input_image.upload(image),
|
| 488 |
+
outputs=[predictor, original_image, message],
|
| 489 |
+
run_on_click=True,
|
| 490 |
+
examples_per_page=12,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
# # Handlers
|
| 494 |
demo.load(start_session)
|
| 495 |
demo.unload(end_session)
|
|
|
|
| 568 |
|
| 569 |
|
| 570 |
# 3D Amodal Reconstruction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
generate_btn.click(
|
| 572 |
+
get_seed,
|
| 573 |
+
inputs=[randomize_seed, seed],
|
| 574 |
+
outputs=[seed],
|
| 575 |
+
).then(
|
| 576 |
image_to_3d,
|
| 577 |
+
inputs=[vis_input, occluded_mask, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, erode_kernel_size],
|
| 578 |
outputs=[output_buf, video_output],
|
| 579 |
+
).then(
|
| 580 |
+
lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
|
| 581 |
+
outputs=[extract_glb_btn, extract_gs_btn],
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
video_output.clear(
|
| 585 |
+
lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
|
| 586 |
+
outputs=[extract_glb_btn, extract_gs_btn],
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
extract_glb_btn.click(
|
| 590 |
+
extract_glb,
|
| 591 |
+
inputs=[output_buf, mesh_simplify, texture_size],
|
| 592 |
+
outputs=[model_output, download_glb],
|
| 593 |
+
).then(
|
| 594 |
+
lambda: gr.Button(interactive=True),
|
| 595 |
+
outputs=[download_glb],
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
extract_gs_btn.click(
|
| 599 |
+
extract_gaussian,
|
| 600 |
+
inputs=[output_buf],
|
| 601 |
+
outputs=[model_output, download_gs],
|
| 602 |
+
).then(
|
| 603 |
+
lambda: gr.Button(interactive=True),
|
| 604 |
+
outputs=[download_gs],
|
| 605 |
)
|
| 606 |
+
|
| 607 |
+
model_output.clear(
|
| 608 |
+
lambda: gr.Button(interactive=False),
|
| 609 |
+
outputs=[download_glb],
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
|
| 613 |
|
| 614 |
|
| 615 |
# 启动 Gradio App
|