xinyiW915 commited on
Commit
17f753b
·
verified ·
1 Parent(s): eec997d

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +87 -0
  2. demo_test_gpu.py +249 -0
  3. model_regression.py +693 -0
  4. relax_vqa.py +159 -0
  5. requirements.txt +15 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from spaces import GPU
2
+ import gradio as gr
3
+ import torch
4
+ import os
5
+ import time
6
+ from torchvision import models
7
+ from joblib import load
8
+ from extractor.visualise_vit_layer import VitGenerator
9
+ from relax_vqa import get_deep_feature, process_video_feature, process_patches, get_frame_patches, flow_to_rgb, merge_fragments, concatenate_features
10
+ from extractor.vf_extract import process_video_residual
11
+ from model_regression import Mlp, preprocess_data
12
+ from demo_test_gpu import evaluate_video_quality, load_model
13
+
14
+
15
+ @GPU
16
+ def run_relax_vqa(video_path, is_finetune, framerate, video_type):
17
+ if not os.path.exists(video_path):
18
+ return "❌ No video uploaded or the uploaded file has expired. Please upload again."
19
+
20
+ config = {
21
+ 'is_finetune': is_finetune,
22
+ 'framerate': framerate,
23
+ 'video_type': video_type,
24
+ 'save_path': 'model/',
25
+ 'train_data_name': 'lsvq_train',
26
+ 'select_criteria': 'byrmse',
27
+ 'video_path': video_path,
28
+ 'video_name': os.path.splitext(os.path.basename(video_path))[0]
29
+ }
30
+
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ resnet50 = models.resnet50(pretrained=True).to(device)
33
+ vit = VitGenerator('vit_base', 16, device, evaluate=True, random=False, verbose=False)
34
+ model_mlp = load_model(config, device)
35
+
36
+ try:
37
+ score, runtime = evaluate_video_quality(config, resnet50, vit, model_mlp, device)
38
+ return f"Predicted Quality Score: {score:.4f} (in {runtime:.2f}s)"
39
+ except Exception as e:
40
+ return f"❌ Error: {str(e)}"
41
+ finally:
42
+ if "gradio" in video_path and os.path.exists(video_path):
43
+ os.remove(video_path)
44
+
45
+
46
+ def toggle_dataset_visibility(is_finetune):
47
+ return gr.update(visible=is_finetune)
48
+
49
+
50
+ with gr.Blocks() as demo:
51
+ gr.Markdown("## 🎬 ReLaX-VQA Online Demo")
52
+ gr.Markdown(
53
+ "Upload a short video and get the predicted perceptual quality score using the ReLaX-VQA model. "
54
+ "You can try our demo video from the "
55
+ "<a href='https://huggingface.co/spaces/xinyiW915/ReLaX-VQA/blob/main/ugc_original_videos/5636101558_540p.mp4' target='_blank'>demo video</a> "
56
+ "(fps = 24, dataset = konvid_1k).<br><br>"
57
+ "⚙️ This demo is currently running on <strong>Hugging Face ZeroGPU Space</strong>: Dynamic resources (NVIDIA A100)."
58
+ )
59
+
60
+ with gr.Row():
61
+ with gr.Column(scale=2):
62
+ video_input = gr.Video(label="Upload a Video (e.g. mp4)")
63
+ framerate_slider = gr.Slider(label="Source Video Framerate (fps)", minimum=1, maximum=60, step=1, value=24)
64
+ is_finetune_checkbox = gr.Checkbox(label="Use Finetuning?", value=False)
65
+ dataset_dropdown = gr.Dropdown(
66
+ label="Source Video Dataset for Finetuning",
67
+ choices=["konvid_1k", "youtube_ugc", "live_vqc", "cvd_2014"],
68
+ value="konvid_1k",
69
+ visible=False
70
+ )
71
+ run_button = gr.Button("Run Prediction")
72
+ with gr.Column(scale=1):
73
+ output_box = gr.Textbox(label="Predicted Quality Score", lines=5)
74
+
75
+ is_finetune_checkbox.change(
76
+ fn=toggle_dataset_visibility,
77
+ inputs=is_finetune_checkbox,
78
+ outputs=dataset_dropdown
79
+ )
80
+
81
+ run_button.click(
82
+ fn=run_relax_vqa,
83
+ inputs=[video_input, is_finetune_checkbox, framerate_slider, dataset_dropdown],
84
+ outputs=output_box
85
+ )
86
+
87
+ demo.launch()
demo_test_gpu.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import math
4
+ import os
5
+ import shutil
6
+ from joblib import load
7
+ import cv2
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from thop import profile
12
+ from torchvision import models, transforms
13
+
14
+ from extractor.visualise_vit_layer import VitGenerator
15
+ from relax_vqa import get_deep_feature, process_video_feature, process_patches, get_frame_patches, flow_to_rgb, merge_fragments, concatenate_features
16
+ from extractor.vf_extract import process_video_residual
17
+ from model_regression import Mlp, preprocess_data
18
+
19
+ def fix_state_dict(state_dict):
20
+ new_state_dict = {}
21
+ for k, v in state_dict.items():
22
+ if k.startswith('module.'):
23
+ name = k[7:]
24
+ elif k == 'n_averaged':
25
+ continue
26
+ else:
27
+ name = k
28
+ new_state_dict[name] = v
29
+ return new_state_dict
30
+
31
+ def preprocess_data(X, y=None, imp=None, scaler=None):
32
+ if not isinstance(X, torch.Tensor):
33
+ X = torch.tensor(X, device='cuda' if torch.cuda.is_available() else 'cpu')
34
+ X = torch.where(torch.isnan(X) | torch.isinf(X), torch.tensor(0.0, device=X.device), X)
35
+
36
+ if imp is not None or scaler is not None:
37
+ X_np = X.cpu().numpy()
38
+ if imp is not None:
39
+ X_np = imp.transform(X_np)
40
+ if scaler is not None:
41
+ X_np = scaler.transform(X_np)
42
+ X = torch.from_numpy(X_np).to(X.device)
43
+
44
+ if y is not None and y.size > 0:
45
+ if not isinstance(y, torch.Tensor):
46
+ y = torch.tensor(y, device=X.device)
47
+ y = y.reshape(-1).squeeze()
48
+ else:
49
+ y = None
50
+
51
+ return X, y, imp, scaler
52
+
53
+ def load_model(config, device, input_features=35203):
54
+ network_name = 'relaxvqa'
55
+ # input_features = X_test_processed.shape[1]
56
+ model = Mlp(input_features=input_features, out_features=1, drop_rate=0.2, act_layer=nn.GELU).to(device)
57
+ if config['is_finetune']:
58
+ model_path = os.path.join(config['save_path'], f"fine_tune_model/{config['video_type']}_{network_name}_{config['select_criteria']}_fine_tuned_model.pth")
59
+ else:
60
+ model_path = os.path.join(config['save_path'], f"{config['train_data_name']}_{network_name}_{config['select_criteria']}_trained_median_model_param_onLSVQ_TEST.pth")
61
+ print("Loading model from:", model_path)
62
+ state_dict = torch.load(model_path, map_location=device)
63
+ fixed_state_dict = fix_state_dict(state_dict)
64
+ try:
65
+ model.load_state_dict(fixed_state_dict)
66
+ except RuntimeError as e:
67
+ print(e)
68
+ return model
69
+
70
+ def evaluate_video_quality(config, resnet50, vit, model_mlp, device):
71
+ is_finetune = config['is_finetune']
72
+ save_path = config['save_path']
73
+ video_type = config['video_type']
74
+ video_name = config['video_name']
75
+ framerate = config['framerate']
76
+ sampled_fragment_path = os.path.join("../video_sampled_frame/sampled_frame/", "test_sampled_fragment")
77
+
78
+ video_path = config.get("video_path")
79
+ if video_path is None:
80
+ if video_type == 'youtube_ugc':
81
+ video_path = f'./ugc_original_videos/{video_name}.mkv'
82
+ else:
83
+ video_path = f'./ugc_original_videos/{video_name}.mp4'
84
+ target_size = 224
85
+ patch_size = 16
86
+ top_n = int((target_size / patch_size) * (target_size / patch_size))
87
+
88
+ # sampled video frames
89
+ start_time = time.time()
90
+ frames, frames_next = process_video_residual(video_type, video_name, framerate, video_path, sampled_fragment_path)
91
+
92
+ # get ResNet50 layer-stack features and ViT pooling features
93
+ all_frame_activations_resnet = []
94
+ all_frame_activations_vit = []
95
+ # get fragments ResNet50 features and ViT features
96
+ all_frame_activations_sampled_resnet = []
97
+ all_frame_activations_merged_resnet = []
98
+ all_frame_activations_sampled_vit = []
99
+ all_frame_activations_merged_vit = []
100
+
101
+ batch_size = 64 # Define the number of frames to process in parallel
102
+ for i in range(0, len(frames_next), batch_size):
103
+ batch_frames = frames[i:i + batch_size]
104
+ batch_rgb_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in batch_frames]
105
+ batch_frames_next = frames_next[i:i + batch_size]
106
+ batch_tensors = torch.stack([transforms.ToTensor()(frame) for frame in batch_frames]).to(device)
107
+ batch_rgb_tensors = torch.stack([transforms.ToTensor()(frame_rgb) for frame_rgb in batch_rgb_frames]).to(device)
108
+ batch_tensors_next = torch.stack([transforms.ToTensor()(frame_next) for frame_next in batch_frames_next]).to(device)
109
+
110
+ # compute residuals
111
+ residuals = torch.abs(batch_tensors_next - batch_tensors)
112
+
113
+ # calculate optical flows
114
+ batch_gray_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) for frame in batch_frames]
115
+ batch_gray_frames_next = [cv2.cvtColor(frame_next, cv2.COLOR_BGR2GRAY) for frame_next in batch_frames_next]
116
+ batch_gray_frames = [frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame for frame in batch_gray_frames]
117
+ batch_gray_frames_next = [frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame for frame in batch_gray_frames_next]
118
+ flows = [cv2.calcOpticalFlowFarneback(batch_gray_frames[j], batch_gray_frames_next[j], None, 0.5, 3, 15, 3, 5, 1.2,0) for j in range(len(batch_gray_frames))]
119
+
120
+ for j in range(batch_tensors.size(0)):
121
+ '''sampled video frames'''
122
+ frame_tensor = batch_tensors[j].unsqueeze(0)
123
+ frame_rgb_tensor = batch_rgb_tensors[j].unsqueeze(0)
124
+ # frame_next_tensor = batch_tensors_next[j].unsqueeze(0)
125
+ frame_number = i + j + 1
126
+
127
+ # ResNet50 layer-stack features
128
+ activations_dict_resnet, _, _ = get_deep_feature('resnet50', video_name, frame_rgb_tensor, frame_number, resnet50, device, 'layerstack')
129
+ all_frame_activations_resnet.append(activations_dict_resnet)
130
+ # ViT pooling features
131
+ activations_dict_vit, _, _ = get_deep_feature('vit', video_name, frame_rgb_tensor, frame_number, vit, device, 'pool')
132
+ all_frame_activations_vit.append(activations_dict_vit)
133
+
134
+
135
+ '''residual video frames'''
136
+ residual = residuals[j].unsqueeze(0)
137
+ flow = flows[j]
138
+ original_path = os.path.join(sampled_fragment_path, f'{video_name}_{frame_number}.png')
139
+
140
+ # Frame Differencing
141
+ residual_frag_path, diff_frag, positions = process_patches(original_path, 'frame_diff', residual, patch_size, target_size, top_n)
142
+ # Frame fragment
143
+ frame_patches = get_frame_patches(frame_tensor, positions, patch_size, target_size)
144
+ # Optical Flow
145
+ opticalflow_rgb = flow_to_rgb(flow)
146
+ opticalflow_rgb_tensor = transforms.ToTensor()(opticalflow_rgb).unsqueeze(0).to(device)
147
+ opticalflow_frag_path, flow_frag, _ = process_patches(original_path, 'optical_flow', opticalflow_rgb_tensor, patch_size, target_size, top_n)
148
+
149
+ merged_frag = merge_fragments(diff_frag, flow_frag)
150
+
151
+ # fragments ResNet50 features
152
+ sampled_frag_activations_resnet, _, _ = get_deep_feature('resnet50', video_name, frame_patches, frame_number, resnet50, device, 'layerstack')
153
+ merged_frag_activations_resnet, _, _ = get_deep_feature('resnet50', video_name, merged_frag, frame_number, resnet50, device, 'pool')
154
+ all_frame_activations_sampled_resnet.append(sampled_frag_activations_resnet)
155
+ all_frame_activations_merged_resnet.append(merged_frag_activations_resnet)
156
+ # fragments ViT features
157
+ sampled_frag_activations_vit,_, _ = get_deep_feature('vit', video_name, frame_patches, frame_number, vit, device, 'pool')
158
+ merged_frag_activations_vit, _, _ = get_deep_feature('vit', video_name, merged_frag, frame_number, vit, device, 'pool')
159
+ all_frame_activations_sampled_vit.append(sampled_frag_activations_vit)
160
+ all_frame_activations_merged_vit.append(merged_frag_activations_vit)
161
+
162
+ print(f'video frame number: {len(all_frame_activations_resnet)}')
163
+ averaged_frames_resnet = process_video_feature(all_frame_activations_resnet, 'resnet50', 'layerstack')
164
+ averaged_frames_vit = process_video_feature(all_frame_activations_vit, 'vit', 'pool')
165
+ # print("ResNet50 layer-stacking feature shape:", averaged_frames_resnet.shape)
166
+ # print("ViT pooling feature shape:", averaged_frames_vit.shape)
167
+ averaged_frames_sampled_resnet = process_video_feature(all_frame_activations_sampled_resnet, 'resnet50', 'layerstack')
168
+ averaged_frames_merged_resnet = process_video_feature(all_frame_activations_merged_resnet, 'resnet50', 'pool')
169
+ averaged_combined_feature_resnet = concatenate_features(averaged_frames_sampled_resnet, averaged_frames_merged_resnet)
170
+ # print("Sampled fragments ResNet50 features shape:", averaged_frames_sampled_resnet.shape)
171
+ # print("Merged fragments ResNet50 features shape:", averaged_frames_merged_resnet.shape)
172
+ averaged_frames_sampled_vit = process_video_feature(all_frame_activations_sampled_vit, 'vit', 'pool')
173
+ averaged_frames_merged_vit = process_video_feature(all_frame_activations_merged_vit, 'vit', 'pool')
174
+ averaged_combined_feature_vit = concatenate_features(averaged_frames_sampled_vit, averaged_frames_merged_vit)
175
+ # print("Sampled fragments ViT features shape:", averaged_frames_sampled_vit.shape)
176
+ # print("Merged fragments ResNet50 features shape:", averaged_frames_merged_vit.shape)
177
+
178
+ # remove tmp folders
179
+ shutil.rmtree(sampled_fragment_path)
180
+
181
+ # concatenate features
182
+ combined_features = torch.cat([torch.mean(averaged_frames_resnet, dim=0), torch.mean(averaged_frames_vit, dim=0),
183
+ torch.mean(averaged_combined_feature_resnet, dim=0), torch.mean(averaged_combined_feature_vit, dim=0)], dim=0).view(1, -1)
184
+ imputer = load(f'{save_path}/scaler/{video_type}_imputer.pkl')
185
+ scaler = load(f'{save_path}/scaler/{video_type}_scaler.pkl')
186
+ X_test_processed, _, _, _ = preprocess_data(combined_features, None, imp=imputer, scaler=scaler)
187
+ feature_tensor = X_test_processed
188
+
189
+ # evaluation for test video
190
+ model_mlp.eval()
191
+
192
+ with torch.no_grad():
193
+ with torch.cuda.amp.autocast():
194
+ prediction = model_mlp(feature_tensor)
195
+ predicted_score = prediction.item()
196
+ # print(f"Raw Predicted Quality Score: {predicted_score}")
197
+ run_time = time.time() - start_time
198
+
199
+ if not is_finetune:
200
+ if video_type in ['konvid_1k', 'youtube_ugc']:
201
+ scaled_prediction = ((predicted_score - 1) / (99 / 4)) + 1.0
202
+ # print(f"Scaled Predicted Quality Score (1-5): {scaled_prediction}")
203
+ return scaled_prediction, run_time
204
+ else:
205
+ scaled_prediction = predicted_score
206
+ return scaled_prediction, run_time
207
+ else:
208
+ return predicted_score, run_time
209
+
210
+ def parse_arguments():
211
+ parser = argparse.ArgumentParser()
212
+ parser.add_argument('-device', type=str, default='gpu', help='cpu or gpu')
213
+ parser.add_argument('-model_name', type=str, default='Mlp', help='Name of the regression model')
214
+ parser.add_argument('-select_criteria', type=str, default='byrmse', help='Selection criteria')
215
+ parser.add_argument('-train_data_name', type=str, default='lsvq_train', help='Name of the training data')
216
+ parser.add_argument('-is_finetune', type=bool, default=False, help='With or without finetune')
217
+ parser.add_argument('-save_path', type=str, default='model/', help='Path to save models')
218
+ parser.add_argument('-video_type', type=str, default='konvid_1k', help='Type of video')
219
+ parser.add_argument('-video_name', type=str, default='5636101558_540p', help='Name of the video')
220
+ parser.add_argument('-framerate', type=float, default=24, help='Frame rate of the video')
221
+
222
+ args = parser.parse_args()
223
+ return args
224
+
225
+ if __name__ == '__main__':
226
+ args = parse_arguments()
227
+ config = vars(args)
228
+ if config['device'] == "gpu":
229
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
230
+ else:
231
+ device = torch.device("cpu")
232
+ print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")
233
+
234
+ # load models to device
235
+ resnet50 = models.resnet50(pretrained=True).to(device)
236
+ vit = VitGenerator('vit_base', 16, device, evaluate=True, random=False, verbose=True)
237
+ model_mlp = load_model(config, device)
238
+
239
+ total_time = 0
240
+ num_runs = 1
241
+ for i in range(num_runs):
242
+ quality_prediction, run_time = evaluate_video_quality(config, resnet50, vit, model_mlp, device)
243
+ print(f"Run {i + 1} - Time taken: {run_time:.4f} seconds")
244
+
245
+ total_time += run_time
246
+ average_time = total_time / num_runs
247
+
248
+ print(f"Average running time over {num_runs} runs: {average_time:.4f} seconds")
249
+ print("Predicted Quality Score:", quality_prediction)
model_regression.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import os
4
+ import pandas as pd
5
+ import numpy as np
6
+ import math
7
+ import scipy.io
8
+ import scipy.stats
9
+ from sklearn.impute import SimpleImputer
10
+ from sklearn.preprocessing import StandardScaler, MinMaxScaler
11
+ from sklearn.metrics import mean_squared_error
12
+ from scipy.optimize import curve_fit
13
+ import joblib
14
+
15
+ import seaborn as sns
16
+ import matplotlib.pyplot as plt
17
+ import copy
18
+ import argparse
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torch.optim as optim
24
+ from torch.optim.lr_scheduler import CosineAnnealingLR
25
+ from torch.optim.swa_utils import AveragedModel, SWALR
26
+ from torch.utils.data import DataLoader, TensorDataset
27
+ from sklearn.model_selection import KFold
28
+ from sklearn.model_selection import train_test_split
29
+
30
+ from data_processing import split_train_test
31
+
32
+ # ignore all warnings
33
+ import warnings
34
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
35
+
36
+
37
+ class Mlp(nn.Module):
38
+ def __init__(self, input_features, hidden_features=256, out_features=1, drop_rate=0.2, act_layer=nn.GELU):
39
+ super().__init__()
40
+ self.fc1 = nn.Linear(input_features, hidden_features)
41
+ self.bn1 = nn.BatchNorm1d(hidden_features)
42
+ self.act1 = act_layer()
43
+ self.drop1 = nn.Dropout(drop_rate)
44
+ self.fc2 = nn.Linear(hidden_features, hidden_features // 2)
45
+ self.act2 = act_layer()
46
+ self.drop2 = nn.Dropout(drop_rate)
47
+ self.fc3 = nn.Linear(hidden_features // 2, out_features)
48
+
49
+ def forward(self, input_feature):
50
+ x = self.fc1(input_feature)
51
+ x = self.bn1(x)
52
+ x = self.act1(x)
53
+ x = self.drop1(x)
54
+ x = self.fc2(x)
55
+ x = self.act2(x)
56
+ x = self.drop2(x)
57
+ output = self.fc3(x)
58
+ return output
59
+
60
+
61
+ class MAEAndRankLoss(nn.Module):
62
+ def __init__(self, l1_w=1.0, rank_w=1.0, margin=0.0, use_margin=False):
63
+ super(MAEAndRankLoss, self).__init__()
64
+ self.l1_w = l1_w
65
+ self.rank_w = rank_w
66
+ self.margin = margin
67
+ self.use_margin = use_margin
68
+
69
+ def forward(self, y_pred, y_true):
70
+ # L1 loss/MAE loss
71
+ l_mae = F.l1_loss(y_pred, y_true, reduction='mean') * self.l1_w
72
+ # Rank loss
73
+ n = y_pred.size(0)
74
+ pred_diff = y_pred.unsqueeze(1) - y_pred.unsqueeze(0)
75
+ true_diff = y_true.unsqueeze(1) - y_true.unsqueeze(0)
76
+
77
+ # e(ytrue_i, ytrue_j)
78
+ masks = torch.sign(true_diff)
79
+
80
+ if self.use_margin and self.margin > 0:
81
+ true_diff = true_diff.abs() - self.margin
82
+ true_diff = F.relu(true_diff)
83
+ masks = true_diff.sign()
84
+
85
+ l_rank = F.relu(true_diff - masks * pred_diff)
86
+ l_rank = l_rank.sum() / (n * (n - 1))
87
+
88
+ loss = l_mae + l_rank * self.rank_w
89
+ return loss
90
+
91
+ def load_data(csv_file, mat_file, features, data_name, set_name):
92
+ try:
93
+ df = pd.read_csv(csv_file, skiprows=[], header=None)
94
+ except Exception as e:
95
+ logging.error(f'Read CSV file error: {e}')
96
+ raise
97
+
98
+ try:
99
+ if data_name == 'lsvq_train':
100
+ X_mat = features
101
+ else:
102
+ X_mat = scipy.io.loadmat(mat_file)
103
+ except Exception as e:
104
+ logging.error(f'Read MAT file error: {e}')
105
+ raise
106
+
107
+ y_data = df.values[1:, 2]
108
+ y = np.array(list(y_data), dtype=float)
109
+
110
+ if data_name == 'cross_dataset': # or data_name == 'lsvq_train':
111
+ y[y > 5] = 5
112
+ if set_name == 'test':
113
+ print(f"Modified y_true: {y}")
114
+ if data_name == 'lsvq_train':
115
+ X = np.asarray(X_mat, dtype=float)
116
+ else:
117
+ data_name = f'{data_name}_{set_name}_features'
118
+ X = np.asarray(X_mat[data_name], dtype=float)
119
+
120
+ return X, y
121
+
122
+ def preprocess_data(X, y):
123
+ X[np.isnan(X)] = 0
124
+ X[np.isinf(X)] = 0
125
+ imp = SimpleImputer(missing_values=np.nan, strategy='mean').fit(X)
126
+ X = imp.transform(X)
127
+
128
+ # scaler = StandardScaler()
129
+ scaler = MinMaxScaler().fit(X)
130
+ X = scaler.transform(X)
131
+ logging.info(f'Scaler: {scaler}')
132
+
133
+ y = y.reshape(-1, 1).squeeze()
134
+ return X, y, imp, scaler
135
+
136
+ # define 4-parameter logistic regression
137
+ def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
138
+ logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4))))
139
+ yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart)
140
+ return yhat
141
+
142
+ def fit_logistic_regression(y_pred, y_true):
143
+ beta = [np.max(y_true), np.min(y_true), np.mean(y_pred), 0.5]
144
+ popt, _ = curve_fit(logistic_func, y_pred, y_true, p0=beta, maxfev=100000000)
145
+ y_pred_logistic = logistic_func(y_pred, *popt)
146
+ return y_pred_logistic, beta, popt
147
+
148
+ def compute_correlation_metrics(y_true, y_pred):
149
+ y_pred_logistic, beta, popt = fit_logistic_regression(y_pred, y_true)
150
+
151
+ plcc = scipy.stats.pearsonr(y_true, y_pred_logistic)[0]
152
+ rmse = np.sqrt(mean_squared_error(y_true, y_pred_logistic))
153
+ srcc = scipy.stats.spearmanr(y_true, y_pred)[0]
154
+
155
+ try:
156
+ krcc = scipy.stats.kendalltau(y_true, y_pred)[0]
157
+ except Exception as e:
158
+ logging.error(f'krcc calculation: {e}')
159
+ krcc = scipy.stats.kendalltau(y_true, y_pred, method='asymptotic')[0]
160
+ return y_pred_logistic, plcc, rmse, srcc, krcc
161
+
162
+ def plot_results(y_test, y_test_pred_logistic, df_pred_score, model_name, data_name, network_name, select_criteria):
163
+ # nonlinear logistic fitted curve / logistic regression
164
+ mos1 = y_test
165
+ y1 = y_test_pred_logistic
166
+
167
+ try:
168
+ beta = [np.max(mos1), np.min(mos1), np.mean(y1), 0.5]
169
+ popt, pcov = curve_fit(logistic_func, y1, mos1, p0=beta, maxfev=100000000)
170
+ sigma = np.sqrt(np.diag(pcov))
171
+ except:
172
+ raise Exception('Fitting logistic function time-out!!')
173
+ x_values1 = np.linspace(np.min(y1), np.max(y1), len(y1))
174
+ plt.plot(x_values1, logistic_func(x_values1, *popt), '-', color='#c72e29', label='Fitted f(x)')
175
+
176
+ fig1 = sns.scatterplot(x="y_test_pred_logistic", y="MOS", data=df_pred_score, markers='o', color='steelblue', label=network_name)
177
+ plt.legend(loc='upper left')
178
+ if data_name == 'live_vqc' or data_name == 'live_qualcomm' or data_name == 'cvd_2014' or data_name == 'lsvq_train':
179
+ plt.ylim(0, 100)
180
+ plt.xlim(0, 100)
181
+ else:
182
+ plt.ylim(1, 5)
183
+ plt.xlim(1, 5)
184
+ plt.title(f"Algorithm {network_name} with {model_name} on dataset {data_name}", fontsize=10)
185
+ plt.xlabel('Predicted Score')
186
+ plt.ylabel('MOS')
187
+ reg_fig1 = fig1.get_figure()
188
+
189
+ fig_path = f'../figs/{data_name}/'
190
+ os.makedirs(fig_path, exist_ok=True)
191
+ reg_fig1.savefig(fig_path + f"{network_name}_{model_name}_{data_name}_by{select_criteria}_kfold.png", dpi=300)
192
+ plt.clf()
193
+ plt.close()
194
+
195
+ def plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, test_vids, i):
196
+ plt.figure(figsize=(10, 6))
197
+
198
+ plt.plot(avg_train_losses, label='Average Training Loss')
199
+ plt.plot(avg_val_losses, label='Average Validation Loss')
200
+
201
+ plt.xlabel('Epoch')
202
+ plt.ylabel('Loss')
203
+ plt.title(f'Average Training and Validation Loss Across Folds - {network_name} with {model_name} (test_vids: {test_vids})', fontsize=10)
204
+
205
+ plt.legend()
206
+ fig_par_path = f'../log/result/{data_name}/'
207
+ os.makedirs(fig_par_path, exist_ok=True)
208
+ plt.savefig(f'{fig_par_path}/{network_name}_Average_Training_Loss_test{i}.png', dpi=50)
209
+ plt.clf()
210
+ plt.close()
211
+
212
+ def configure_logging(log_path, model_name, data_name, network_name, select_criteria):
213
+ log_file_name = os.path.join(log_path, f"{data_name}_{network_name}_{model_name}_corr_{select_criteria}_kfold.log")
214
+ logging.basicConfig(filename=log_file_name, filemode='w', level=logging.DEBUG, format='%(levelname)s - %(message)s')
215
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
216
+ logging.info(f"Evaluating algorithm {network_name} with {model_name} on dataset {data_name}")
217
+ logging.info(f"torch cuda: {torch.cuda.is_available()}")
218
+
219
+ def load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features):
220
+ if data_name == 'cross_dataset':
221
+ data_name1 = 'youtube_ugc_all'
222
+ data_name2 = 'cvd_2014_all'
223
+ csv_train_file = os.path.join(metadata_path, f'mos_files/{data_name1}_MOS_train.csv')
224
+ csv_test_file = os.path.join(metadata_path, f'mos_files/{data_name2}_MOS_test.csv')
225
+ mat_train_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name1}_{network_name}_train_features.mat')
226
+ mat_test_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name2}_{network_name}_test_features.mat')
227
+ X_train, y_train = load_data(csv_train_file, mat_train_file, None, data_name1, 'train')
228
+ X_test, y_test = load_data(csv_test_file, mat_test_file, None, data_name2, 'test')
229
+
230
+ elif data_name == 'lsvq_train':
231
+ csv_train_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
232
+ csv_test_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
233
+ X_train, y_train = load_data(csv_train_file, None, train_features, data_name, 'train')
234
+ X_test, y_test = load_data(csv_test_file, None, test_features, data_name, 'test')
235
+
236
+ else:
237
+ csv_train_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
238
+ csv_test_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
239
+ mat_train_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name}_{network_name}_train_features.mat')
240
+ mat_test_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name}_{network_name}_test_features.mat')
241
+ X_train, y_train = load_data(csv_train_file, mat_train_file, None, data_name, 'train')
242
+ X_test, y_test = load_data(csv_test_file, mat_test_file, None, data_name, 'test')
243
+
244
+ # standard min-max normalization of traning features
245
+ X_train, y_train, _, _ = preprocess_data(X_train, y_train)
246
+ X_test, y_test, _, _ = preprocess_data(X_test, y_test)
247
+
248
+ return X_train, y_train, X_test, y_test
249
+
250
+ def train_one_epoch(model, train_loader, criterion, optimizer, device):
251
+ """Train the model for one epoch"""
252
+ model.train()
253
+ train_loss = 0.0
254
+ for inputs, targets in train_loader:
255
+ inputs, targets = inputs.to(device), targets.to(device)
256
+
257
+ optimizer.zero_grad()
258
+ outputs = model(inputs)
259
+ loss = criterion(outputs, targets.view(-1, 1))
260
+ loss.backward()
261
+ optimizer.step()
262
+ train_loss += loss.item() * inputs.size(0)
263
+ train_loss /= len(train_loader.dataset)
264
+ return train_loss
265
+
266
+ def evaluate(model, val_loader, criterion, device):
267
+ """Evaluate model performance on validation sets"""
268
+ model.eval()
269
+ val_loss = 0.0
270
+ y_val_pred = []
271
+ with torch.no_grad():
272
+ for inputs, targets in val_loader:
273
+ inputs, targets = inputs.to(device), targets.to(device)
274
+
275
+ outputs = model(inputs)
276
+ y_val_pred.extend(outputs.view(-1).tolist())
277
+ loss = criterion(outputs, targets.view(-1, 1))
278
+ val_loss += loss.item() * inputs.size(0)
279
+ val_loss /= len(val_loader.dataset)
280
+ return val_loss, np.array(y_val_pred)
281
+
282
+ def update_best_model(select_criteria, best_metric, current_val, model):
283
+ is_better = False
284
+ if select_criteria == 'byrmse' and current_val < best_metric:
285
+ is_better = True
286
+ elif select_criteria == 'bykrcc' and current_val > best_metric:
287
+ is_better = True
288
+
289
+ if is_better:
290
+ return current_val, copy.deepcopy(model), is_better
291
+ return best_metric, model, is_better
292
+
293
+ def train_and_evaluate(X_train, y_train, config):
294
+ # parameters
295
+ n_repeats = config['n_repeats']
296
+ n_splits = config['n_splits']
297
+ batch_size = config['batch_size']
298
+ epochs = config['epochs']
299
+ hidden_features = config['hidden_features']
300
+ drop_rate = config['drop_rate']
301
+ loss_type = config['loss_type']
302
+ optimizer_type = config['optimizer_type']
303
+ select_criteria = config['select_criteria']
304
+ initial_lr = config['initial_lr']
305
+ weight_decay = config['weight_decay']
306
+ patience = config['patience']
307
+ l1_w = config['l1_w']
308
+ rank_w = config['rank_w']
309
+ use_swa = config.get('use_swa', False)
310
+ logging.info(f'Parameters - Number of repeats for 80-20 hold out test: {n_repeats}, Number of splits for kfold: {n_splits}, Batch size: {batch_size}, Number of epochs: {epochs}')
311
+ logging.info(f'Network Parameters - hidden_features: {hidden_features}, drop_rate: {drop_rate}, patience: {patience}')
312
+ logging.info(f'Optimizer Parameters - loss_type: {loss_type}, optimizer_type: {optimizer_type}, initial_lr: {initial_lr}, weight_decay: {weight_decay}, use_swa: {use_swa}')
313
+ logging.info(f'MAEAndRankLoss - l1_w: {l1_w}, rank_w: {rank_w}')
314
+
315
+ kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
316
+ best_model = None
317
+ best_metric = float('inf') if select_criteria == 'byrmse' else float('-inf')
318
+
319
+ # loss for every fold
320
+ all_train_losses = []
321
+ all_val_losses = []
322
+ for fold, (train_idx, val_idx) in enumerate(kf.split(X_train)):
323
+ print(f"Fold {fold + 1}/{n_splits}")
324
+
325
+ X_train_fold, X_val_fold = X_train[train_idx], X_train[val_idx]
326
+ y_train_fold, y_val_fold = y_train[train_idx], y_train[val_idx]
327
+
328
+ # initialisation of model, loss function, optimiser
329
+ model = Mlp(input_features=X_train_fold.shape[1], hidden_features=hidden_features, drop_rate=drop_rate)
330
+ model = model.to(device) # to gpu
331
+
332
+ if loss_type == 'MAERankLoss':
333
+ criterion = MAEAndRankLoss()
334
+ criterion.l1_w = l1_w
335
+ criterion.rank_w = rank_w
336
+ else:
337
+ nn.MSELoss()
338
+
339
+ if optimizer_type == 'sgd':
340
+ optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=weight_decay)
341
+ scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)# initial eta_nim=1e-5
342
+ else:
343
+ optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay) # L2 Regularisation initial: 0.01, 1e-5
344
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95) # step_size=10, gamma=0.1: every 10 epochs lr*0.1
345
+ if use_swa:
346
+ swa_model = AveragedModel(model).to(device)
347
+ swa_scheduler = SWALR(optimizer, swa_lr=initial_lr, anneal_strategy='cos')
348
+
349
+ # dataset loader
350
+ train_dataset = TensorDataset(torch.FloatTensor(X_train_fold), torch.FloatTensor(y_train_fold))
351
+ val_dataset = TensorDataset(torch.FloatTensor(X_val_fold), torch.FloatTensor(y_val_fold))
352
+ train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
353
+ val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
354
+
355
+ train_losses, val_losses = [], []
356
+
357
+ # early stopping parameters
358
+ best_val_loss = float('inf')
359
+ epochs_no_improve = 0
360
+ early_stop_active = False
361
+ swa_start = int(epochs * 0.7) if use_swa else epochs # SWA starts after 70% of total epochs, only set SWA start if SWA is used
362
+
363
+ for epoch in range(epochs):
364
+ train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
365
+ train_losses.append(train_loss)
366
+ scheduler.step() # update learning rate
367
+ if use_swa and epoch >= swa_start:
368
+ swa_model.update_parameters(model)
369
+ swa_scheduler.step()
370
+ early_stop_active = True
371
+ print(f"Current learning rate with SWA: {swa_scheduler.get_last_lr()}")
372
+
373
+ lr = optimizer.param_groups[0]['lr']
374
+ print('Epoch %d: Learning rate: %f' % (epoch + 1, lr))
375
+
376
+ # decide which model to evaluate: SWA model or regular model
377
+ current_model = swa_model if use_swa and epoch >= swa_start else model
378
+ current_model.eval()
379
+ val_loss, y_val_pred = evaluate(current_model, val_loader, criterion, device)
380
+ val_losses.append(val_loss)
381
+ print(f"Epoch {epoch + 1}, Fold {fold + 1}, Training Loss: {train_loss}, Validation Loss: {val_loss}")
382
+
383
+ y_val_pred = np.array(list(y_val_pred), dtype=float)
384
+ _, _, rmse_val, _, krcc_val = compute_correlation_metrics(y_val_fold, y_val_pred)
385
+ current_metric = rmse_val if select_criteria == 'byrmse' else krcc_val
386
+ best_metric, best_model, is_better = update_best_model(select_criteria, best_metric, current_metric, current_model)
387
+ if is_better:
388
+ logging.info(f"Epoch {epoch + 1}, Fold {fold + 1}:")
389
+ y_val_pred_logistic_tmp, plcc_valid_tmp, rmse_valid_tmp, srcc_valid_tmp, krcc_valid_tmp = compute_correlation_metrics(y_val_fold, y_val_pred)
390
+ logging.info(f'Validation set - Evaluation Results - SRCC: {srcc_valid_tmp}, KRCC: {krcc_valid_tmp}, PLCC: {plcc_valid_tmp}, RMSE: {rmse_valid_tmp}')
391
+
392
+ X_train_fold_tensor = torch.FloatTensor(X_train_fold).to(device)
393
+ y_tra_pred_tmp = best_model(X_train_fold_tensor).detach().cpu().numpy().squeeze()
394
+ y_tra_pred_tmp = np.array(list(y_tra_pred_tmp), dtype=float)
395
+ y_tra_pred_logistic_tmp, plcc_train_tmp, rmse_train_tmp, srcc_train_tmp, krcc_train_tmp = compute_correlation_metrics(y_train_fold, y_tra_pred_tmp)
396
+ logging.info(f'Train set - Evaluation Results - SRCC: {srcc_train_tmp}, KRCC: {krcc_train_tmp}, PLCC: {plcc_train_tmp}, RMSE: {rmse_train_tmp}')
397
+
398
+ # check for loss improvement
399
+ if early_stop_active:
400
+ if val_loss < best_val_loss:
401
+ best_val_loss = val_loss
402
+ # save the best model if validation loss improves
403
+ best_model = copy.deepcopy(model)
404
+ epochs_no_improve = 0
405
+ else:
406
+ epochs_no_improve += 1
407
+ if epochs_no_improve >= patience:
408
+ # epochs to wait for improvement before stopping
409
+ print(f"Early stopping triggered after {epoch + 1} epochs.")
410
+ break
411
+
412
+ # saving SWA models and updating BN statistics
413
+ if use_swa:
414
+ train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_to_device(x, device))
415
+ best_model = best_model.to(device)
416
+ best_model.eval()
417
+ torch.optim.swa_utils.update_bn(train_loader, best_model)
418
+ # swa_model_path = os.path.join('save_swa_path='../model/', f'model_swa_fold{fold}.pth')
419
+ # torch.save(swa_model.state_dict(), swa_model_path)
420
+ # logging.info(f'SWA model saved at {swa_model_path}')
421
+
422
+ all_train_losses.append(train_losses)
423
+ all_val_losses.append(val_losses)
424
+ max_length = max(len(x) for x in all_train_losses)
425
+ all_train_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_train_losses]
426
+ max_length = max(len(x) for x in all_val_losses)
427
+ all_val_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_val_losses]
428
+
429
+ return best_model, all_train_losses, all_val_losses
430
+
431
+ def collate_to_device(batch, device):
432
+ data, targets = zip(*batch)
433
+ return torch.stack(data).to(device), torch.stack(targets).to(device)
434
+
435
+ def model_test(best_model, X, y, device):
436
+ test_dataset = TensorDataset(torch.FloatTensor(X), torch.FloatTensor(y))
437
+ test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
438
+
439
+ best_model.eval()
440
+ y_pred = []
441
+ with torch.no_grad():
442
+ for inputs, _ in test_loader:
443
+ inputs = inputs.to(device)
444
+
445
+ outputs = best_model(inputs)
446
+ y_pred.extend(outputs.view(-1).tolist())
447
+
448
+ return y_pred
449
+
450
+ def main(config):
451
+ model_name = config['model_name']
452
+ data_name = config['data_name']
453
+ network_name = config['network_name']
454
+
455
+ metadata_path = config['metadata_path']
456
+ feature_path = config['feature_path']
457
+ log_path = config['log_path']
458
+ save_path = config['save_path']
459
+ score_path = config['score_path']
460
+ result_path = config['result_path']
461
+
462
+ # parameters
463
+ select_criteria = config['select_criteria']
464
+ n_repeats = config['n_repeats']
465
+
466
+ # logging and result
467
+ os.makedirs(log_path, exist_ok=True)
468
+ os.makedirs(save_path, exist_ok=True)
469
+ os.makedirs(score_path, exist_ok=True)
470
+ os.makedirs(result_path, exist_ok=True)
471
+ result_file = f'{result_path}{data_name}_{network_name}_{select_criteria}.mat'
472
+ pred_score_filename = os.path.join(score_path, f"{data_name}_{network_name}_{select_criteria}.csv")
473
+ file_path = os.path.join(save_path, f"{data_name}_{network_name}_{select_criteria}_trained_median_model_param.pth")
474
+ configure_logging(log_path, model_name, data_name, network_name, select_criteria)
475
+
476
+ '''======================== Main Body ==========================='''
477
+ PLCC_all_repeats_test = []
478
+ SRCC_all_repeats_test = []
479
+ KRCC_all_repeats_test = []
480
+ RMSE_all_repeats_test = []
481
+ PLCC_all_repeats_train = []
482
+ SRCC_all_repeats_train = []
483
+ KRCC_all_repeats_train = []
484
+ RMSE_all_repeats_train = []
485
+ all_repeats_test_vids = []
486
+ all_repeats_df_test_pred = []
487
+ best_model_list = []
488
+
489
+ for i in range(1, n_repeats + 1):
490
+ print(f"{i}th repeated 80-20 hold out test")
491
+ logging.info(f"{i}th repeated 80-20 hold out test")
492
+ t0 = time.time()
493
+
494
+ # train test split
495
+ test_size = 0.2
496
+ random_state = math.ceil(8.8 * i)
497
+ # NR: original
498
+ if data_name == 'lsvq_train':
499
+ test_data_name = 'lsvq_test' #lsvq_test, lsvq_test_1080p
500
+ train_features, test_features, test_vids = split_train_test.process_lsvq(data_name, test_data_name, metadata_path, feature_path, network_name)
501
+ elif data_name == 'cross_dataset':
502
+ train_data_name = 'youtube_ugc_all'
503
+ test_data_name = 'cvd_2014_all'
504
+ _, _, test_vids = split_train_test.process_cross_dataset(train_data_name, test_data_name, metadata_path, feature_path, network_name)
505
+ else:
506
+ _, _, test_vids = split_train_test.process_other(data_name, test_size, random_state, metadata_path, feature_path, network_name)
507
+
508
+ '''======================== read files =============================== '''
509
+ if data_name == 'lsvq_train':
510
+ X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features)
511
+ else:
512
+ X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, None, None)
513
+
514
+ '''======================== regression model =============================== '''
515
+ best_model, all_train_losses, all_val_losses = train_and_evaluate(X_train, y_train, config)
516
+
517
+ # average loss plots
518
+ avg_train_losses = np.mean(all_train_losses, axis=0)
519
+ avg_val_losses = np.mean(all_val_losses, axis=0)
520
+ test_vids = test_vids.tolist()
521
+ plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, len(test_vids), i)
522
+
523
+ # predict best model on the train dataset
524
+ y_train_pred = model_test(best_model, X_train, y_train, device)
525
+ y_train_pred = np.array(list(y_train_pred), dtype=float)
526
+ y_train_pred_logistic, plcc_train, rmse_train, srcc_train, krcc_train = compute_correlation_metrics(y_train, y_train_pred)
527
+
528
+ # test best model on the test dataset
529
+ y_test_pred = model_test(best_model, X_test, y_test, device)
530
+ y_test_pred = np.array(list(y_test_pred), dtype=float)
531
+ y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test = compute_correlation_metrics(y_test, y_test_pred)
532
+
533
+ # save the predict score results
534
+ test_pred_score = {'MOS': y_test, 'y_test_pred': y_test_pred, 'y_test_pred_logistic': y_test_pred_logistic}
535
+ df_test_pred = pd.DataFrame(test_pred_score)
536
+
537
+ # logging logistic predicted scores
538
+ logging.info("============================================================================================================")
539
+ SRCC_all_repeats_test.append(srcc_test)
540
+ KRCC_all_repeats_test.append(krcc_test)
541
+ PLCC_all_repeats_test.append(plcc_test)
542
+ RMSE_all_repeats_test.append(rmse_test)
543
+ SRCC_all_repeats_train.append(srcc_train)
544
+ KRCC_all_repeats_train.append(krcc_train)
545
+ PLCC_all_repeats_train.append(plcc_train)
546
+ RMSE_all_repeats_train.append(rmse_train)
547
+ all_repeats_test_vids.append(test_vids)
548
+ all_repeats_df_test_pred.append(df_test_pred)
549
+ best_model_list.append(copy.deepcopy(best_model))
550
+
551
+ # logging.info results for each iteration
552
+ logging.info('Best results in Mlp model within one split')
553
+ logging.info(f'MODEL: {best_model}')
554
+ logging.info('======================================================')
555
+ logging.info(f'Train set - Evaluation Results')
556
+ logging.info(f'SRCC_train: {srcc_train}')
557
+ logging.info(f'KRCC_train: {krcc_train}')
558
+ logging.info(f'PLCC_train: {plcc_train}')
559
+ logging.info(f'RMSE_train: {rmse_train}')
560
+ logging.info('======================================================')
561
+ logging.info(f'Test set - Evaluation Results')
562
+ logging.info(f'SRCC_test: {srcc_test}')
563
+ logging.info(f'KRCC_test: {krcc_test}')
564
+ logging.info(f'PLCC_test: {plcc_test}')
565
+ logging.info(f'RMSE_test: {rmse_test}')
566
+ logging.info('======================================================')
567
+ logging.info(' -- {} seconds elapsed...\n\n'.format(time.time() - t0))
568
+
569
+ logging.info('')
570
+ SRCC_all_repeats_test = np.nan_to_num(SRCC_all_repeats_test)
571
+ KRCC_all_repeats_test = np.nan_to_num(KRCC_all_repeats_test)
572
+ PLCC_all_repeats_test = np.nan_to_num(PLCC_all_repeats_test)
573
+ RMSE_all_repeats_test = np.nan_to_num(RMSE_all_repeats_test)
574
+ SRCC_all_repeats_train = np.nan_to_num(SRCC_all_repeats_train)
575
+ KRCC_all_repeats_train = np.nan_to_num(KRCC_all_repeats_train)
576
+ PLCC_all_repeats_train = np.nan_to_num(PLCC_all_repeats_train)
577
+ RMSE_all_repeats_train = np.nan_to_num(RMSE_all_repeats_train)
578
+ logging.info('======================================================')
579
+ logging.info('Average training results among all repeated 80-20 holdouts:')
580
+ logging.info('SRCC: %f (std: %f)', np.median(SRCC_all_repeats_train), np.std(SRCC_all_repeats_train))
581
+ logging.info('KRCC: %f (std: %f)', np.median(KRCC_all_repeats_train), np.std(KRCC_all_repeats_train))
582
+ logging.info('PLCC: %f (std: %f)', np.median(PLCC_all_repeats_train), np.std(PLCC_all_repeats_train))
583
+ logging.info('RMSE: %f (std: %f)', np.median(RMSE_all_repeats_train), np.std(RMSE_all_repeats_train))
584
+ logging.info('======================================================')
585
+ logging.info('Average testing results among all repeated 80-20 holdouts:')
586
+ logging.info('SRCC: %f (std: %f)', np.median(SRCC_all_repeats_test), np.std(SRCC_all_repeats_test))
587
+ logging.info('KRCC: %f (std: %f)', np.median(KRCC_all_repeats_test), np.std(KRCC_all_repeats_test))
588
+ logging.info('PLCC: %f (std: %f)', np.median(PLCC_all_repeats_test), np.std(PLCC_all_repeats_test))
589
+ logging.info('RMSE: %f (std: %f)', np.median(RMSE_all_repeats_test), np.std(RMSE_all_repeats_test))
590
+ logging.info('======================================================')
591
+ logging.info('\n')
592
+
593
+ # find the median model and the index of the median
594
+ print('======================================================')
595
+ if select_criteria == 'byrmse':
596
+ median_metrics = np.median(RMSE_all_repeats_test)
597
+ indices = np.where(RMSE_all_repeats_test == median_metrics)[0]
598
+ select_criteria = select_criteria.replace('by', '').upper()
599
+ print(RMSE_all_repeats_test)
600
+ logging.info(f'all {select_criteria}: {RMSE_all_repeats_test}')
601
+ elif select_criteria == 'bykrcc':
602
+ median_metrics = np.median(KRCC_all_repeats_test)
603
+ indices = np.where(KRCC_all_repeats_test == median_metrics)[0]
604
+ select_criteria = select_criteria.replace('by', '').upper()
605
+ print(KRCC_all_repeats_test)
606
+ logging.info(f'all {select_criteria}: {KRCC_all_repeats_test}')
607
+
608
+ median_test_vids = [all_repeats_test_vids[i] for i in indices]
609
+ test_vids = [arr.tolist() for arr in median_test_vids] if len(median_test_vids) > 1 else (median_test_vids[0] if median_test_vids else [])
610
+
611
+ # select the model with the first index where the median is located
612
+ # Note: If there are multiple iterations with the same median RMSE, the first index is selected here
613
+ median_model = None
614
+ if len(indices) > 0:
615
+ median_index = indices[0] # select the first index
616
+ median_model = best_model_list[median_index]
617
+ median_model_df_test_pred = all_repeats_df_test_pred[median_index]
618
+
619
+ median_model_df_test_pred.to_csv(pred_score_filename, index=False)
620
+ plot_results(y_test, y_test_pred_logistic, median_model_df_test_pred, model_name, data_name, network_name, select_criteria)
621
+
622
+ print(f'Median Metrics: {median_metrics}')
623
+ print(f'Indices: {indices}')
624
+ # print(f'Test Videos: {test_vids}')
625
+ print(f'Best model: {median_model}')
626
+
627
+ logging.info(f'median test {select_criteria}: {median_metrics}')
628
+ logging.info(f"Indices of median metrics: {indices}")
629
+ # logging.info(f'Best training and test dataset: {test_vids}')
630
+ logging.info(f'Best model predict score: {median_model_df_test_pred}')
631
+ logging.info(f'Best model: {median_model}')
632
+
633
+ # ================================================================================
634
+ # save mats
635
+ scipy.io.savemat(result_file, mdict={'SRCC_train': np.asarray(SRCC_all_repeats_train, dtype=float), \
636
+ 'KRCC_train': np.asarray(KRCC_all_repeats_train, dtype=float), \
637
+ 'PLCC_train': np.asarray(PLCC_all_repeats_train, dtype=float), \
638
+ 'RMSE_train': np.asarray(RMSE_all_repeats_train, dtype=float), \
639
+ 'SRCC_test': np.asarray(SRCC_all_repeats_test, dtype=float), \
640
+ 'KRCC_test': np.asarray(KRCC_all_repeats_test, dtype=float), \
641
+ 'PLCC_test': np.asarray(PLCC_all_repeats_test, dtype=float), \
642
+ 'RMSE_test': np.asarray(RMSE_all_repeats_test, dtype=float), \
643
+ f'Median_{select_criteria}': median_metrics, \
644
+ 'Test_Videos_list': all_repeats_test_vids, \
645
+ 'Test_videos_Median_model': test_vids, \
646
+ })
647
+
648
+ # save model
649
+ torch.save(median_model.state_dict(), file_path)
650
+ print(f"Model state_dict saved to {file_path}")
651
+
652
+
653
+ if __name__ == '__main__':
654
+ parser = argparse.ArgumentParser()
655
+ # input parameters
656
+ parser.add_argument('--model_name', type=str, default='Mlp')
657
+ parser.add_argument('--data_name', type=str, default='lsvq_train', help='konvid_1k, youtube_ugc, live_vqc, cvd_2014, lsvq_train, cross_dataset')
658
+ parser.add_argument('--network_name', type=str, default='relaxvqa', help='relaxvqa, {frag_name}_{network_name}_{layer_name}')
659
+
660
+ parser.add_argument('--metadata_path', type=str, default='../metadata/')
661
+ parser.add_argument('--feature_path', type=str, default='../features/')
662
+ parser.add_argument('--log_path', type=str, default='../log/')
663
+ parser.add_argument('--save_path', type=str, default='../model/')
664
+ parser.add_argument('--score_path', type=str, default='../log/predict_score/')
665
+ parser.add_argument('--result_path', type=str, default='../log/result/')
666
+ # training parameters
667
+ parser.add_argument('--select_criteria', type=str, default='byrmse', help='byrmse, bykrcc')
668
+ parser.add_argument('--n_repeats', type=int, default=21, help='Number of repeats for 80-20 hold out test')
669
+ parser.add_argument('--n_splits', type=int, default=10, help='Number of splits for k-fold validation')
670
+ parser.add_argument('--batch_size', type=int, default=256, help='Batch size for training')
671
+ parser.add_argument('--epochs', type=int, default=20, help='Epochs for training') # 120(small), 20(big)
672
+ parser.add_argument('--hidden_features', type=int, default=256, help='Hidden features')
673
+ parser.add_argument('--drop_rate', type=float, default=0.1, help='Dropout rate.')
674
+ # misc
675
+ parser.add_argument('--loss_type', type=str, default='MAERankLoss', help='MSEloss or MAERankLoss')
676
+ parser.add_argument('--optimizer_type', type=str, default='sgd', help='adam or sgd')
677
+ parser.add_argument('--initial_lr', type=float, default=1e-1, help='Initial learning rate: 1e-2')
678
+ parser.add_argument('--weight_decay', type=float, default=0.005, help='Weight decay (L2 loss): 1e-4')
679
+ parser.add_argument('--patience', type=int, default=5, help='Early stopping patience.')
680
+ parser.add_argument('--use_swa', type=bool, default=True, help='Use Stochastic Weight Averaging')
681
+ parser.add_argument('--l1_w', type=float, default=0.6, help='MAE loss weight')
682
+ parser.add_argument('--rank_w', type=float, default=1.0, help='Rank loss weight')
683
+
684
+ args = parser.parse_args()
685
+ config = vars(args) # args to dict
686
+ print(config)
687
+
688
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
689
+ print(device)
690
+ if device.type == "cuda":
691
+ torch.cuda.set_device(0)
692
+
693
+ main(config)
relax_vqa.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ from extractor import visualise_resnet, visualise_resnet_layer, visualise_vit_layer
6
+
7
+
8
+ def get_deep_feature(network_name, video_name, frame, frame_number, model, device, layer_name):
9
+ if network_name == 'resnet50':
10
+ if layer_name == 'layerstack':
11
+ all_layers = ['resnet50.conv1',
12
+ 'resnet50.layer1[0]', 'resnet50.layer1[1]', 'resnet50.layer1[2]',
13
+ 'resnet50.layer2[0]', 'resnet50.layer2[1]', 'resnet50.layer2[2]', 'resnet50.layer2[3]',
14
+ 'resnet50.layer3[0]', 'resnet50.layer3[1]', 'resnet50.layer3[2]', 'resnet50.layer3[3]',
15
+ 'resnet50.layer4[0]', 'resnet50.layer4[1]', 'resnet50.layer4[2]']
16
+ resnet50 = model
17
+ activations_dict, _, total_flops, total_params = visualise_resnet.process_video_frame(video_name, frame, frame_number, all_layers, resnet50, device)
18
+
19
+ elif layer_name == 'pool':
20
+ visual_layer = 'resnet50.avgpool' # before avg_pool
21
+ resnet50 = model
22
+ activations_dict, _, total_flops, total_params = visualise_resnet_layer.process_video_frame(video_name, frame, frame_number, visual_layer, resnet50, device)
23
+
24
+ elif network_name == 'vit':
25
+ patch_size = 16
26
+ activations_dict, _, total_flops, total_params = visualise_vit_layer.process_video_frame(video_name, frame, frame_number, model, patch_size, device)
27
+
28
+ return activations_dict, total_flops, total_params
29
+
30
+
31
+ def process_video_feature(video_feature, network_name, layer_name):
32
+ # initialize an empty list to store processed frames
33
+ averaged_frames = []
34
+ # iterate through each frame in the video_feature
35
+ for frame in video_feature:
36
+ frame_features = []
37
+
38
+ if network_name == 'vit':
39
+ # global mean and std
40
+ global_mean = torch.mean(frame, dim=0)
41
+ global_max = torch.max(frame, dim=0)[0]
42
+ global_std = torch.std(frame, dim=0)
43
+ # concatenate all pooling
44
+ combined_features = torch.hstack([global_mean, global_max, global_std])
45
+ frame_features.append(combined_features)
46
+
47
+ elif network_name == 'resnet50':
48
+ if layer_name == 'layerstack':
49
+ # iterate through each layer in the current framex
50
+ for layer_array in frame.values():
51
+ # calculate the mean along the specified axes (1 and 2) for each layer
52
+ layer_mean = torch.mean(layer_array, dim=(1, 2))
53
+ # append the calculated mean to the list for the current frame
54
+ frame_features.append(layer_mean)
55
+ elif layer_name == 'pool':
56
+ frame = torch.squeeze(torch.tensor(frame))
57
+ # global mean and std
58
+ global_mean = torch.mean(frame, dim=0)
59
+ global_max = torch.max(frame, dim=0)[0]
60
+ global_std = torch.std(frame, dim=0)
61
+ # concatenate all pooling
62
+ combined_features = torch.hstack([frame, global_mean, global_max, global_std])
63
+ frame_features.append(combined_features)
64
+
65
+ # concatenate the layer means horizontally to form the processed frame
66
+ processed_frame = torch.hstack(frame_features)
67
+ averaged_frames.append(processed_frame)
68
+
69
+ averaged_frames = torch.stack(averaged_frames)
70
+ return averaged_frames
71
+
72
+
73
+ def flow_to_rgb(flow):
74
+ mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
75
+ mag = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
76
+ # convert angle to hue
77
+ hue = ang * 180 / np.pi / 2
78
+
79
+ # create HSV
80
+ hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
81
+ hsv[..., 0] = hue
82
+ hsv[..., 1] = 255
83
+ hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
84
+ # convert HSV to RGB
85
+ rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
86
+ return rgb
87
+
88
+ def get_patch_diff(residual_frame, patch_size):
89
+ h, w = residual_frame.shape[2:] # Assuming (1, C, H, W) shape
90
+ h_adj = (h // patch_size) * patch_size
91
+ w_adj = (w // patch_size) * patch_size
92
+ residual_frame_adj = residual_frame[:, :, :h_adj, :w_adj]
93
+ # calculate absolute patch difference
94
+ diff = torch.zeros((h_adj // patch_size, w_adj // patch_size), device=residual_frame.device)
95
+ for i in range(0, h_adj, patch_size):
96
+ for j in range(0, w_adj, patch_size):
97
+ patch = residual_frame_adj[:, :, i:i + patch_size, j:j + patch_size]
98
+ # absolute sum
99
+ diff[i // patch_size, j // patch_size] = torch.sum(torch.abs(patch))
100
+ return diff
101
+
102
+ def extract_important_patches(residual_frame, diff, patch_size=16, target_size=224, top_n=196):
103
+ # find top n patches indices
104
+ patch_idx = torch.argsort(-diff.view(-1))
105
+ top_patches = [(idx // diff.shape[1], idx % diff.shape[1]) for idx in patch_idx[:top_n]]
106
+ sorted_idx = sorted(top_patches, key=lambda x: (x[0], x[1]))
107
+
108
+ imp_patches_img = torch.zeros((residual_frame.shape[1], target_size, target_size), dtype=residual_frame.dtype, device=residual_frame.device)
109
+ patches_per_row = target_size // patch_size # 14
110
+ # order the patch in the original location relation
111
+ positions = []
112
+ for idx, (y, x) in enumerate(sorted_idx):
113
+ patch = residual_frame[:, :, y * patch_size:(y + 1) * patch_size, x * patch_size:(x + 1) * patch_size]
114
+ # new patch location
115
+ row_idx = idx // patches_per_row
116
+ col_idx = idx % patches_per_row
117
+ start_y = row_idx * patch_size
118
+ start_x = col_idx * patch_size
119
+ imp_patches_img[:, start_y:start_y + patch_size, start_x:start_x + patch_size] = patch
120
+ positions.append((y.item(), x.item()))
121
+ return imp_patches_img, positions
122
+
123
+ def get_frame_patches(frame, positions, patch_size, target_size):
124
+ imp_patches_img = torch.zeros((frame.shape[1], target_size, target_size), dtype=frame.dtype, device=frame.device)
125
+ patches_per_row = target_size // patch_size
126
+
127
+ for idx, (y, x) in enumerate(positions):
128
+ start_y = y * patch_size
129
+ start_x = x * patch_size
130
+ end_y = start_y + patch_size
131
+ end_x = start_x + patch_size
132
+
133
+ patch = frame[:, :, start_y:end_y, start_x:end_x]
134
+ row_idx = idx // patches_per_row
135
+ col_idx = idx % patches_per_row
136
+ target_start_y = row_idx * patch_size
137
+ target_start_x = col_idx * patch_size
138
+
139
+ imp_patches_img[:, target_start_y:target_start_y + patch_size,
140
+ target_start_x:target_start_x + patch_size] = patch.squeeze(0)
141
+ return imp_patches_img
142
+
143
+ def process_patches(original_path, frag_name, residual, patch_size, target_size, top_n):
144
+ diff = get_patch_diff(residual, patch_size)
145
+ imp_patches, positions = extract_important_patches(residual, diff, patch_size, target_size, top_n)
146
+ if frag_name == 'frame_diff':
147
+ frag_path = original_path.replace('.png', '_residual_imp.png')
148
+ elif frag_name == 'optical_flow':
149
+ frag_path = original_path.replace('.png', '_residual_of_imp.png')
150
+ # cv2.imwrite(frag_path, imp_patches)
151
+ return frag_path, imp_patches, positions
152
+
153
+ def merge_fragments(diff_fragment, flow_fragment):
154
+ alpha = 0.5
155
+ merged_fragment = diff_fragment * alpha + flow_fragment * (1 - alpha)
156
+ return merged_fragment
157
+
158
+ def concatenate_features(frame_feature, residual_feature):
159
+ return torch.cat((frame_feature, residual_feature), dim=-1)
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ opencv-python
6
+ joblib
7
+ scikit-learn
8
+ scipy
9
+ numpy
10
+ pandas
11
+ matplotlib
12
+ ipywidgets
13
+ thop
14
+ PyYAML
15
+ seaborn