Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import gradio as gr | |
| import kornia as K | |
| from kornia.core import Tensor | |
| from kornia.contrib import ImageStitcher | |
| import kornia.feature as KF | |
| import torch | |
| import numpy as np | |
| def preprocess_image(img): | |
| print(f"Input image type: {type(img)}") | |
| print(f"Input image shape: {img.shape if hasattr(img, 'shape') else 'No shape attribute'}") | |
| # Convert numpy array to Tensor and ensure correct shape | |
| if isinstance(img, np.ndarray): | |
| img = K.image_to_tensor(img, keepdim=False).float() / 255.0 | |
| elif isinstance(img, torch.Tensor): | |
| img = img.float() | |
| if img.max() > 1.0: | |
| img = img / 255.0 | |
| else: | |
| raise ValueError(f"Unsupported image type: {type(img)}") | |
| print(f"After conversion to tensor - shape: {img.shape}") | |
| # Ensure 4D tensor (B, C, H, W) | |
| if img.ndim == 2: | |
| img = img.unsqueeze(0).unsqueeze(0) | |
| elif img.ndim == 3: | |
| if img.shape[0] in [1, 3]: | |
| img = img.unsqueeze(0) | |
| else: | |
| img = img.unsqueeze(1) | |
| elif img.ndim == 4: | |
| if img.shape[1] not in [1, 3]: | |
| img = img.permute(0, 3, 1, 2) | |
| print(f"After ensuring 4D - shape: {img.shape}") | |
| # Ensure 3 channel image | |
| if img.shape[1] == 1: | |
| img = img.repeat(1, 3, 1, 1) | |
| elif img.shape[1] > 3: | |
| img = img[:, :3] # Take only the first 3 channels if more than 3 | |
| print(f"Final tensor shape: {img.shape}") | |
| return img | |
| def inference(img_1, img_2): | |
| # Preprocess images | |
| img_1 = preprocess_image(img_1) | |
| img_2 = preprocess_image(img_2) | |
| IS = ImageStitcher(KF.LoFTR(pretrained='outdoor'), estimator='ransac') | |
| with torch.no_grad(): | |
| result = IS(img_1, img_2) | |
| return K.tensor_to_image(result[0]) | |
| examples = [ | |
| ['examples/foto1B.jpg', 'examples/foto1A.jpg'], | |
| ] | |
| with gr.Blocks(theme='huggingface') as demo_app: | |
| gr.Markdown("# Image Stitching using Kornia and LoFTR") | |
| with gr.Row(): | |
| input_image1 = gr.Image(label="Input Image 1") | |
| input_image2 = gr.Image(label="Input Image 2") | |
| output_image = gr.Image(label="Output Image") | |
| stitch_button = gr.Button("Stitch Images") | |
| stitch_button.click(fn=inference, inputs=[input_image1, input_image2], outputs=output_image) | |
| gr.Examples(examples=examples, inputs=[input_image1, input_image2]) | |
| if __name__ == "__main__": | |
| demo_app.launch(share=True) |