Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| from streamlit_image_comparison import image_comparison | |
| # from src.envs.new_edit_photo import PhotoEditor | |
| from src.sac.sac_inference import InferenceAgent | |
| import yaml | |
| import os | |
| from src.envs.photo_env import PhotoEnhancementEnvTest | |
| from tensordict import TensorDict | |
| import torchvision.transforms.v2.functional as F | |
| from streamlit import cache_resource | |
| import pandas as pd | |
| from bokeh.plotting import figure | |
| from bokeh.models import ColumnDataSource | |
| from bokeh.palettes import Spectral3 | |
| from src.envs.edit_photo_opt import PhotoEditor | |
| import io | |
| import cv2 | |
| # Set page config to wide mode | |
| st.set_page_config(layout="wide") | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # DEVICE = torch.device("cpu") | |
| MODEL_PATH = os.path.join("experiments",'ResNet_10_sliders__224_128_aug__2024-07-23_21-23-35') | |
| SLIDERS = ['temp','tint','exposure', 'contrast','highlights','shadows', 'whites', 'blacks','vibrance','saturation'] | |
| SLIDERS_ORD = ['contrast','exposure','temp','tint','whites','blacks','highlights','shadows','vibrance','saturation'] | |
| class Config(object): | |
| def __init__(self,dictionary): | |
| self.__dict__.update(dictionary) | |
| def load_preprocessor_agent(preprocessor_agent_path,device): | |
| with open(os.path.join(preprocessor_agent_path,"configs/sac_config.yaml")) as f: | |
| sac_config_dict = yaml.load(f, Loader=yaml.FullLoader) | |
| with open(os.path.join(preprocessor_agent_path,"configs/env_config.yaml")) as f: | |
| env_config_dict = yaml.load(f, Loader=yaml.FullLoader) | |
| with open(os.path.join("src/configs/inference_config.yaml")) as f: | |
| inf_config_dict = yaml.load(f, Loader=yaml.FullLoader) | |
| inference_config = Config(inf_config_dict) | |
| sac_config = Config(sac_config_dict) | |
| env_config = Config(env_config_dict) | |
| inference_env = PhotoEnhancementEnvTest( | |
| batch_size=env_config.train_batch_size, | |
| imsize=env_config.imsize, | |
| training_mode=None, | |
| done_threshold=env_config.threshold_psnr, | |
| edit_sliders=env_config.sliders_to_use, | |
| features_size=env_config.features_size, | |
| discretize=env_config.discretize, | |
| discretize_step=env_config.discretize_step, | |
| use_txt_features=env_config.use_txt_features if hasattr(env_config,'use_txt_features') else False, | |
| augment_data=False, | |
| pre_encoding_device=device, | |
| pre_load_images=False, | |
| logger=None | |
| ) | |
| inference_config.device = device | |
| preprocessor_agent = InferenceAgent(inference_env, inference_config) | |
| preprocessor_agent.device = device | |
| preprocessor_agent.load_backbone(os.path.join(preprocessor_agent_path,'models','backbone.pth')) | |
| preprocessor_agent.load_actor_weights(os.path.join(preprocessor_agent_path,'models','actor_head.pth')) | |
| preprocessor_agent.load_critics_weights(os.path.join(preprocessor_agent_path,'models','qf1_head.pth'), | |
| os.path.join(preprocessor_agent_path,'models','qf2_head.pth')) | |
| return preprocessor_agent | |
| enhancer_agent = load_preprocessor_agent(MODEL_PATH,DEVICE) | |
| photo_editor = PhotoEditor(SLIDERS) | |
| def enhance_image(image:np.array, params:dict): | |
| input_image = image.unsqueeze(0).to(DEVICE) | |
| parameters = [params[param_name]/100.0 for param_name in SLIDERS_ORD] | |
| parameters = torch.tensor(parameters).unsqueeze(0).to(DEVICE) | |
| if st.session_state.photopro_image is None: | |
| enhanced_image,photopro_image = photo_editor(input_image,parameters,use_photopro_image=False) | |
| st.session_state.photopro_image = photopro_image | |
| else: | |
| enhanced_image = photo_editor(st.session_state.photopro_image,parameters,use_photopro_image=True) | |
| enhanced_image = enhanced_image.squeeze(0).cpu().detach().numpy() | |
| enhanced_image = np.clip(enhanced_image, 0, 1) | |
| enhanced_image = (enhanced_image*255).astype(np.uint8) | |
| return enhanced_image | |
| def auto_enhance(image,deterministic=True): | |
| input_image = image.unsqueeze(0).to(DEVICE) | |
| input_image = input_image.permute(0,3,1,2) | |
| IMAGE_SIZE = enhancer_agent.env.imsize | |
| input_image = F.resize(input_image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=F.InterpolationMode.BICUBIC) | |
| batch_observation = TensorDict( | |
| { | |
| "batch_images":input_image, | |
| }, | |
| batch_size = [input_image.shape[0]], | |
| ) | |
| parameters = enhancer_agent.act(batch_observation,deterministic=deterministic,n_samples=0) | |
| parameters = parameters.squeeze(0)*100.0 | |
| parameters = torch.round(parameters) | |
| output_parameters = [] | |
| index = 0 | |
| for slider in SLIDERS_ORD: | |
| if slider in enhancer_agent.env.edit_sliders: | |
| output_parameters.append(parameters[index].item()) | |
| index += 1 | |
| else: | |
| output_parameters.append(0) | |
| return output_parameters | |
| def slider_callback(): | |
| st.session_state.apply_button_enabled = True | |
| def apply_button_callback(): | |
| for name in SLIDERS: | |
| st.session_state.params[name] = st.session_state[f"slider_{name}"] | |
| image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0 | |
| st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params) | |
| st.session_state.apply_button_enabled = False | |
| def auto_random_enhance_callback(): | |
| image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0 | |
| auto_params = auto_enhance(image_tensor,deterministic=False) | |
| for i, name in enumerate(SLIDERS_ORD): | |
| st.session_state[f"slider_{name}"] = int(auto_params[i]) | |
| st.session_state.params[name] = int(auto_params[i]) | |
| st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params) | |
| def auto_enhance_callback(): | |
| image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0 | |
| auto_params = auto_enhance(image_tensor) | |
| for i, name in enumerate(SLIDERS_ORD): | |
| st.session_state[f"slider_{name}"] = int(auto_params[i]) | |
| st.session_state.params[name] = int(auto_params[i]) | |
| st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params) | |
| def reset_sliders(): | |
| for name in SLIDERS: | |
| st.session_state[f"slider_{name}"] = 0 | |
| st.session_state.params[name] = 0 | |
| # st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params) | |
| st.session_state.enhanced_image = st.session_state.original_image | |
| def reset_on_upload(): | |
| st.session_state.original_image = None | |
| st.session_state.photopro_image = None | |
| st.session_state.file_extension = None | |
| st.session_state.mime_type = None | |
| reset_sliders() | |
| def create_smooth_histogram(image): | |
| # Compute histograms for each channel | |
| bins = np.linspace(0, 255, 256) | |
| hist_r, _ = np.histogram(image[..., 0], bins=bins) | |
| hist_g, _ = np.histogram(image[..., 1], bins=bins) | |
| hist_b, _ = np.histogram(image[..., 2], bins=bins) | |
| # Normalize the histograms | |
| def normalize_histogram(hist): | |
| hist_central = hist[1:-1] | |
| hist_max = np.max(hist_central) | |
| hist_min = np.min(hist_central) | |
| hist_normalized = (hist_central - hist_min) / (hist_max - hist_min) | |
| hist[0] = min(hist[0] / hist_max, 1) | |
| hist[-1] = min(hist[-1] / hist_max, 1) | |
| return np.concatenate(([hist[0]], hist_normalized, [hist[-1]])) | |
| hist_r_norm = normalize_histogram(hist_r) | |
| hist_g_norm = normalize_histogram(hist_g) | |
| hist_b_norm = normalize_histogram(hist_b) | |
| # Create Bokeh figure with transparent background | |
| p = figure(width=300, height=150, toolbar_location=None, | |
| x_range=(0, 255), y_range=(0, 1.1), | |
| background_fill_color=None, | |
| border_fill_color=None, | |
| outline_line_color=None) | |
| # Remove all axes, labels, and grids | |
| p.axis.visible = False | |
| p.xgrid.grid_line_color = None | |
| p.ygrid.grid_line_color = None | |
| # Create ColumnDataSource for each channel | |
| source_r = ColumnDataSource(data=dict(left=bins[:-1], right=bins[1:], top=hist_r_norm)) | |
| source_g = ColumnDataSource(data=dict(left=bins[:-1], right=bins[1:], top=hist_g_norm)) | |
| source_b = ColumnDataSource(data=dict(left=bins[:-1], right=bins[1:], top=hist_b_norm)) | |
| # Plot the histograms | |
| p.quad(bottom=0, top='top', left='left', right='right', source=source_r, | |
| fill_color="red", fill_alpha=0.9, line_color=None) | |
| p.quad(bottom=0, top='top', left='left', right='right', source=source_g, | |
| fill_color="green", fill_alpha=0.9, line_color=None) | |
| p.quad(bottom=0, top='top', left='left', right='right', source=source_b, | |
| fill_color="blue", fill_alpha=0.9, line_color=None) | |
| # Remove padding | |
| p.min_border_left = 0 | |
| p.min_border_right = 0 | |
| p.min_border_top = 0 | |
| p.min_border_bottom = 0 | |
| return p | |
| # In your Streamlit app | |
| def plot_histogram_streamlit(image): | |
| histogram = create_smooth_histogram(image) | |
| st.sidebar.bokeh_chart(histogram, use_container_width=True) | |
| # Initialize session state | |
| if 'enhanced_image' not in st.session_state: | |
| st.session_state.enhanced_image = None | |
| if 'original_image' not in st.session_state: | |
| st.session_state.original_image = None | |
| if 'photopro_image' not in st.session_state: | |
| st.session_state.photopro_image = None | |
| if 'params' not in st.session_state: | |
| st.session_state.params = {name: 0 for name in SLIDERS} | |
| if "apply_button_enabled" not in st.session_state: | |
| st.session_state.apply_button_enabled = False | |
| if "uploaded_file" not in st.session_state: | |
| st.session_state.uploaded_file = None | |
| if 'file_extension' not in st.session_state: | |
| st.session_state.file_extension = None | |
| for name in SLIDERS: | |
| if f"slider_{name}" not in st.session_state: | |
| st.session_state[f"slider_{name}"] = 0 | |
| for name in SLIDERS: | |
| if f"slider_{name}" not in st.session_state: | |
| st.session_state[f"slider_{name}"] = 0 | |
| # Set up the Streamlit app | |
| # File uploader in the main area | |
| _, center_col,_ = st.columns([1, 2, 1]) | |
| if st.session_state.original_image is None: | |
| with center_col: | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "tif", "tiff"], on_change=reset_on_upload) | |
| # Show welcome content only if no file is uploaded | |
| if uploaded_file is None: | |
| st.title("Welcome to the AI Photo Enhancement App") | |
| st.write("This application uses Reinforcement Learning to enhance your photos with ease and flexibility. Below, you'll find a short tutorial video on how to use the app.") | |
| st.write("You can find the project repository here: [AI Photo Enhancer GitHub](https://github.com/zakaria-narjis/ai-photo-enhancer).") | |
| # st.info("Please note that this app is currently limited to 8-bit images. I'm working on extending it to handle 16-bit images.", icon="ℹ️") | |
| st.video("demo.mp4", muted=True) | |
| st.subheader("How to Use the App:") | |
| st.markdown(""" | |
| 1. **<span style='color: #FF6347;'>Upload</span> your image** to get started. | |
| 2. Proceed to the **<span style='color: #FF6347;'>editing</span> phase**, where you can: | |
| - View the color histograms and adjust sliders for various tonal values. | |
| - Use the **<span style='color: #FF6347;'>Auto Enhance</span>** button to let the agent automatically enhance your image. | |
| - Use the **<span style='color: #FF6347;'>Auto Random Enhance</span>** button to receive different suggested editing parameters each time. | |
| - **<span style='color: #FF6347;'>Manually refine</span>** your image by adjusting slider values and clicking the **<span style='color: #FF6347;'>Apply Edits</span>** button at the bottom of the sidebar. | |
| 3. **<span style='color: #FF6347;'>Download</span>** your enhanced image once you're satisfied with the results. | |
| """, unsafe_allow_html=True) | |
| st.write("Enjoy! :smile:") | |
| elif st.session_state.original_image is None: # Process the uploaded file | |
| # Get and store the file extension | |
| file_ext = uploaded_file.name.split('.')[-1].lower() | |
| st.session_state.file_extension = file_ext | |
| # Convert tif/tiff to appropriate mime type | |
| if file_ext in ['tif', 'tiff']: | |
| st.session_state.mime_type = 'image/tiff' | |
| elif file_ext == 'jpg': | |
| st.session_state.file_extension = 'jpeg' | |
| st.session_state.mime_type = 'image/jpeg' | |
| else: | |
| st.session_state.mime_type = f'image/{file_ext}' | |
| st.session_state.original_image = np.array(Image.open(uploaded_file).convert('RGB'), dtype=np.uint16) | |
| if st.session_state.original_image is not None: | |
| # Load the original image | |
| # st.session_state.original_image = np.array(Image.open(st.session_state.original_image).convert('RGB'),dtype=np.uint16) | |
| # Enhance the image initially | |
| if st.session_state.enhanced_image is None: | |
| st.session_state.enhanced_image = st.session_state.original_image | |
| # Sidebar for controls | |
| st.sidebar.title("Controls") | |
| # Display histogram | |
| st.sidebar.subheader("Colors Histogram") | |
| plot_histogram_streamlit(st.session_state.enhanced_image) | |
| # Select box to choose which image to display | |
| display_option = st.sidebar.selectbox( | |
| "Select view mode", | |
| ("Comparison", "Enhanced") | |
| ) | |
| # Create two columns for the buttons | |
| col1, col2,col3 = st.sidebar.columns(3) | |
| # Button for auto-enhancement | |
| with col1: | |
| st.button("Auto Enhance", on_click=auto_enhance_callback, key="auto_enhance_button",use_container_width=True) | |
| with col2: | |
| st.button("Auto Random Enhance", on_click=auto_random_enhance_callback, key="auto_random_enhance_button",use_container_width=True) | |
| # Button for resetting sliders | |
| with col3: | |
| st.button("Reset", on_click=reset_sliders, key="reset_button",use_container_width=True) | |
| st.sidebar.subheader("Adjustments") | |
| slider_names = SLIDERS | |
| for name in slider_names: | |
| if f"slider_{name}" not in st.session_state: | |
| st.session_state[f"slider_{name}"] = 0 | |
| st.sidebar.slider( | |
| name.capitalize(), | |
| min_value=-100, | |
| max_value=100, | |
| value=st.session_state[f"slider_{name}"], | |
| key=f"slider_{name}", | |
| on_change=slider_callback | |
| ) | |
| st.sidebar.button("Apply manual edit", on_click=apply_button_callback, key="apply_button",use_container_width=True,disabled=not st.session_state.apply_button_enabled) | |
| # Create a single column to maximize width | |
| left_spacer, content_column, right_spacer = st.columns([1, 3, 1]) | |
| with content_column: | |
| if display_option == "Enhanced": | |
| if st.session_state.enhanced_image is not None: | |
| st.image(st.session_state.enhanced_image.astype(np.uint8), caption="Enhanced Image", use_column_width=True) | |
| else: | |
| st.warning("Enhanced image is not available. Try adjusting the sliders or clicking 'Auto Enhance'.") | |
| else: # Comparison view | |
| if st.session_state.enhanced_image is not None: | |
| image_comparison( | |
| img1=Image.fromarray(st.session_state.original_image.astype(np.uint8)), | |
| img2=Image.fromarray(st.session_state.enhanced_image.astype(np.uint8)), | |
| label1="Original", | |
| label2="Enhanced", | |
| width=850, # You might want to adjust this value | |
| starting_position=50, | |
| show_labels=True, | |
| make_responsive=True, | |
| ) | |
| else: | |
| st.warning("Enhanced image is not available for comparison. Try adjusting the sliders or clicking 'Auto Enhance'.") | |
| with io.BytesIO() as img_bytes: | |
| enhanced_img = Image.fromarray(st.session_state.enhanced_image.astype(np.uint8)) | |
| enhanced_img.save(img_bytes, format=f"{st.session_state.file_extension}") | |
| st.download_button( | |
| label="Download Enhanced Image", | |
| data=img_bytes.getvalue(), | |
| file_name=f"enhanced.{st.session_state.file_extension}".lower(), | |
| mime=st.session_state.mime_type, | |
| use_container_width=True | |
| ) | |
| # Add custom CSS to make the image comparison component responsive | |
| st.markdown(""" | |
| <style> | |
| .stImageComparison { | |
| width: 100% !important; | |
| } | |
| .stImageComparison > figure > div { | |
| width: 100% !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |