Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import mlflow | |
| import numpy as np | |
| from PIL import Image | |
| from skimage.color import rgb2lab, lab2rgb | |
| from torchvision import transforms | |
| from model import Generator | |
| EXPERIMENT_NAME = "Colorizer_Experiment" | |
| RUN_ID = "your_run_id_here" # Replace with your actual run ID | |
| def setup_mlflow(): | |
| experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME) | |
| if experiment is None: | |
| experiment_id = mlflow.create_experiment(EXPERIMENT_NAME) | |
| else: | |
| experiment_id = experiment.experiment_id | |
| return experiment_id | |
| def load_model(run_id, device): | |
| print(f"Loading model from run: {run_id}") | |
| model_uri = f"runs:/{run_id}/generator_model" | |
| model = mlflow.pytorch.load_model(model_uri, map_location=device) | |
| return model | |
| def preprocess_image(image): | |
| img = Image.fromarray(image).convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor() | |
| ]) | |
| img_tensor = transform(img) | |
| lab_img = rgb2lab(img_tensor.permute(1, 2, 0).numpy()) | |
| L = lab_img[:,:,0] | |
| L = (L - 50) / 50 | |
| L = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float() | |
| return L | |
| def postprocess_output(L, ab): | |
| L = L.squeeze().cpu().numpy() | |
| ab = ab.squeeze().cpu().numpy() | |
| L = (L + 1.) * 50. | |
| ab = ab * 128. | |
| Lab = np.concatenate([L[..., np.newaxis], ab], axis=2) | |
| rgb_img = lab2rgb(Lab) | |
| return (rgb_img * 255).astype(np.uint8) | |
| def colorize_image(image, model, device): | |
| L = preprocess_image(image).to(device) | |
| with torch.no_grad(): | |
| ab = model(L) | |
| colorized = postprocess_output(L, ab) | |
| return colorized | |
| def setup_gradio_app(run_id, device): | |
| model = load_model(run_id, device) | |
| def gradio_colorize(input_image): | |
| colorized = colorize_image(input_image, model, device) | |
| return Image.fromarray(colorized) | |
| iface = gr.Interface( | |
| fn=gradio_colorize, | |
| inputs=gr.Image(label="Upload a grayscale image"), | |
| outputs=gr.Image(label="Colorized Image"), | |
| title="Image Colorizer", | |
| description="Upload a grayscale image and get a colorized version!", | |
| ) | |
| return iface |