import streamlit as st
import torch
from PIL import Image
import numpy as np
from streamlit_image_comparison import image_comparison
# from src.envs.new_edit_photo import PhotoEditor
from src.sac.sac_inference import InferenceAgent
import yaml
import os
from src.envs.photo_env import PhotoEnhancementEnvTest
from tensordict import TensorDict
import torchvision.transforms.v2.functional as F
from streamlit import cache_resource
import pandas as pd
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource
from bokeh.palettes import Spectral3
from src.envs.edit_photo_opt import PhotoEditor
import io
import cv2
# Set page config to wide mode
st.set_page_config(layout="wide")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device("cpu")
MODEL_PATH = os.path.join("experiments",'ResNet_10_sliders__224_128_aug__2024-07-23_21-23-35')
SLIDERS = ['temp','tint','exposure', 'contrast','highlights','shadows', 'whites', 'blacks','vibrance','saturation']
SLIDERS_ORD = ['contrast','exposure','temp','tint','whites','blacks','highlights','shadows','vibrance','saturation']
class Config(object):
def __init__(self,dictionary):
self.__dict__.update(dictionary)
@cache_resource
def load_preprocessor_agent(preprocessor_agent_path,device):
with open(os.path.join(preprocessor_agent_path,"configs/sac_config.yaml")) as f:
sac_config_dict = yaml.load(f, Loader=yaml.FullLoader)
with open(os.path.join(preprocessor_agent_path,"configs/env_config.yaml")) as f:
env_config_dict = yaml.load(f, Loader=yaml.FullLoader)
with open(os.path.join("src/configs/inference_config.yaml")) as f:
inf_config_dict = yaml.load(f, Loader=yaml.FullLoader)
inference_config = Config(inf_config_dict)
sac_config = Config(sac_config_dict)
env_config = Config(env_config_dict)
inference_env = PhotoEnhancementEnvTest(
batch_size=env_config.train_batch_size,
imsize=env_config.imsize,
training_mode=None,
done_threshold=env_config.threshold_psnr,
edit_sliders=env_config.sliders_to_use,
features_size=env_config.features_size,
discretize=env_config.discretize,
discretize_step=env_config.discretize_step,
use_txt_features=env_config.use_txt_features if hasattr(env_config,'use_txt_features') else False,
augment_data=False,
pre_encoding_device=device,
pre_load_images=False,
logger=None
)
inference_config.device = device
preprocessor_agent = InferenceAgent(inference_env, inference_config)
preprocessor_agent.device = device
preprocessor_agent.load_backbone(os.path.join(preprocessor_agent_path,'models','backbone.pth'))
preprocessor_agent.load_actor_weights(os.path.join(preprocessor_agent_path,'models','actor_head.pth'))
preprocessor_agent.load_critics_weights(os.path.join(preprocessor_agent_path,'models','qf1_head.pth'),
os.path.join(preprocessor_agent_path,'models','qf2_head.pth'))
return preprocessor_agent
enhancer_agent = load_preprocessor_agent(MODEL_PATH,DEVICE)
photo_editor = PhotoEditor(SLIDERS)
def enhance_image(image:np.array, params:dict):
input_image = image.unsqueeze(0).to(DEVICE)
parameters = [params[param_name]/100.0 for param_name in SLIDERS_ORD]
parameters = torch.tensor(parameters).unsqueeze(0).to(DEVICE)
if st.session_state.photopro_image is None:
enhanced_image,photopro_image = photo_editor(input_image,parameters,use_photopro_image=False)
st.session_state.photopro_image = photopro_image
else:
enhanced_image = photo_editor(st.session_state.photopro_image,parameters,use_photopro_image=True)
enhanced_image = enhanced_image.squeeze(0).cpu().detach().numpy()
enhanced_image = np.clip(enhanced_image, 0, 1)
enhanced_image = (enhanced_image*255).astype(np.uint8)
return enhanced_image
def auto_enhance(image,deterministic=True):
input_image = image.unsqueeze(0).to(DEVICE)
input_image = input_image.permute(0,3,1,2)
IMAGE_SIZE = enhancer_agent.env.imsize
input_image = F.resize(input_image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=F.InterpolationMode.BICUBIC)
batch_observation = TensorDict(
{
"batch_images":input_image,
},
batch_size = [input_image.shape[0]],
)
parameters = enhancer_agent.act(batch_observation,deterministic=deterministic,n_samples=0)
parameters = parameters.squeeze(0)*100.0
parameters = torch.round(parameters)
output_parameters = []
index = 0
for slider in SLIDERS_ORD:
if slider in enhancer_agent.env.edit_sliders:
output_parameters.append(parameters[index].item())
index += 1
else:
output_parameters.append(0)
return output_parameters
def slider_callback():
st.session_state.apply_button_enabled = True
def apply_button_callback():
for name in SLIDERS:
st.session_state.params[name] = st.session_state[f"slider_{name}"]
image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params)
st.session_state.apply_button_enabled = False
def auto_random_enhance_callback():
image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
auto_params = auto_enhance(image_tensor,deterministic=False)
for i, name in enumerate(SLIDERS_ORD):
st.session_state[f"slider_{name}"] = int(auto_params[i])
st.session_state.params[name] = int(auto_params[i])
st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params)
def auto_enhance_callback():
image_tensor = torch.from_numpy(st.session_state.original_image).float() / 255.0
auto_params = auto_enhance(image_tensor)
for i, name in enumerate(SLIDERS_ORD):
st.session_state[f"slider_{name}"] = int(auto_params[i])
st.session_state.params[name] = int(auto_params[i])
st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params)
def reset_sliders():
for name in SLIDERS:
st.session_state[f"slider_{name}"] = 0
st.session_state.params[name] = 0
# st.session_state.enhanced_image = enhance_image(image_tensor, st.session_state.params)
st.session_state.enhanced_image = st.session_state.original_image
def reset_on_upload():
st.session_state.original_image = None
st.session_state.photopro_image = None
st.session_state.file_extension = None
st.session_state.mime_type = None
reset_sliders()
def create_smooth_histogram(image):
# Compute histograms for each channel
bins = np.linspace(0, 255, 256)
hist_r, _ = np.histogram(image[..., 0], bins=bins)
hist_g, _ = np.histogram(image[..., 1], bins=bins)
hist_b, _ = np.histogram(image[..., 2], bins=bins)
# Normalize the histograms
def normalize_histogram(hist):
hist_central = hist[1:-1]
hist_max = np.max(hist_central)
hist_min = np.min(hist_central)
hist_normalized = (hist_central - hist_min) / (hist_max - hist_min)
hist[0] = min(hist[0] / hist_max, 1)
hist[-1] = min(hist[-1] / hist_max, 1)
return np.concatenate(([hist[0]], hist_normalized, [hist[-1]]))
hist_r_norm = normalize_histogram(hist_r)
hist_g_norm = normalize_histogram(hist_g)
hist_b_norm = normalize_histogram(hist_b)
# Create Bokeh figure with transparent background
p = figure(width=300, height=150, toolbar_location=None,
x_range=(0, 255), y_range=(0, 1.1),
background_fill_color=None,
border_fill_color=None,
outline_line_color=None)
# Remove all axes, labels, and grids
p.axis.visible = False
p.xgrid.grid_line_color = None
p.ygrid.grid_line_color = None
# Create ColumnDataSource for each channel
source_r = ColumnDataSource(data=dict(left=bins[:-1], right=bins[1:], top=hist_r_norm))
source_g = ColumnDataSource(data=dict(left=bins[:-1], right=bins[1:], top=hist_g_norm))
source_b = ColumnDataSource(data=dict(left=bins[:-1], right=bins[1:], top=hist_b_norm))
# Plot the histograms
p.quad(bottom=0, top='top', left='left', right='right', source=source_r,
fill_color="red", fill_alpha=0.9, line_color=None)
p.quad(bottom=0, top='top', left='left', right='right', source=source_g,
fill_color="green", fill_alpha=0.9, line_color=None)
p.quad(bottom=0, top='top', left='left', right='right', source=source_b,
fill_color="blue", fill_alpha=0.9, line_color=None)
# Remove padding
p.min_border_left = 0
p.min_border_right = 0
p.min_border_top = 0
p.min_border_bottom = 0
return p
# In your Streamlit app
def plot_histogram_streamlit(image):
histogram = create_smooth_histogram(image)
st.sidebar.bokeh_chart(histogram, use_container_width=True)
# Initialize session state
if 'enhanced_image' not in st.session_state:
st.session_state.enhanced_image = None
if 'original_image' not in st.session_state:
st.session_state.original_image = None
if 'photopro_image' not in st.session_state:
st.session_state.photopro_image = None
if 'params' not in st.session_state:
st.session_state.params = {name: 0 for name in SLIDERS}
if "apply_button_enabled" not in st.session_state:
st.session_state.apply_button_enabled = False
if "uploaded_file" not in st.session_state:
st.session_state.uploaded_file = None
if 'file_extension' not in st.session_state:
st.session_state.file_extension = None
for name in SLIDERS:
if f"slider_{name}" not in st.session_state:
st.session_state[f"slider_{name}"] = 0
for name in SLIDERS:
if f"slider_{name}" not in st.session_state:
st.session_state[f"slider_{name}"] = 0
# Set up the Streamlit app
# File uploader in the main area
_, center_col,_ = st.columns([1, 2, 1])
if st.session_state.original_image is None:
with center_col:
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "tif", "tiff"], on_change=reset_on_upload)
# Show welcome content only if no file is uploaded
if uploaded_file is None:
st.title("Welcome to the AI Photo Enhancement App")
st.write("This application uses Reinforcement Learning to enhance your photos with ease and flexibility. Below, you'll find a short tutorial video on how to use the app.")
st.write("You can find the project repository here: [AI Photo Enhancer GitHub](https://github.com/zakaria-narjis/ai-photo-enhancer).")
# st.info("Please note that this app is currently limited to 8-bit images. I'm working on extending it to handle 16-bit images.", icon="ℹ️")
st.video("demo.mp4", muted=True)
st.subheader("How to Use the App:")
st.markdown("""
1. **Upload your image** to get started.
2. Proceed to the **editing phase**, where you can:
- View the color histograms and adjust sliders for various tonal values.
- Use the **Auto Enhance** button to let the agent automatically enhance your image.
- Use the **Auto Random Enhance** button to receive different suggested editing parameters each time.
- **Manually refine** your image by adjusting slider values and clicking the **Apply Edits** button at the bottom of the sidebar.
3. **Download** your enhanced image once you're satisfied with the results.
""", unsafe_allow_html=True)
st.write("Enjoy! :smile:")
elif st.session_state.original_image is None: # Process the uploaded file
# Get and store the file extension
file_ext = uploaded_file.name.split('.')[-1].lower()
st.session_state.file_extension = file_ext
# Convert tif/tiff to appropriate mime type
if file_ext in ['tif', 'tiff']:
st.session_state.mime_type = 'image/tiff'
elif file_ext == 'jpg':
st.session_state.file_extension = 'jpeg'
st.session_state.mime_type = 'image/jpeg'
else:
st.session_state.mime_type = f'image/{file_ext}'
st.session_state.original_image = np.array(Image.open(uploaded_file).convert('RGB'), dtype=np.uint16)
if st.session_state.original_image is not None:
# Load the original image
# st.session_state.original_image = np.array(Image.open(st.session_state.original_image).convert('RGB'),dtype=np.uint16)
# Enhance the image initially
if st.session_state.enhanced_image is None:
st.session_state.enhanced_image = st.session_state.original_image
# Sidebar for controls
st.sidebar.title("Controls")
# Display histogram
st.sidebar.subheader("Colors Histogram")
plot_histogram_streamlit(st.session_state.enhanced_image)
# Select box to choose which image to display
display_option = st.sidebar.selectbox(
"Select view mode",
("Comparison", "Enhanced")
)
# Create two columns for the buttons
col1, col2,col3 = st.sidebar.columns(3)
# Button for auto-enhancement
with col1:
st.button("Auto Enhance", on_click=auto_enhance_callback, key="auto_enhance_button",use_container_width=True)
with col2:
st.button("Auto Random Enhance", on_click=auto_random_enhance_callback, key="auto_random_enhance_button",use_container_width=True)
# Button for resetting sliders
with col3:
st.button("Reset", on_click=reset_sliders, key="reset_button",use_container_width=True)
st.sidebar.subheader("Adjustments")
slider_names = SLIDERS
for name in slider_names:
if f"slider_{name}" not in st.session_state:
st.session_state[f"slider_{name}"] = 0
st.sidebar.slider(
name.capitalize(),
min_value=-100,
max_value=100,
value=st.session_state[f"slider_{name}"],
key=f"slider_{name}",
on_change=slider_callback
)
st.sidebar.button("Apply manual edit", on_click=apply_button_callback, key="apply_button",use_container_width=True,disabled=not st.session_state.apply_button_enabled)
# Create a single column to maximize width
left_spacer, content_column, right_spacer = st.columns([1, 3, 1])
with content_column:
if display_option == "Enhanced":
if st.session_state.enhanced_image is not None:
st.image(st.session_state.enhanced_image.astype(np.uint8), caption="Enhanced Image", use_column_width=True)
else:
st.warning("Enhanced image is not available. Try adjusting the sliders or clicking 'Auto Enhance'.")
else: # Comparison view
if st.session_state.enhanced_image is not None:
image_comparison(
img1=Image.fromarray(st.session_state.original_image.astype(np.uint8)),
img2=Image.fromarray(st.session_state.enhanced_image.astype(np.uint8)),
label1="Original",
label2="Enhanced",
width=850, # You might want to adjust this value
starting_position=50,
show_labels=True,
make_responsive=True,
)
else:
st.warning("Enhanced image is not available for comparison. Try adjusting the sliders or clicking 'Auto Enhance'.")
with io.BytesIO() as img_bytes:
enhanced_img = Image.fromarray(st.session_state.enhanced_image.astype(np.uint8))
enhanced_img.save(img_bytes, format=f"{st.session_state.file_extension}")
st.download_button(
label="Download Enhanced Image",
data=img_bytes.getvalue(),
file_name=f"enhanced.{st.session_state.file_extension}".lower(),
mime=st.session_state.mime_type,
use_container_width=True
)
# Add custom CSS to make the image comparison component responsive
st.markdown("""
""", unsafe_allow_html=True)