zakaria-narjis commited on
Commit
46e8495
·
1 Parent(s): 5576428

add photopro image caching

Browse files
Files changed (2) hide show
  1. demo.py +13 -5
  2. src/envs/edit_photo_opt.py +591 -0
demo.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from PIL import Image
4
  import numpy as np
5
  from streamlit_image_comparison import image_comparison
6
- from src.envs.new_edit_photo import PhotoEditor
7
  from src.sac.sac_inference import InferenceAgent
8
  import yaml
9
  import os
@@ -15,12 +15,13 @@ import pandas as pd
15
  from bokeh.plotting import figure
16
  from bokeh.models import ColumnDataSource
17
  from bokeh.palettes import Spectral3
 
18
  # Set page config to wide mode
19
  st.set_page_config(layout="wide")
20
 
21
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
- # DEVICE = torch.device("cpu")
23
- MODEL_PATH = "experiments/runs/ResNet_10_sliders__224_128_aug__2024-07-23_21-23-35"
24
  SLIDERS = ['temp','tint','exposure', 'contrast','highlights','shadows', 'whites', 'blacks','vibrance','saturation']
25
  SLIDERS_ORD = ['contrast','exposure','temp','tint','whites','blacks','highlights','shadows','vibrance','saturation']
26
 
@@ -73,7 +74,11 @@ def enhance_image(image:np.array, params:dict):
73
  input_image = image.unsqueeze(0).to(DEVICE)
74
  parameters = [params[param_name]/100.0 for param_name in SLIDERS_ORD]
75
  parameters = torch.tensor(parameters).unsqueeze(0).to(DEVICE)
76
- enhanced_image = photo_editor(input_image,parameters)
 
 
 
 
77
  enhanced_image = enhanced_image.squeeze(0).cpu().detach().numpy()
78
  enhanced_image = np.clip(enhanced_image, 0, 1)
79
  enhanced_image = (enhanced_image*255).astype(np.uint8)
@@ -134,6 +139,7 @@ def reset_sliders():
134
 
135
  def reset_on_upload():
136
  st.session_state.original_image = None
 
137
  reset_sliders()
138
 
139
  def create_smooth_histogram(image):
@@ -202,6 +208,8 @@ if 'enhanced_image' not in st.session_state:
202
  st.session_state.enhanced_image = None
203
  if 'original_image' not in st.session_state:
204
  st.session_state.original_image = None
 
 
205
  if 'params' not in st.session_state:
206
  st.session_state.params = {name: 0 for name in SLIDERS}
207
  for name in SLIDERS:
 
3
  from PIL import Image
4
  import numpy as np
5
  from streamlit_image_comparison import image_comparison
6
+ # from src.envs.new_edit_photo import PhotoEditor
7
  from src.sac.sac_inference import InferenceAgent
8
  import yaml
9
  import os
 
15
  from bokeh.plotting import figure
16
  from bokeh.models import ColumnDataSource
17
  from bokeh.palettes import Spectral3
18
+ from src.envs.edit_photo_opt import PhotoEditor
19
  # Set page config to wide mode
20
  st.set_page_config(layout="wide")
21
 
22
+ # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ DEVICE = torch.device("cpu")
24
+ MODEL_PATH = "experiments/ResNet_10_sliders__224_128_aug__2024-07-23_21-23-35"
25
  SLIDERS = ['temp','tint','exposure', 'contrast','highlights','shadows', 'whites', 'blacks','vibrance','saturation']
26
  SLIDERS_ORD = ['contrast','exposure','temp','tint','whites','blacks','highlights','shadows','vibrance','saturation']
27
 
 
74
  input_image = image.unsqueeze(0).to(DEVICE)
75
  parameters = [params[param_name]/100.0 for param_name in SLIDERS_ORD]
76
  parameters = torch.tensor(parameters).unsqueeze(0).to(DEVICE)
77
+ if st.session_state.photopro_image is None:
78
+ enhanced_image,photopro_image = photo_editor(input_image,parameters,use_photopro_image=False)
79
+ st.session_state.photopro_image = photopro_image
80
+ else:
81
+ enhanced_image = photo_editor(st.session_state.photopro_image,parameters,use_photopro_image=True)
82
  enhanced_image = enhanced_image.squeeze(0).cpu().detach().numpy()
83
  enhanced_image = np.clip(enhanced_image, 0, 1)
84
  enhanced_image = (enhanced_image*255).astype(np.uint8)
 
139
 
140
  def reset_on_upload():
141
  st.session_state.original_image = None
142
+ st.session_state.photopro_image = None
143
  reset_sliders()
144
 
145
  def create_smooth_histogram(image):
 
208
  st.session_state.enhanced_image = None
209
  if 'original_image' not in st.session_state:
210
  st.session_state.original_image = None
211
+ if 'photopro_image' not in st.session_state:
212
+ st.session_state.photopro_image = None
213
  if 'params' not in st.session_state:
214
  st.session_state.params = {name: 0 for name in SLIDERS}
215
  for name in SLIDERS:
src/envs/edit_photo_opt.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ import cv2
5
+ try:
6
+ from .dehaze.src import dehaze
7
+ except:
8
+ from dehaze.src import dehaze
9
+ import streamlit as st
10
+ # def numpy_sigmoid(x):
11
+ # return 1/(1+np.exp(-x))
12
+
13
+ def sigmoid_inverse(y):
14
+ epsilon = 10**(-3)
15
+ y = F.relu(y-epsilon)+epsilon
16
+ y = 1-epsilon-F.relu((1-epsilon)-y)
17
+ y = (1/y)-1
18
+ output = -torch.log(y)
19
+ return output
20
+ class Sigmoid():
21
+ def __init__(self):
22
+ self.num_parameters = 0
23
+ def __call__(self,images):
24
+ return torch.sigmoid(images)
25
+ class SigmoidInverse():
26
+
27
+ def __init__(self):
28
+ self.num_parameters = 0
29
+
30
+ def __call__(self, images):
31
+ return sigmoid_inverse(images)
32
+
33
+
34
+ new_sig_inv = SigmoidInverse()
35
+
36
+ class AdjustContrast():
37
+ def __init__(self):
38
+ self.num_parameters = 1
39
+ self.window_names = ["parameter"]
40
+ self.slider_names = ["contrast"]
41
+
42
+ def __call__(self, images:torch.Tensor, parameters:torch.Tensor):
43
+
44
+ assert images.dim()==4
45
+ assert images.shape[0]==parameters.shape[0]
46
+
47
+ batch_size = parameters.shape[0]
48
+ mean = images.view(batch_size,-1).mean(1)
49
+ mean = mean.view(batch_size, 1, 1, 1)
50
+ parameters = parameters.view(batch_size, 1, 1, 1)
51
+ editted = (images-mean)*(parameters+1)+mean
52
+ editted = F.relu(editted)
53
+ editted = 1-F.relu(1-editted)
54
+ return editted
55
+
56
+
57
+ class AdjustDehaze():
58
+
59
+ def __init__(self):
60
+ self.num_parameters = 1
61
+ self.window_names = ["parameter"]
62
+ self.slider_names = ["dehaze"]
63
+
64
+ def __call__(self, images, parameters):
65
+ """
66
+ Takes a batch of images where B (the last dim) is the batch size
67
+ args:
68
+ images: torch.Tensor # B H W C
69
+ parameters :torch.Tensor # N
70
+ return:
71
+ output: torch.Tensor # B H W C
72
+ """
73
+ assert images.dim()==4
74
+ batch_size = parameters.shape[0]
75
+ output = []
76
+ for image_index in range(batch_size):
77
+ image = images[image_index].numpy()
78
+ scale = max((image.shape[:2])) / 512.0
79
+ omega = float(parameters[image_index])
80
+ editted= dehaze.DarkPriorChannelDehaze(
81
+ wsize=int(15*scale), radius=int(80*scale), omega=omega,
82
+ t_min=0.25, refine=True)(image * 255.0) / 255.0
83
+ editted = torch.tensor(editted)
84
+ editted = F.relu(editted)
85
+ editted= 1-F.relu(1-editted)
86
+ output.append(editted)
87
+ output = torch.stack(output)
88
+ return output
89
+
90
+ class AdjustClarity():
91
+ def __init__(self):
92
+ self.num_parameters = 1
93
+ self.window_names = ["parameter"]
94
+ self.slider_names = ["clarity"]
95
+
96
+ def __call__(self, images, parameters):
97
+ """
98
+ Takes a batch of images where B (the last dim) is the batch size
99
+ args:
100
+ images: torch.Tensor # B H W C
101
+ parameters :torch.Tensor # N
102
+ return:
103
+ output: torch.Tensor # B H W C
104
+ """
105
+ assert images.dim()==4
106
+ batch_size = parameters.shape[0]
107
+ output = []
108
+ clarity = parameters.view(batch_size, 1, 1, 1)
109
+ for image in images:
110
+ input = image.numpy()
111
+ scale = max((input.shape[:2])) / 512.0
112
+ unsharped = cv2.bilateralFilter((input*255.0).astype(np.uint8),
113
+ int(32*scale), 50, 10*scale)/255.0
114
+ output.append(torch.tensor(unsharped))
115
+ output = torch.stack(output)
116
+ editted_images = images + (images-output) * clarity
117
+
118
+ return editted_images
119
+
120
+ class AdjustExposure():
121
+ def __init__(self):
122
+ self.num_parameters = 1
123
+ self.window_names = ["parameter"]
124
+ self.slider_names = ["exposure"]
125
+
126
+ def __call__(self, images, parameters):
127
+ batch_size = parameters.shape[0]
128
+ exposure = parameters.view(batch_size, 1, 1, 1)
129
+ output = images+exposure*5
130
+ return output
131
+
132
+ class AdjustTemp():
133
+ def __init__(self):
134
+ self.num_parameters = 1
135
+ self.window_names = ["parameter"]
136
+ self.slider_names = ["temp"]
137
+
138
+ def __call__(self, images, parameters):
139
+ batch_size = parameters.shape[0]
140
+ temp = parameters.view(batch_size, 1, 1, 1)
141
+ editted = torch.clone(images)
142
+
143
+ index_high = (temp>0).view(-1)
144
+ index_low = (temp<=0).view(-1)
145
+
146
+ editted[index_high,:,:,1] += temp[index_high,:,:,0]*1.6
147
+ editted[index_high,:,:,2] += temp[index_high,:,:,0]*2
148
+ editted[index_low,:,:,0] -= temp[index_low,:,:,0]*2.0
149
+ editted[index_low,:,:,1] -= temp[index_low,:,:,0]*1.0
150
+
151
+ return editted
152
+ class AdjustTint():
153
+ def __init__(self):
154
+ self.num_parameters = 1
155
+ self.window_names = ["parameter"]
156
+ self.slider_names = ["tint"]
157
+
158
+ def __call__(self, images, parameters):
159
+ batch_size = parameters.shape[0]
160
+ tint = parameters.view(batch_size, 1, 1, 1)
161
+ editted = torch.clone(images)
162
+
163
+ index_high = (tint>0).view(-1)
164
+ index_low = (tint<=0).view(-1)
165
+
166
+ editted[index_high,:,:,0] += tint[index_high,:,:,0]*2
167
+ editted[index_high,:,:,2] += tint[index_high,:,:,0]*1
168
+ editted[index_low,:,:,1] -= tint[index_low,:,:,0]*2
169
+ editted[index_low,:,:,2] -= tint[index_low,:,:,0]*1
170
+
171
+ return editted
172
+ class AdjustShadows:
173
+ def __init__(self):
174
+ self.num_parameters = 1
175
+ self.window_names = ["parameter"]
176
+ self.slider_names = ["shadows"]
177
+
178
+ def __call__(self, list_hsv, parameters):
179
+ batch_size = parameters.shape[0]
180
+ shadows = parameters.view(batch_size, 1, 1)
181
+
182
+ v = list_hsv[2]
183
+
184
+ # Calculate shadows mask
185
+
186
+ shadows_mask = 1 - torch.sigmoid((v - 0.0) * 5.0)
187
+ # Adjust v channel based on shadows mask
188
+ adjusted_v = v * (1 + shadows_mask * shadows * 5.0)
189
+
190
+ return [list_hsv[0], list_hsv[1], adjusted_v]
191
+
192
+ class AdjustHighlights: # I should change the sigmoid to torch.sigmoid
193
+ def __init__(self):
194
+ self.num_parameters = 1
195
+ self.window_names = ["parameter"]
196
+ self.slider_names = ["highlights"]
197
+
198
+ # def custom_sigmoid(self, x):
199
+ # return 1 / (1 + torch.exp(-x))
200
+
201
+ def __call__(self, list_hsv, parameters):
202
+ batch_size = parameters.shape[0]
203
+ highlights = parameters.view(batch_size, 1, 1)
204
+
205
+ v = list_hsv[2]
206
+
207
+ # Calculate highlights mask using custom sigmoid function
208
+ highlights_mask = torch.sigmoid((v - 1) * 5)
209
+
210
+ # Adjust v channel based on highlights mask
211
+ adjusted_v = 1 - (1 - v) * (1 - highlights_mask * highlights * 5)
212
+
213
+ return [list_hsv[0], list_hsv[1], adjusted_v]
214
+
215
+
216
+ class AdjustBlacks:
217
+ def __init__(self):
218
+ self.num_parameters = 1
219
+ self.window_names = ["parameter"]
220
+ self.slider_names = ["blacks"]
221
+
222
+ def __call__(self, list_hsv, parameters):
223
+ batch_size = parameters.shape[0]
224
+ blacks = parameters.view(batch_size, 1, 1)
225
+ blacks = blacks + 1
226
+ v = list_hsv[2]
227
+
228
+ # Calculate the adjustment factor
229
+ adjustment_factor = (torch.sqrt(blacks) - 1) * 0.2
230
+
231
+ # Adjust the v channel
232
+ adjusted_v = v + (1 - v) * adjustment_factor
233
+
234
+ return [list_hsv[0], list_hsv[1], adjusted_v]
235
+
236
+ class AdjustWhites:
237
+ def __init__(self):
238
+ self.num_parameters = 1
239
+ self.window_names = ["parameter"]
240
+ self.slider_names = ["whites"]
241
+
242
+ def __call__(self, list_hsv, parameters):
243
+ batch_size = parameters.shape[0]
244
+ whites= parameters.view(batch_size, 1, 1)
245
+ whites= whites+ 1
246
+ v = list_hsv[2]
247
+
248
+ # Calculate the adjustment factor
249
+ adjustment_factor = (torch.sqrt(whites) - 1) * 0.2
250
+
251
+ # Adjust the v channel
252
+ adjusted_v = v + v * adjustment_factor
253
+
254
+ return [list_hsv[0], list_hsv[1], adjusted_v]
255
+
256
+ class Bgr2Hsv:
257
+ def __init__(self):
258
+ self.num_parameters = 0
259
+
260
+ def __call__(self, images):
261
+ editted = images
262
+
263
+ max_bgr, _ = editted.max(dim=-1, keepdim=True)
264
+ min_bgr, _ = editted.min(dim=-1, keepdim=True)
265
+
266
+ b = editted[..., 0]
267
+ g = editted[..., 1]
268
+ r = editted[..., 2]
269
+
270
+ b_g = b - g
271
+ g_r = g - r
272
+ r_b = r - b
273
+
274
+ b_min_flg = (1 - F.relu(torch.sign(b_g))) * F.relu(torch.sign(r_b))
275
+ g_min_flg = (1 - F.relu(torch.sign(g_r))) * F.relu(torch.sign(b_g))
276
+ r_min_flg = (1 - F.relu(torch.sign(r_b))) * F.relu(torch.sign(g_r))
277
+
278
+ epsilon = 10**(-5)
279
+ h1 = 60 * g_r / (max_bgr.squeeze() - min_bgr.squeeze() + epsilon) + 60
280
+ h2 = 60 * b_g / (max_bgr.squeeze() - min_bgr.squeeze() + epsilon) + 180
281
+ h3 = 60 * r_b / (max_bgr.squeeze() - min_bgr.squeeze() + epsilon) + 300
282
+ h = h1 * b_min_flg + h2 * r_min_flg + h3 * g_min_flg
283
+
284
+ v = max_bgr.squeeze()
285
+ s = (max_bgr.squeeze() - min_bgr.squeeze()) / (max_bgr.squeeze() + epsilon)
286
+
287
+ return [h, s, v]
288
+
289
+ class AdjustVibrance:
290
+ def __init__(self):
291
+ self.num_parameters = 1
292
+ self.window_names = ["parameter"]
293
+ self.slider_names = ["vibrance"]
294
+
295
+ def __call__(self, list_hsv, parameters):
296
+ batch_size = parameters.shape[0]
297
+ vibrance= parameters.view(batch_size, 1, 1)
298
+ vibrance = vibrance + 1
299
+ s = list_hsv[1]
300
+
301
+ # Calculate vibrance flag using custom sigmoid function
302
+ vibrance_flg = -torch.sigmoid((s - 0.5) * 10) + 1
303
+
304
+ # Adjust the s channel
305
+ adjusted_s = s * vibrance * vibrance_flg + s * (1 - vibrance_flg)
306
+
307
+ return [list_hsv[0], adjusted_s, list_hsv[2]]
308
+
309
+ class AdjustSaturation:
310
+ def __init__(self):
311
+ self.num_parameters = 1
312
+ self.window_names = ["parameter"]
313
+ self.slider_names = ["saturation"]
314
+
315
+ def __call__(self, list_hsv, parameters):
316
+ batch_size = parameters.shape[0]
317
+ saturation = parameters.view(batch_size, 1, 1)
318
+ saturation = saturation+ 1
319
+ s = list_hsv[1]
320
+
321
+ # Adjust the saturation
322
+ s_ = s * saturation
323
+ s_ = F.relu(s_)
324
+ s_ = 1 - F.relu(1 - s_)
325
+
326
+ return [list_hsv[0], s_, list_hsv[2]]
327
+
328
+ class Hsv2Bgr:
329
+ def __init__(self):
330
+ self.num_parameters = 0
331
+
332
+ def __call__(self, list_hsv):
333
+ h, s, v = list_hsv
334
+
335
+ # Adjust h values
336
+ h = h * torch.relu(torch.sign(h-0)) * (1 - torch.relu(torch.sign(h-360))) + \
337
+ (h-360) * torch.relu(torch.sign(h-360)) * (1 - torch.relu(torch.sign(h-720))) + \
338
+ (h+360) * torch.relu(torch.sign(h+360)) * (1 - torch.relu(torch.sign(h-0)))
339
+
340
+ # Calculate h flags
341
+ h60_flg = torch.relu(torch.sign(h-0)) * (1 - torch.relu(torch.sign(h-60)))
342
+ h120_flg = torch.relu(torch.sign(h-60)) * (1 - torch.relu(torch.sign(h-120)))
343
+ h180_flg = torch.relu(torch.sign(h-120)) * (1 - torch.relu(torch.sign(h-180)))
344
+ h240_flg = torch.relu(torch.sign(h-180)) * (1 - torch.relu(torch.sign(h-240)))
345
+ h300_flg = torch.relu(torch.sign(h-240)) * (1 - torch.relu(torch.sign(h-300)))
346
+ h360_flg = torch.relu(torch.sign(h-300)) * (1 - torch.relu(torch.sign(h-360)))
347
+
348
+ C = v * s
349
+ b = v - C + C * (h240_flg + h300_flg) + C * ((h / 60 - 2) * h180_flg + (6 - h / 60) * h360_flg)
350
+ g = v - C + C * (h120_flg + h180_flg) + C * ((h / 60) * h60_flg + (4 - h / 60) * h240_flg)
351
+ r = v - C + C * (h60_flg + h360_flg) + C * ((h / 60 - 4) * h300_flg + (2 - h / 60) * h120_flg)
352
+
353
+ # Add an extra dimension to b, g, r to concatenate them correctly
354
+ b = b.unsqueeze(-1)
355
+ g = g.unsqueeze(-1)
356
+ r = r.unsqueeze(-1)
357
+
358
+ bgr = torch.cat([b, g, r], dim=-1)
359
+
360
+ return bgr
361
+
362
+ # class Srgb2Photopro:
363
+ # def __init__(self):
364
+ # self.num_parameters = 0
365
+
366
+ # def __call__(self, images):
367
+ # srgb = images.clone()
368
+ # k = 0.055
369
+ # thre_srgb = 0.04045
370
+
371
+ # a = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
372
+ # [0.2126729, 0.7151522, 0.0721750],
373
+ # [0.0193339, 0.1191920, 0.9503041]], dtype=torch.float32)
374
+ # b = torch.tensor([[1.3459433, -0.2556075, -0.0511118],
375
+ # [-0.5445989, 1.5081673, 0.0205351],
376
+ # [0.0000000, 0.0000000, 1.2118128]], dtype=torch.float32)
377
+
378
+ # M = torch.matmul(b, a)
379
+ # M = M / M.sum(dim=1, keepdim=True)
380
+
381
+ # thre_photopro = 1 / 512.0
382
+
383
+ # # sRGB to linear RGB
384
+ # srgb = torch.where(srgb <= thre_srgb, srgb / 12.92, ((srgb + k) / (1 + k)) ** 2.4)
385
+
386
+ # sb = srgb[..., 0:1]
387
+ # sg = srgb[..., 1:2]
388
+ # sr = srgb[..., 2:3]
389
+
390
+ # photopror = sr * M[0][0] + sg * M[0][1] + sb * M[0][2]
391
+ # photoprog = sr * M[1][0] + sg * M[1][1] + sb * M[1][2]
392
+ # photoprob = sr * M[2][0] + sg * M[2][1] + sb * M[2][2]
393
+
394
+ # photopro = torch.cat((photoprob, photoprog, photopror), dim=-1)
395
+ # photopro = torch.clamp(photopro, 0, 1)
396
+ # photopro = torch.where(photopro >= thre_photopro, photopro ** (1 / 1.8), photopro * 16)
397
+
398
+ # return photopro
399
+
400
+ class Srgb2Photopro:
401
+ def __init__(self):
402
+ self.num_parameters = 0
403
+ k = 0.055
404
+ thre_srgb = 0.04045
405
+
406
+ self.k = k
407
+ self.thre_srgb = thre_srgb
408
+ self.thre_photopro = 1 / 512.0
409
+
410
+ # Transformation matrices
411
+ a = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
412
+ [0.2126729, 0.7151522, 0.0721750],
413
+ [0.0193339, 0.1191920, 0.9503041]], dtype=torch.float32)
414
+ b = torch.tensor([[1.3459433, -0.2556075, -0.0511118],
415
+ [-0.5445989, 1.5081673, 0.0205351],
416
+ [0.0000000, 0.0000000, 1.2118128]], dtype=torch.float32)
417
+
418
+ self.M = torch.matmul(b, a)
419
+ self.M = self.M / self.M.sum(dim=1, keepdim=True)
420
+ def __call__(self, images):
421
+ srgb = images.clone()
422
+
423
+ with torch.no_grad(): # Disable gradient computation for inference
424
+ # sRGB to linear RGB
425
+ srgb = torch.where(srgb <= self.thre_srgb, srgb / 12.92, ((srgb + self.k) / (1 + self.k)) ** 2.4)
426
+
427
+ sb = srgb[..., 0:1]
428
+ sg = srgb[..., 1:2]
429
+ sr = srgb[..., 2:3]
430
+
431
+ # Apply the transformation matrix
432
+ photopror = sr * self.M[0][0] + sg * self.M[0][1] + sb * self.M[0][2]
433
+ photoprog = sr * self.M[1][0] + sg * self.M[1][1] + sb * self.M[1][2]
434
+ photoprob = sr * self.M[2][0] + sg * self.M[2][1] + sb * self.M[2][2]
435
+
436
+ photopro = torch.cat((photoprob, photoprog, photopror), dim=-1)
437
+ photopro = torch.clamp(photopro, 0, 1)
438
+
439
+ # Apply the Photopro gamma correction
440
+ photopro = torch.where(photopro >= self.thre_photopro, photopro ** (1 / 1.8), photopro * 16)
441
+
442
+ # Clear intermediate tensors
443
+
444
+ return photopro
445
+
446
+ # class Photopro2Srgb:
447
+ # def __init__(self):
448
+ # self.num_parameters = 0
449
+
450
+ # def __call__(self, photopro_tensor):
451
+ # photopro = photopro_tensor.clone() # Make a copy to avoid modifying the input tensor
452
+ # thre_photopro = 1/512.0*16
453
+
454
+ # a = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
455
+ # [0.2126729, 0.7151522, 0.0721750],
456
+ # [0.0193339, 0.1191920, 0.9503041]], dtype=torch.float32)
457
+ # b = torch.tensor([[1.3459433, -0.2556075, -0.0511118],
458
+ # [-0.5445989, 1.5081673, 0.0205351],
459
+ # [0.0000000, 0.0000000, 1.2118128]], dtype=torch.float32)
460
+ # M = torch.matmul(b, a)
461
+ # M = M / M.sum(dim=1, keepdim=True)
462
+ # M = torch.linalg.inv(M)
463
+ # k = 0.055
464
+ # thre_srgb = 0.04045 / 12.92
465
+
466
+ # # Apply transformations
467
+ # mask = photopro < thre_photopro
468
+ # photopro[mask] *= 1.0 / 16
469
+ # photopro[~mask] = photopro[~mask] ** 1.8
470
+
471
+ # photoprob = photopro[:, :, :, 0:1]
472
+ # photoprog = photopro[:, :, :, 1:2]
473
+ # photopror = photopro[:, :, :, 2:3]
474
+
475
+ # sr = photopror * M[0, 0] + photoprog * M[0, 1] + photoprob * M[0, 2]
476
+ # sg = photopror * M[1, 0] + photoprog * M[1, 1] + photoprob * M[1, 2]
477
+ # sb = photopror * M[2, 0] + photoprog * M[2, 1] + photoprob * M[2, 2]
478
+
479
+ # srgb = torch.cat((sb, sg, sr), dim=-1)
480
+
481
+ # # Clip and apply final transformations
482
+ # srgb = torch.clamp(srgb, 0, 1)
483
+ # mask = srgb > thre_srgb
484
+ # srgb[mask] = (1 + k) * srgb[mask] ** (1 / 2.4) - k
485
+ # srgb[~mask] *= 12.92
486
+
487
+ # return srgb
488
+
489
+ class Photopro2Srgb:
490
+ def __init__(self):
491
+ self.num_parameters = 0
492
+ self.k = 0.055
493
+ self.thre_srgb = 0.04045 / 12.92
494
+ self.thre_photopro = 1 / 512.0 * 16
495
+
496
+ # Transformation matrices
497
+ a = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
498
+ [0.2126729, 0.7151522, 0.0721750],
499
+ [0.0193339, 0.1191920, 0.9503041]], dtype=torch.float32)
500
+ b = torch.tensor([[1.3459433, -0.2556075, -0.0511118],
501
+ [-0.5445989, 1.5081673, 0.0205351],
502
+ [0.0000000, 0.0000000, 1.2118128]], dtype=torch.float32)
503
+
504
+ self.M = torch.matmul(b, a)
505
+ self.M = self.M / self.M.sum(dim=1, keepdim=True)
506
+ self.M_inv = torch.linalg.inv(self.M)
507
+
508
+ def __call__(self, photopro_tensor):
509
+ with torch.no_grad(): # Disable gradient computation for inference
510
+ photopro = photopro_tensor.clone() # Make a copy to avoid modifying the input tensor
511
+ # photopro = photopro.to(torch.float16)
512
+ # Apply gamma correction
513
+ mask = photopro < self.thre_photopro
514
+ photopro[mask] *= 1.0 / 16
515
+ photopro[~mask] = photopro[~mask] ** 1.8
516
+
517
+ # Separate channels
518
+ photoprob = photopro[..., 0:1]
519
+ photoprog = photopro[..., 1:2]
520
+ photopror = photopro[..., 2:3]
521
+
522
+ # Apply the inverse transformation matrix
523
+ sr = photopror * self.M_inv[0, 0] + photoprog * self.M_inv[0, 1] + photoprob * self.M_inv[0, 2]
524
+ sg = photopror * self.M_inv[1, 0] + photoprog * self.M_inv[1, 1] + photoprob * self.M_inv[1, 2]
525
+ sb = photopror * self.M_inv[2, 0] + photoprog * self.M_inv[2, 1] + photoprob * self.M_inv[2, 2]
526
+ del photopror, photoprog, photoprob
527
+ srgb = torch.cat((sb, sg, sr), dim=-1)
528
+ del sr, sg, sb
529
+ # Apply sRGB transformation
530
+ srgb = torch.clamp(srgb, 0, 1)
531
+ mask = srgb > self.thre_srgb
532
+ srgb[mask] = (1 + self.k) * srgb[mask] ** (1 / 2.4) - self.k
533
+ srgb[~mask] *= 12.92
534
+
535
+ # Clear intermediate tensors
536
+ return srgb
537
+
538
+ class PhotoEditor():
539
+ def __init__(self,sliders= 'all'):
540
+ self.edit_funcs = [Srgb2Photopro(), AdjustDehaze(), AdjustClarity(), AdjustContrast(),
541
+ SigmoidInverse(), AdjustExposure(), AdjustTemp(), AdjustTint(),
542
+ Sigmoid(), Bgr2Hsv(), AdjustWhites(), AdjustBlacks(), AdjustHighlights(),
543
+ AdjustShadows(), AdjustVibrance(), AdjustSaturation(), Hsv2Bgr(), Photopro2Srgb()]
544
+ self.sliders = sliders
545
+ self.num_parameters = 0
546
+ if sliders=='all':
547
+ for edit_func in self.edit_funcs:
548
+ self.num_parameters += edit_func.num_parameters
549
+ else:
550
+ for edit_func in self.edit_funcs:
551
+ if edit_func.num_parameters==0:
552
+ self.num_parameters += edit_func.num_parameters
553
+ elif edit_func.slider_names[0] in sliders:
554
+ self.num_parameters += edit_func.num_parameters
555
+
556
+ def __call__(self, images, parameters,use_photopro_image=False):
557
+ editted_images = images.clone()
558
+ num_parameters = 0
559
+ photopro_image = None
560
+ assert images.shape[-1]==3 #make sure that the image shape is (B,H,W,C)
561
+ assert images.dim()==4 #make sure that the image is batched
562
+ for edit_func in self.edit_funcs:
563
+ if use_photopro_image and type(edit_func)==Srgb2Photopro:
564
+ continue
565
+ if self.sliders=='all':
566
+
567
+ if edit_func.num_parameters == 0:
568
+ editted_images = edit_func(editted_images)
569
+ else:
570
+ editted_images = edit_func(editted_images,
571
+ parameters[:,num_parameters : num_parameters + edit_func.num_parameters])
572
+ num_parameters = num_parameters + edit_func.num_parameters
573
+
574
+ else:
575
+
576
+ if edit_func.num_parameters == 0:
577
+ editted_images = edit_func(editted_images)
578
+ else:
579
+ if edit_func.slider_names[0] in self.sliders:
580
+ editted_images = edit_func(editted_images,
581
+ parameters[:,num_parameters : num_parameters + edit_func.num_parameters])
582
+ num_parameters = num_parameters + edit_func.num_parameters
583
+ if type(edit_func)==Srgb2Photopro and use_photopro_image==False:
584
+ photopro_image = editted_images
585
+
586
+ editted_images = editted_images.type(torch.float32)
587
+
588
+ if use_photopro_image:
589
+ return editted_images
590
+ else:
591
+ return editted_images, photopro_image