Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- app.py +87 -0
- demo_test_gpu.py +249 -0
- model_regression.py +693 -0
- relax_vqa.py +159 -0
- requirements.txt +15 -0
app.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from spaces import GPU
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
from torchvision import models
|
| 7 |
+
from joblib import load
|
| 8 |
+
from extractor.visualise_vit_layer import VitGenerator
|
| 9 |
+
from relax_vqa import get_deep_feature, process_video_feature, process_patches, get_frame_patches, flow_to_rgb, merge_fragments, concatenate_features
|
| 10 |
+
from extractor.vf_extract import process_video_residual
|
| 11 |
+
from model_regression import Mlp, preprocess_data
|
| 12 |
+
from demo_test_gpu import evaluate_video_quality, load_model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@GPU
|
| 16 |
+
def run_relax_vqa(video_path, is_finetune, framerate, video_type):
|
| 17 |
+
if not os.path.exists(video_path):
|
| 18 |
+
return "❌ No video uploaded or the uploaded file has expired. Please upload again."
|
| 19 |
+
|
| 20 |
+
config = {
|
| 21 |
+
'is_finetune': is_finetune,
|
| 22 |
+
'framerate': framerate,
|
| 23 |
+
'video_type': video_type,
|
| 24 |
+
'save_path': 'model/',
|
| 25 |
+
'train_data_name': 'lsvq_train',
|
| 26 |
+
'select_criteria': 'byrmse',
|
| 27 |
+
'video_path': video_path,
|
| 28 |
+
'video_name': os.path.splitext(os.path.basename(video_path))[0]
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
+
resnet50 = models.resnet50(pretrained=True).to(device)
|
| 33 |
+
vit = VitGenerator('vit_base', 16, device, evaluate=True, random=False, verbose=False)
|
| 34 |
+
model_mlp = load_model(config, device)
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
score, runtime = evaluate_video_quality(config, resnet50, vit, model_mlp, device)
|
| 38 |
+
return f"Predicted Quality Score: {score:.4f} (in {runtime:.2f}s)"
|
| 39 |
+
except Exception as e:
|
| 40 |
+
return f"❌ Error: {str(e)}"
|
| 41 |
+
finally:
|
| 42 |
+
if "gradio" in video_path and os.path.exists(video_path):
|
| 43 |
+
os.remove(video_path)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def toggle_dataset_visibility(is_finetune):
|
| 47 |
+
return gr.update(visible=is_finetune)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
with gr.Blocks() as demo:
|
| 51 |
+
gr.Markdown("## 🎬 ReLaX-VQA Online Demo")
|
| 52 |
+
gr.Markdown(
|
| 53 |
+
"Upload a short video and get the predicted perceptual quality score using the ReLaX-VQA model. "
|
| 54 |
+
"You can try our demo video from the "
|
| 55 |
+
"<a href='https://huggingface.co/spaces/xinyiW915/ReLaX-VQA/blob/main/ugc_original_videos/5636101558_540p.mp4' target='_blank'>demo video</a> "
|
| 56 |
+
"(fps = 24, dataset = konvid_1k).<br><br>"
|
| 57 |
+
"⚙️ This demo is currently running on <strong>Hugging Face ZeroGPU Space</strong>: Dynamic resources (NVIDIA A100)."
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
with gr.Row():
|
| 61 |
+
with gr.Column(scale=2):
|
| 62 |
+
video_input = gr.Video(label="Upload a Video (e.g. mp4)")
|
| 63 |
+
framerate_slider = gr.Slider(label="Source Video Framerate (fps)", minimum=1, maximum=60, step=1, value=24)
|
| 64 |
+
is_finetune_checkbox = gr.Checkbox(label="Use Finetuning?", value=False)
|
| 65 |
+
dataset_dropdown = gr.Dropdown(
|
| 66 |
+
label="Source Video Dataset for Finetuning",
|
| 67 |
+
choices=["konvid_1k", "youtube_ugc", "live_vqc", "cvd_2014"],
|
| 68 |
+
value="konvid_1k",
|
| 69 |
+
visible=False
|
| 70 |
+
)
|
| 71 |
+
run_button = gr.Button("Run Prediction")
|
| 72 |
+
with gr.Column(scale=1):
|
| 73 |
+
output_box = gr.Textbox(label="Predicted Quality Score", lines=5)
|
| 74 |
+
|
| 75 |
+
is_finetune_checkbox.change(
|
| 76 |
+
fn=toggle_dataset_visibility,
|
| 77 |
+
inputs=is_finetune_checkbox,
|
| 78 |
+
outputs=dataset_dropdown
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
run_button.click(
|
| 82 |
+
fn=run_relax_vqa,
|
| 83 |
+
inputs=[video_input, is_finetune_checkbox, framerate_slider, dataset_dropdown],
|
| 84 |
+
outputs=output_box
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
demo.launch()
|
demo_test_gpu.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import time
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
from joblib import load
|
| 7 |
+
import cv2
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch.utils.data import DataLoader, Dataset
|
| 11 |
+
from thop import profile
|
| 12 |
+
from torchvision import models, transforms
|
| 13 |
+
|
| 14 |
+
from extractor.visualise_vit_layer import VitGenerator
|
| 15 |
+
from relax_vqa import get_deep_feature, process_video_feature, process_patches, get_frame_patches, flow_to_rgb, merge_fragments, concatenate_features
|
| 16 |
+
from extractor.vf_extract import process_video_residual
|
| 17 |
+
from model_regression import Mlp, preprocess_data
|
| 18 |
+
|
| 19 |
+
def fix_state_dict(state_dict):
|
| 20 |
+
new_state_dict = {}
|
| 21 |
+
for k, v in state_dict.items():
|
| 22 |
+
if k.startswith('module.'):
|
| 23 |
+
name = k[7:]
|
| 24 |
+
elif k == 'n_averaged':
|
| 25 |
+
continue
|
| 26 |
+
else:
|
| 27 |
+
name = k
|
| 28 |
+
new_state_dict[name] = v
|
| 29 |
+
return new_state_dict
|
| 30 |
+
|
| 31 |
+
def preprocess_data(X, y=None, imp=None, scaler=None):
|
| 32 |
+
if not isinstance(X, torch.Tensor):
|
| 33 |
+
X = torch.tensor(X, device='cuda' if torch.cuda.is_available() else 'cpu')
|
| 34 |
+
X = torch.where(torch.isnan(X) | torch.isinf(X), torch.tensor(0.0, device=X.device), X)
|
| 35 |
+
|
| 36 |
+
if imp is not None or scaler is not None:
|
| 37 |
+
X_np = X.cpu().numpy()
|
| 38 |
+
if imp is not None:
|
| 39 |
+
X_np = imp.transform(X_np)
|
| 40 |
+
if scaler is not None:
|
| 41 |
+
X_np = scaler.transform(X_np)
|
| 42 |
+
X = torch.from_numpy(X_np).to(X.device)
|
| 43 |
+
|
| 44 |
+
if y is not None and y.size > 0:
|
| 45 |
+
if not isinstance(y, torch.Tensor):
|
| 46 |
+
y = torch.tensor(y, device=X.device)
|
| 47 |
+
y = y.reshape(-1).squeeze()
|
| 48 |
+
else:
|
| 49 |
+
y = None
|
| 50 |
+
|
| 51 |
+
return X, y, imp, scaler
|
| 52 |
+
|
| 53 |
+
def load_model(config, device, input_features=35203):
|
| 54 |
+
network_name = 'relaxvqa'
|
| 55 |
+
# input_features = X_test_processed.shape[1]
|
| 56 |
+
model = Mlp(input_features=input_features, out_features=1, drop_rate=0.2, act_layer=nn.GELU).to(device)
|
| 57 |
+
if config['is_finetune']:
|
| 58 |
+
model_path = os.path.join(config['save_path'], f"fine_tune_model/{config['video_type']}_{network_name}_{config['select_criteria']}_fine_tuned_model.pth")
|
| 59 |
+
else:
|
| 60 |
+
model_path = os.path.join(config['save_path'], f"{config['train_data_name']}_{network_name}_{config['select_criteria']}_trained_median_model_param_onLSVQ_TEST.pth")
|
| 61 |
+
print("Loading model from:", model_path)
|
| 62 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 63 |
+
fixed_state_dict = fix_state_dict(state_dict)
|
| 64 |
+
try:
|
| 65 |
+
model.load_state_dict(fixed_state_dict)
|
| 66 |
+
except RuntimeError as e:
|
| 67 |
+
print(e)
|
| 68 |
+
return model
|
| 69 |
+
|
| 70 |
+
def evaluate_video_quality(config, resnet50, vit, model_mlp, device):
|
| 71 |
+
is_finetune = config['is_finetune']
|
| 72 |
+
save_path = config['save_path']
|
| 73 |
+
video_type = config['video_type']
|
| 74 |
+
video_name = config['video_name']
|
| 75 |
+
framerate = config['framerate']
|
| 76 |
+
sampled_fragment_path = os.path.join("../video_sampled_frame/sampled_frame/", "test_sampled_fragment")
|
| 77 |
+
|
| 78 |
+
video_path = config.get("video_path")
|
| 79 |
+
if video_path is None:
|
| 80 |
+
if video_type == 'youtube_ugc':
|
| 81 |
+
video_path = f'./ugc_original_videos/{video_name}.mkv'
|
| 82 |
+
else:
|
| 83 |
+
video_path = f'./ugc_original_videos/{video_name}.mp4'
|
| 84 |
+
target_size = 224
|
| 85 |
+
patch_size = 16
|
| 86 |
+
top_n = int((target_size / patch_size) * (target_size / patch_size))
|
| 87 |
+
|
| 88 |
+
# sampled video frames
|
| 89 |
+
start_time = time.time()
|
| 90 |
+
frames, frames_next = process_video_residual(video_type, video_name, framerate, video_path, sampled_fragment_path)
|
| 91 |
+
|
| 92 |
+
# get ResNet50 layer-stack features and ViT pooling features
|
| 93 |
+
all_frame_activations_resnet = []
|
| 94 |
+
all_frame_activations_vit = []
|
| 95 |
+
# get fragments ResNet50 features and ViT features
|
| 96 |
+
all_frame_activations_sampled_resnet = []
|
| 97 |
+
all_frame_activations_merged_resnet = []
|
| 98 |
+
all_frame_activations_sampled_vit = []
|
| 99 |
+
all_frame_activations_merged_vit = []
|
| 100 |
+
|
| 101 |
+
batch_size = 64 # Define the number of frames to process in parallel
|
| 102 |
+
for i in range(0, len(frames_next), batch_size):
|
| 103 |
+
batch_frames = frames[i:i + batch_size]
|
| 104 |
+
batch_rgb_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in batch_frames]
|
| 105 |
+
batch_frames_next = frames_next[i:i + batch_size]
|
| 106 |
+
batch_tensors = torch.stack([transforms.ToTensor()(frame) for frame in batch_frames]).to(device)
|
| 107 |
+
batch_rgb_tensors = torch.stack([transforms.ToTensor()(frame_rgb) for frame_rgb in batch_rgb_frames]).to(device)
|
| 108 |
+
batch_tensors_next = torch.stack([transforms.ToTensor()(frame_next) for frame_next in batch_frames_next]).to(device)
|
| 109 |
+
|
| 110 |
+
# compute residuals
|
| 111 |
+
residuals = torch.abs(batch_tensors_next - batch_tensors)
|
| 112 |
+
|
| 113 |
+
# calculate optical flows
|
| 114 |
+
batch_gray_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) for frame in batch_frames]
|
| 115 |
+
batch_gray_frames_next = [cv2.cvtColor(frame_next, cv2.COLOR_BGR2GRAY) for frame_next in batch_frames_next]
|
| 116 |
+
batch_gray_frames = [frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame for frame in batch_gray_frames]
|
| 117 |
+
batch_gray_frames_next = [frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame for frame in batch_gray_frames_next]
|
| 118 |
+
flows = [cv2.calcOpticalFlowFarneback(batch_gray_frames[j], batch_gray_frames_next[j], None, 0.5, 3, 15, 3, 5, 1.2,0) for j in range(len(batch_gray_frames))]
|
| 119 |
+
|
| 120 |
+
for j in range(batch_tensors.size(0)):
|
| 121 |
+
'''sampled video frames'''
|
| 122 |
+
frame_tensor = batch_tensors[j].unsqueeze(0)
|
| 123 |
+
frame_rgb_tensor = batch_rgb_tensors[j].unsqueeze(0)
|
| 124 |
+
# frame_next_tensor = batch_tensors_next[j].unsqueeze(0)
|
| 125 |
+
frame_number = i + j + 1
|
| 126 |
+
|
| 127 |
+
# ResNet50 layer-stack features
|
| 128 |
+
activations_dict_resnet, _, _ = get_deep_feature('resnet50', video_name, frame_rgb_tensor, frame_number, resnet50, device, 'layerstack')
|
| 129 |
+
all_frame_activations_resnet.append(activations_dict_resnet)
|
| 130 |
+
# ViT pooling features
|
| 131 |
+
activations_dict_vit, _, _ = get_deep_feature('vit', video_name, frame_rgb_tensor, frame_number, vit, device, 'pool')
|
| 132 |
+
all_frame_activations_vit.append(activations_dict_vit)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
'''residual video frames'''
|
| 136 |
+
residual = residuals[j].unsqueeze(0)
|
| 137 |
+
flow = flows[j]
|
| 138 |
+
original_path = os.path.join(sampled_fragment_path, f'{video_name}_{frame_number}.png')
|
| 139 |
+
|
| 140 |
+
# Frame Differencing
|
| 141 |
+
residual_frag_path, diff_frag, positions = process_patches(original_path, 'frame_diff', residual, patch_size, target_size, top_n)
|
| 142 |
+
# Frame fragment
|
| 143 |
+
frame_patches = get_frame_patches(frame_tensor, positions, patch_size, target_size)
|
| 144 |
+
# Optical Flow
|
| 145 |
+
opticalflow_rgb = flow_to_rgb(flow)
|
| 146 |
+
opticalflow_rgb_tensor = transforms.ToTensor()(opticalflow_rgb).unsqueeze(0).to(device)
|
| 147 |
+
opticalflow_frag_path, flow_frag, _ = process_patches(original_path, 'optical_flow', opticalflow_rgb_tensor, patch_size, target_size, top_n)
|
| 148 |
+
|
| 149 |
+
merged_frag = merge_fragments(diff_frag, flow_frag)
|
| 150 |
+
|
| 151 |
+
# fragments ResNet50 features
|
| 152 |
+
sampled_frag_activations_resnet, _, _ = get_deep_feature('resnet50', video_name, frame_patches, frame_number, resnet50, device, 'layerstack')
|
| 153 |
+
merged_frag_activations_resnet, _, _ = get_deep_feature('resnet50', video_name, merged_frag, frame_number, resnet50, device, 'pool')
|
| 154 |
+
all_frame_activations_sampled_resnet.append(sampled_frag_activations_resnet)
|
| 155 |
+
all_frame_activations_merged_resnet.append(merged_frag_activations_resnet)
|
| 156 |
+
# fragments ViT features
|
| 157 |
+
sampled_frag_activations_vit,_, _ = get_deep_feature('vit', video_name, frame_patches, frame_number, vit, device, 'pool')
|
| 158 |
+
merged_frag_activations_vit, _, _ = get_deep_feature('vit', video_name, merged_frag, frame_number, vit, device, 'pool')
|
| 159 |
+
all_frame_activations_sampled_vit.append(sampled_frag_activations_vit)
|
| 160 |
+
all_frame_activations_merged_vit.append(merged_frag_activations_vit)
|
| 161 |
+
|
| 162 |
+
print(f'video frame number: {len(all_frame_activations_resnet)}')
|
| 163 |
+
averaged_frames_resnet = process_video_feature(all_frame_activations_resnet, 'resnet50', 'layerstack')
|
| 164 |
+
averaged_frames_vit = process_video_feature(all_frame_activations_vit, 'vit', 'pool')
|
| 165 |
+
# print("ResNet50 layer-stacking feature shape:", averaged_frames_resnet.shape)
|
| 166 |
+
# print("ViT pooling feature shape:", averaged_frames_vit.shape)
|
| 167 |
+
averaged_frames_sampled_resnet = process_video_feature(all_frame_activations_sampled_resnet, 'resnet50', 'layerstack')
|
| 168 |
+
averaged_frames_merged_resnet = process_video_feature(all_frame_activations_merged_resnet, 'resnet50', 'pool')
|
| 169 |
+
averaged_combined_feature_resnet = concatenate_features(averaged_frames_sampled_resnet, averaged_frames_merged_resnet)
|
| 170 |
+
# print("Sampled fragments ResNet50 features shape:", averaged_frames_sampled_resnet.shape)
|
| 171 |
+
# print("Merged fragments ResNet50 features shape:", averaged_frames_merged_resnet.shape)
|
| 172 |
+
averaged_frames_sampled_vit = process_video_feature(all_frame_activations_sampled_vit, 'vit', 'pool')
|
| 173 |
+
averaged_frames_merged_vit = process_video_feature(all_frame_activations_merged_vit, 'vit', 'pool')
|
| 174 |
+
averaged_combined_feature_vit = concatenate_features(averaged_frames_sampled_vit, averaged_frames_merged_vit)
|
| 175 |
+
# print("Sampled fragments ViT features shape:", averaged_frames_sampled_vit.shape)
|
| 176 |
+
# print("Merged fragments ResNet50 features shape:", averaged_frames_merged_vit.shape)
|
| 177 |
+
|
| 178 |
+
# remove tmp folders
|
| 179 |
+
shutil.rmtree(sampled_fragment_path)
|
| 180 |
+
|
| 181 |
+
# concatenate features
|
| 182 |
+
combined_features = torch.cat([torch.mean(averaged_frames_resnet, dim=0), torch.mean(averaged_frames_vit, dim=0),
|
| 183 |
+
torch.mean(averaged_combined_feature_resnet, dim=0), torch.mean(averaged_combined_feature_vit, dim=0)], dim=0).view(1, -1)
|
| 184 |
+
imputer = load(f'{save_path}/scaler/{video_type}_imputer.pkl')
|
| 185 |
+
scaler = load(f'{save_path}/scaler/{video_type}_scaler.pkl')
|
| 186 |
+
X_test_processed, _, _, _ = preprocess_data(combined_features, None, imp=imputer, scaler=scaler)
|
| 187 |
+
feature_tensor = X_test_processed
|
| 188 |
+
|
| 189 |
+
# evaluation for test video
|
| 190 |
+
model_mlp.eval()
|
| 191 |
+
|
| 192 |
+
with torch.no_grad():
|
| 193 |
+
with torch.cuda.amp.autocast():
|
| 194 |
+
prediction = model_mlp(feature_tensor)
|
| 195 |
+
predicted_score = prediction.item()
|
| 196 |
+
# print(f"Raw Predicted Quality Score: {predicted_score}")
|
| 197 |
+
run_time = time.time() - start_time
|
| 198 |
+
|
| 199 |
+
if not is_finetune:
|
| 200 |
+
if video_type in ['konvid_1k', 'youtube_ugc']:
|
| 201 |
+
scaled_prediction = ((predicted_score - 1) / (99 / 4)) + 1.0
|
| 202 |
+
# print(f"Scaled Predicted Quality Score (1-5): {scaled_prediction}")
|
| 203 |
+
return scaled_prediction, run_time
|
| 204 |
+
else:
|
| 205 |
+
scaled_prediction = predicted_score
|
| 206 |
+
return scaled_prediction, run_time
|
| 207 |
+
else:
|
| 208 |
+
return predicted_score, run_time
|
| 209 |
+
|
| 210 |
+
def parse_arguments():
|
| 211 |
+
parser = argparse.ArgumentParser()
|
| 212 |
+
parser.add_argument('-device', type=str, default='gpu', help='cpu or gpu')
|
| 213 |
+
parser.add_argument('-model_name', type=str, default='Mlp', help='Name of the regression model')
|
| 214 |
+
parser.add_argument('-select_criteria', type=str, default='byrmse', help='Selection criteria')
|
| 215 |
+
parser.add_argument('-train_data_name', type=str, default='lsvq_train', help='Name of the training data')
|
| 216 |
+
parser.add_argument('-is_finetune', type=bool, default=False, help='With or without finetune')
|
| 217 |
+
parser.add_argument('-save_path', type=str, default='model/', help='Path to save models')
|
| 218 |
+
parser.add_argument('-video_type', type=str, default='konvid_1k', help='Type of video')
|
| 219 |
+
parser.add_argument('-video_name', type=str, default='5636101558_540p', help='Name of the video')
|
| 220 |
+
parser.add_argument('-framerate', type=float, default=24, help='Frame rate of the video')
|
| 221 |
+
|
| 222 |
+
args = parser.parse_args()
|
| 223 |
+
return args
|
| 224 |
+
|
| 225 |
+
if __name__ == '__main__':
|
| 226 |
+
args = parse_arguments()
|
| 227 |
+
config = vars(args)
|
| 228 |
+
if config['device'] == "gpu":
|
| 229 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 230 |
+
else:
|
| 231 |
+
device = torch.device("cpu")
|
| 232 |
+
print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")
|
| 233 |
+
|
| 234 |
+
# load models to device
|
| 235 |
+
resnet50 = models.resnet50(pretrained=True).to(device)
|
| 236 |
+
vit = VitGenerator('vit_base', 16, device, evaluate=True, random=False, verbose=True)
|
| 237 |
+
model_mlp = load_model(config, device)
|
| 238 |
+
|
| 239 |
+
total_time = 0
|
| 240 |
+
num_runs = 1
|
| 241 |
+
for i in range(num_runs):
|
| 242 |
+
quality_prediction, run_time = evaluate_video_quality(config, resnet50, vit, model_mlp, device)
|
| 243 |
+
print(f"Run {i + 1} - Time taken: {run_time:.4f} seconds")
|
| 244 |
+
|
| 245 |
+
total_time += run_time
|
| 246 |
+
average_time = total_time / num_runs
|
| 247 |
+
|
| 248 |
+
print(f"Average running time over {num_runs} runs: {average_time:.4f} seconds")
|
| 249 |
+
print("Predicted Quality Score:", quality_prediction)
|
model_regression.py
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
import os
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import math
|
| 7 |
+
import scipy.io
|
| 8 |
+
import scipy.stats
|
| 9 |
+
from sklearn.impute import SimpleImputer
|
| 10 |
+
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
| 11 |
+
from sklearn.metrics import mean_squared_error
|
| 12 |
+
from scipy.optimize import curve_fit
|
| 13 |
+
import joblib
|
| 14 |
+
|
| 15 |
+
import seaborn as sns
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
import copy
|
| 18 |
+
import argparse
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import torch.optim as optim
|
| 24 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 25 |
+
from torch.optim.swa_utils import AveragedModel, SWALR
|
| 26 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 27 |
+
from sklearn.model_selection import KFold
|
| 28 |
+
from sklearn.model_selection import train_test_split
|
| 29 |
+
|
| 30 |
+
from data_processing import split_train_test
|
| 31 |
+
|
| 32 |
+
# ignore all warnings
|
| 33 |
+
import warnings
|
| 34 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Mlp(nn.Module):
|
| 38 |
+
def __init__(self, input_features, hidden_features=256, out_features=1, drop_rate=0.2, act_layer=nn.GELU):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.fc1 = nn.Linear(input_features, hidden_features)
|
| 41 |
+
self.bn1 = nn.BatchNorm1d(hidden_features)
|
| 42 |
+
self.act1 = act_layer()
|
| 43 |
+
self.drop1 = nn.Dropout(drop_rate)
|
| 44 |
+
self.fc2 = nn.Linear(hidden_features, hidden_features // 2)
|
| 45 |
+
self.act2 = act_layer()
|
| 46 |
+
self.drop2 = nn.Dropout(drop_rate)
|
| 47 |
+
self.fc3 = nn.Linear(hidden_features // 2, out_features)
|
| 48 |
+
|
| 49 |
+
def forward(self, input_feature):
|
| 50 |
+
x = self.fc1(input_feature)
|
| 51 |
+
x = self.bn1(x)
|
| 52 |
+
x = self.act1(x)
|
| 53 |
+
x = self.drop1(x)
|
| 54 |
+
x = self.fc2(x)
|
| 55 |
+
x = self.act2(x)
|
| 56 |
+
x = self.drop2(x)
|
| 57 |
+
output = self.fc3(x)
|
| 58 |
+
return output
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MAEAndRankLoss(nn.Module):
|
| 62 |
+
def __init__(self, l1_w=1.0, rank_w=1.0, margin=0.0, use_margin=False):
|
| 63 |
+
super(MAEAndRankLoss, self).__init__()
|
| 64 |
+
self.l1_w = l1_w
|
| 65 |
+
self.rank_w = rank_w
|
| 66 |
+
self.margin = margin
|
| 67 |
+
self.use_margin = use_margin
|
| 68 |
+
|
| 69 |
+
def forward(self, y_pred, y_true):
|
| 70 |
+
# L1 loss/MAE loss
|
| 71 |
+
l_mae = F.l1_loss(y_pred, y_true, reduction='mean') * self.l1_w
|
| 72 |
+
# Rank loss
|
| 73 |
+
n = y_pred.size(0)
|
| 74 |
+
pred_diff = y_pred.unsqueeze(1) - y_pred.unsqueeze(0)
|
| 75 |
+
true_diff = y_true.unsqueeze(1) - y_true.unsqueeze(0)
|
| 76 |
+
|
| 77 |
+
# e(ytrue_i, ytrue_j)
|
| 78 |
+
masks = torch.sign(true_diff)
|
| 79 |
+
|
| 80 |
+
if self.use_margin and self.margin > 0:
|
| 81 |
+
true_diff = true_diff.abs() - self.margin
|
| 82 |
+
true_diff = F.relu(true_diff)
|
| 83 |
+
masks = true_diff.sign()
|
| 84 |
+
|
| 85 |
+
l_rank = F.relu(true_diff - masks * pred_diff)
|
| 86 |
+
l_rank = l_rank.sum() / (n * (n - 1))
|
| 87 |
+
|
| 88 |
+
loss = l_mae + l_rank * self.rank_w
|
| 89 |
+
return loss
|
| 90 |
+
|
| 91 |
+
def load_data(csv_file, mat_file, features, data_name, set_name):
|
| 92 |
+
try:
|
| 93 |
+
df = pd.read_csv(csv_file, skiprows=[], header=None)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logging.error(f'Read CSV file error: {e}')
|
| 96 |
+
raise
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
if data_name == 'lsvq_train':
|
| 100 |
+
X_mat = features
|
| 101 |
+
else:
|
| 102 |
+
X_mat = scipy.io.loadmat(mat_file)
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logging.error(f'Read MAT file error: {e}')
|
| 105 |
+
raise
|
| 106 |
+
|
| 107 |
+
y_data = df.values[1:, 2]
|
| 108 |
+
y = np.array(list(y_data), dtype=float)
|
| 109 |
+
|
| 110 |
+
if data_name == 'cross_dataset': # or data_name == 'lsvq_train':
|
| 111 |
+
y[y > 5] = 5
|
| 112 |
+
if set_name == 'test':
|
| 113 |
+
print(f"Modified y_true: {y}")
|
| 114 |
+
if data_name == 'lsvq_train':
|
| 115 |
+
X = np.asarray(X_mat, dtype=float)
|
| 116 |
+
else:
|
| 117 |
+
data_name = f'{data_name}_{set_name}_features'
|
| 118 |
+
X = np.asarray(X_mat[data_name], dtype=float)
|
| 119 |
+
|
| 120 |
+
return X, y
|
| 121 |
+
|
| 122 |
+
def preprocess_data(X, y):
|
| 123 |
+
X[np.isnan(X)] = 0
|
| 124 |
+
X[np.isinf(X)] = 0
|
| 125 |
+
imp = SimpleImputer(missing_values=np.nan, strategy='mean').fit(X)
|
| 126 |
+
X = imp.transform(X)
|
| 127 |
+
|
| 128 |
+
# scaler = StandardScaler()
|
| 129 |
+
scaler = MinMaxScaler().fit(X)
|
| 130 |
+
X = scaler.transform(X)
|
| 131 |
+
logging.info(f'Scaler: {scaler}')
|
| 132 |
+
|
| 133 |
+
y = y.reshape(-1, 1).squeeze()
|
| 134 |
+
return X, y, imp, scaler
|
| 135 |
+
|
| 136 |
+
# define 4-parameter logistic regression
|
| 137 |
+
def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
|
| 138 |
+
logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4))))
|
| 139 |
+
yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart)
|
| 140 |
+
return yhat
|
| 141 |
+
|
| 142 |
+
def fit_logistic_regression(y_pred, y_true):
|
| 143 |
+
beta = [np.max(y_true), np.min(y_true), np.mean(y_pred), 0.5]
|
| 144 |
+
popt, _ = curve_fit(logistic_func, y_pred, y_true, p0=beta, maxfev=100000000)
|
| 145 |
+
y_pred_logistic = logistic_func(y_pred, *popt)
|
| 146 |
+
return y_pred_logistic, beta, popt
|
| 147 |
+
|
| 148 |
+
def compute_correlation_metrics(y_true, y_pred):
|
| 149 |
+
y_pred_logistic, beta, popt = fit_logistic_regression(y_pred, y_true)
|
| 150 |
+
|
| 151 |
+
plcc = scipy.stats.pearsonr(y_true, y_pred_logistic)[0]
|
| 152 |
+
rmse = np.sqrt(mean_squared_error(y_true, y_pred_logistic))
|
| 153 |
+
srcc = scipy.stats.spearmanr(y_true, y_pred)[0]
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
krcc = scipy.stats.kendalltau(y_true, y_pred)[0]
|
| 157 |
+
except Exception as e:
|
| 158 |
+
logging.error(f'krcc calculation: {e}')
|
| 159 |
+
krcc = scipy.stats.kendalltau(y_true, y_pred, method='asymptotic')[0]
|
| 160 |
+
return y_pred_logistic, plcc, rmse, srcc, krcc
|
| 161 |
+
|
| 162 |
+
def plot_results(y_test, y_test_pred_logistic, df_pred_score, model_name, data_name, network_name, select_criteria):
|
| 163 |
+
# nonlinear logistic fitted curve / logistic regression
|
| 164 |
+
mos1 = y_test
|
| 165 |
+
y1 = y_test_pred_logistic
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
beta = [np.max(mos1), np.min(mos1), np.mean(y1), 0.5]
|
| 169 |
+
popt, pcov = curve_fit(logistic_func, y1, mos1, p0=beta, maxfev=100000000)
|
| 170 |
+
sigma = np.sqrt(np.diag(pcov))
|
| 171 |
+
except:
|
| 172 |
+
raise Exception('Fitting logistic function time-out!!')
|
| 173 |
+
x_values1 = np.linspace(np.min(y1), np.max(y1), len(y1))
|
| 174 |
+
plt.plot(x_values1, logistic_func(x_values1, *popt), '-', color='#c72e29', label='Fitted f(x)')
|
| 175 |
+
|
| 176 |
+
fig1 = sns.scatterplot(x="y_test_pred_logistic", y="MOS", data=df_pred_score, markers='o', color='steelblue', label=network_name)
|
| 177 |
+
plt.legend(loc='upper left')
|
| 178 |
+
if data_name == 'live_vqc' or data_name == 'live_qualcomm' or data_name == 'cvd_2014' or data_name == 'lsvq_train':
|
| 179 |
+
plt.ylim(0, 100)
|
| 180 |
+
plt.xlim(0, 100)
|
| 181 |
+
else:
|
| 182 |
+
plt.ylim(1, 5)
|
| 183 |
+
plt.xlim(1, 5)
|
| 184 |
+
plt.title(f"Algorithm {network_name} with {model_name} on dataset {data_name}", fontsize=10)
|
| 185 |
+
plt.xlabel('Predicted Score')
|
| 186 |
+
plt.ylabel('MOS')
|
| 187 |
+
reg_fig1 = fig1.get_figure()
|
| 188 |
+
|
| 189 |
+
fig_path = f'../figs/{data_name}/'
|
| 190 |
+
os.makedirs(fig_path, exist_ok=True)
|
| 191 |
+
reg_fig1.savefig(fig_path + f"{network_name}_{model_name}_{data_name}_by{select_criteria}_kfold.png", dpi=300)
|
| 192 |
+
plt.clf()
|
| 193 |
+
plt.close()
|
| 194 |
+
|
| 195 |
+
def plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, test_vids, i):
|
| 196 |
+
plt.figure(figsize=(10, 6))
|
| 197 |
+
|
| 198 |
+
plt.plot(avg_train_losses, label='Average Training Loss')
|
| 199 |
+
plt.plot(avg_val_losses, label='Average Validation Loss')
|
| 200 |
+
|
| 201 |
+
plt.xlabel('Epoch')
|
| 202 |
+
plt.ylabel('Loss')
|
| 203 |
+
plt.title(f'Average Training and Validation Loss Across Folds - {network_name} with {model_name} (test_vids: {test_vids})', fontsize=10)
|
| 204 |
+
|
| 205 |
+
plt.legend()
|
| 206 |
+
fig_par_path = f'../log/result/{data_name}/'
|
| 207 |
+
os.makedirs(fig_par_path, exist_ok=True)
|
| 208 |
+
plt.savefig(f'{fig_par_path}/{network_name}_Average_Training_Loss_test{i}.png', dpi=50)
|
| 209 |
+
plt.clf()
|
| 210 |
+
plt.close()
|
| 211 |
+
|
| 212 |
+
def configure_logging(log_path, model_name, data_name, network_name, select_criteria):
|
| 213 |
+
log_file_name = os.path.join(log_path, f"{data_name}_{network_name}_{model_name}_corr_{select_criteria}_kfold.log")
|
| 214 |
+
logging.basicConfig(filename=log_file_name, filemode='w', level=logging.DEBUG, format='%(levelname)s - %(message)s')
|
| 215 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 216 |
+
logging.info(f"Evaluating algorithm {network_name} with {model_name} on dataset {data_name}")
|
| 217 |
+
logging.info(f"torch cuda: {torch.cuda.is_available()}")
|
| 218 |
+
|
| 219 |
+
def load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features):
|
| 220 |
+
if data_name == 'cross_dataset':
|
| 221 |
+
data_name1 = 'youtube_ugc_all'
|
| 222 |
+
data_name2 = 'cvd_2014_all'
|
| 223 |
+
csv_train_file = os.path.join(metadata_path, f'mos_files/{data_name1}_MOS_train.csv')
|
| 224 |
+
csv_test_file = os.path.join(metadata_path, f'mos_files/{data_name2}_MOS_test.csv')
|
| 225 |
+
mat_train_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name1}_{network_name}_train_features.mat')
|
| 226 |
+
mat_test_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name2}_{network_name}_test_features.mat')
|
| 227 |
+
X_train, y_train = load_data(csv_train_file, mat_train_file, None, data_name1, 'train')
|
| 228 |
+
X_test, y_test = load_data(csv_test_file, mat_test_file, None, data_name2, 'test')
|
| 229 |
+
|
| 230 |
+
elif data_name == 'lsvq_train':
|
| 231 |
+
csv_train_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
|
| 232 |
+
csv_test_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
|
| 233 |
+
X_train, y_train = load_data(csv_train_file, None, train_features, data_name, 'train')
|
| 234 |
+
X_test, y_test = load_data(csv_test_file, None, test_features, data_name, 'test')
|
| 235 |
+
|
| 236 |
+
else:
|
| 237 |
+
csv_train_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
|
| 238 |
+
csv_test_file = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
|
| 239 |
+
mat_train_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name}_{network_name}_train_features.mat')
|
| 240 |
+
mat_test_file = os.path.join(f'{feature_path}split_train_test/', f'{data_name}_{network_name}_test_features.mat')
|
| 241 |
+
X_train, y_train = load_data(csv_train_file, mat_train_file, None, data_name, 'train')
|
| 242 |
+
X_test, y_test = load_data(csv_test_file, mat_test_file, None, data_name, 'test')
|
| 243 |
+
|
| 244 |
+
# standard min-max normalization of traning features
|
| 245 |
+
X_train, y_train, _, _ = preprocess_data(X_train, y_train)
|
| 246 |
+
X_test, y_test, _, _ = preprocess_data(X_test, y_test)
|
| 247 |
+
|
| 248 |
+
return X_train, y_train, X_test, y_test
|
| 249 |
+
|
| 250 |
+
def train_one_epoch(model, train_loader, criterion, optimizer, device):
|
| 251 |
+
"""Train the model for one epoch"""
|
| 252 |
+
model.train()
|
| 253 |
+
train_loss = 0.0
|
| 254 |
+
for inputs, targets in train_loader:
|
| 255 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
| 256 |
+
|
| 257 |
+
optimizer.zero_grad()
|
| 258 |
+
outputs = model(inputs)
|
| 259 |
+
loss = criterion(outputs, targets.view(-1, 1))
|
| 260 |
+
loss.backward()
|
| 261 |
+
optimizer.step()
|
| 262 |
+
train_loss += loss.item() * inputs.size(0)
|
| 263 |
+
train_loss /= len(train_loader.dataset)
|
| 264 |
+
return train_loss
|
| 265 |
+
|
| 266 |
+
def evaluate(model, val_loader, criterion, device):
|
| 267 |
+
"""Evaluate model performance on validation sets"""
|
| 268 |
+
model.eval()
|
| 269 |
+
val_loss = 0.0
|
| 270 |
+
y_val_pred = []
|
| 271 |
+
with torch.no_grad():
|
| 272 |
+
for inputs, targets in val_loader:
|
| 273 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
| 274 |
+
|
| 275 |
+
outputs = model(inputs)
|
| 276 |
+
y_val_pred.extend(outputs.view(-1).tolist())
|
| 277 |
+
loss = criterion(outputs, targets.view(-1, 1))
|
| 278 |
+
val_loss += loss.item() * inputs.size(0)
|
| 279 |
+
val_loss /= len(val_loader.dataset)
|
| 280 |
+
return val_loss, np.array(y_val_pred)
|
| 281 |
+
|
| 282 |
+
def update_best_model(select_criteria, best_metric, current_val, model):
|
| 283 |
+
is_better = False
|
| 284 |
+
if select_criteria == 'byrmse' and current_val < best_metric:
|
| 285 |
+
is_better = True
|
| 286 |
+
elif select_criteria == 'bykrcc' and current_val > best_metric:
|
| 287 |
+
is_better = True
|
| 288 |
+
|
| 289 |
+
if is_better:
|
| 290 |
+
return current_val, copy.deepcopy(model), is_better
|
| 291 |
+
return best_metric, model, is_better
|
| 292 |
+
|
| 293 |
+
def train_and_evaluate(X_train, y_train, config):
|
| 294 |
+
# parameters
|
| 295 |
+
n_repeats = config['n_repeats']
|
| 296 |
+
n_splits = config['n_splits']
|
| 297 |
+
batch_size = config['batch_size']
|
| 298 |
+
epochs = config['epochs']
|
| 299 |
+
hidden_features = config['hidden_features']
|
| 300 |
+
drop_rate = config['drop_rate']
|
| 301 |
+
loss_type = config['loss_type']
|
| 302 |
+
optimizer_type = config['optimizer_type']
|
| 303 |
+
select_criteria = config['select_criteria']
|
| 304 |
+
initial_lr = config['initial_lr']
|
| 305 |
+
weight_decay = config['weight_decay']
|
| 306 |
+
patience = config['patience']
|
| 307 |
+
l1_w = config['l1_w']
|
| 308 |
+
rank_w = config['rank_w']
|
| 309 |
+
use_swa = config.get('use_swa', False)
|
| 310 |
+
logging.info(f'Parameters - Number of repeats for 80-20 hold out test: {n_repeats}, Number of splits for kfold: {n_splits}, Batch size: {batch_size}, Number of epochs: {epochs}')
|
| 311 |
+
logging.info(f'Network Parameters - hidden_features: {hidden_features}, drop_rate: {drop_rate}, patience: {patience}')
|
| 312 |
+
logging.info(f'Optimizer Parameters - loss_type: {loss_type}, optimizer_type: {optimizer_type}, initial_lr: {initial_lr}, weight_decay: {weight_decay}, use_swa: {use_swa}')
|
| 313 |
+
logging.info(f'MAEAndRankLoss - l1_w: {l1_w}, rank_w: {rank_w}')
|
| 314 |
+
|
| 315 |
+
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
|
| 316 |
+
best_model = None
|
| 317 |
+
best_metric = float('inf') if select_criteria == 'byrmse' else float('-inf')
|
| 318 |
+
|
| 319 |
+
# loss for every fold
|
| 320 |
+
all_train_losses = []
|
| 321 |
+
all_val_losses = []
|
| 322 |
+
for fold, (train_idx, val_idx) in enumerate(kf.split(X_train)):
|
| 323 |
+
print(f"Fold {fold + 1}/{n_splits}")
|
| 324 |
+
|
| 325 |
+
X_train_fold, X_val_fold = X_train[train_idx], X_train[val_idx]
|
| 326 |
+
y_train_fold, y_val_fold = y_train[train_idx], y_train[val_idx]
|
| 327 |
+
|
| 328 |
+
# initialisation of model, loss function, optimiser
|
| 329 |
+
model = Mlp(input_features=X_train_fold.shape[1], hidden_features=hidden_features, drop_rate=drop_rate)
|
| 330 |
+
model = model.to(device) # to gpu
|
| 331 |
+
|
| 332 |
+
if loss_type == 'MAERankLoss':
|
| 333 |
+
criterion = MAEAndRankLoss()
|
| 334 |
+
criterion.l1_w = l1_w
|
| 335 |
+
criterion.rank_w = rank_w
|
| 336 |
+
else:
|
| 337 |
+
nn.MSELoss()
|
| 338 |
+
|
| 339 |
+
if optimizer_type == 'sgd':
|
| 340 |
+
optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=weight_decay)
|
| 341 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)# initial eta_nim=1e-5
|
| 342 |
+
else:
|
| 343 |
+
optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay) # L2 Regularisation initial: 0.01, 1e-5
|
| 344 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95) # step_size=10, gamma=0.1: every 10 epochs lr*0.1
|
| 345 |
+
if use_swa:
|
| 346 |
+
swa_model = AveragedModel(model).to(device)
|
| 347 |
+
swa_scheduler = SWALR(optimizer, swa_lr=initial_lr, anneal_strategy='cos')
|
| 348 |
+
|
| 349 |
+
# dataset loader
|
| 350 |
+
train_dataset = TensorDataset(torch.FloatTensor(X_train_fold), torch.FloatTensor(y_train_fold))
|
| 351 |
+
val_dataset = TensorDataset(torch.FloatTensor(X_val_fold), torch.FloatTensor(y_val_fold))
|
| 352 |
+
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
|
| 353 |
+
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
|
| 354 |
+
|
| 355 |
+
train_losses, val_losses = [], []
|
| 356 |
+
|
| 357 |
+
# early stopping parameters
|
| 358 |
+
best_val_loss = float('inf')
|
| 359 |
+
epochs_no_improve = 0
|
| 360 |
+
early_stop_active = False
|
| 361 |
+
swa_start = int(epochs * 0.7) if use_swa else epochs # SWA starts after 70% of total epochs, only set SWA start if SWA is used
|
| 362 |
+
|
| 363 |
+
for epoch in range(epochs):
|
| 364 |
+
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
|
| 365 |
+
train_losses.append(train_loss)
|
| 366 |
+
scheduler.step() # update learning rate
|
| 367 |
+
if use_swa and epoch >= swa_start:
|
| 368 |
+
swa_model.update_parameters(model)
|
| 369 |
+
swa_scheduler.step()
|
| 370 |
+
early_stop_active = True
|
| 371 |
+
print(f"Current learning rate with SWA: {swa_scheduler.get_last_lr()}")
|
| 372 |
+
|
| 373 |
+
lr = optimizer.param_groups[0]['lr']
|
| 374 |
+
print('Epoch %d: Learning rate: %f' % (epoch + 1, lr))
|
| 375 |
+
|
| 376 |
+
# decide which model to evaluate: SWA model or regular model
|
| 377 |
+
current_model = swa_model if use_swa and epoch >= swa_start else model
|
| 378 |
+
current_model.eval()
|
| 379 |
+
val_loss, y_val_pred = evaluate(current_model, val_loader, criterion, device)
|
| 380 |
+
val_losses.append(val_loss)
|
| 381 |
+
print(f"Epoch {epoch + 1}, Fold {fold + 1}, Training Loss: {train_loss}, Validation Loss: {val_loss}")
|
| 382 |
+
|
| 383 |
+
y_val_pred = np.array(list(y_val_pred), dtype=float)
|
| 384 |
+
_, _, rmse_val, _, krcc_val = compute_correlation_metrics(y_val_fold, y_val_pred)
|
| 385 |
+
current_metric = rmse_val if select_criteria == 'byrmse' else krcc_val
|
| 386 |
+
best_metric, best_model, is_better = update_best_model(select_criteria, best_metric, current_metric, current_model)
|
| 387 |
+
if is_better:
|
| 388 |
+
logging.info(f"Epoch {epoch + 1}, Fold {fold + 1}:")
|
| 389 |
+
y_val_pred_logistic_tmp, plcc_valid_tmp, rmse_valid_tmp, srcc_valid_tmp, krcc_valid_tmp = compute_correlation_metrics(y_val_fold, y_val_pred)
|
| 390 |
+
logging.info(f'Validation set - Evaluation Results - SRCC: {srcc_valid_tmp}, KRCC: {krcc_valid_tmp}, PLCC: {plcc_valid_tmp}, RMSE: {rmse_valid_tmp}')
|
| 391 |
+
|
| 392 |
+
X_train_fold_tensor = torch.FloatTensor(X_train_fold).to(device)
|
| 393 |
+
y_tra_pred_tmp = best_model(X_train_fold_tensor).detach().cpu().numpy().squeeze()
|
| 394 |
+
y_tra_pred_tmp = np.array(list(y_tra_pred_tmp), dtype=float)
|
| 395 |
+
y_tra_pred_logistic_tmp, plcc_train_tmp, rmse_train_tmp, srcc_train_tmp, krcc_train_tmp = compute_correlation_metrics(y_train_fold, y_tra_pred_tmp)
|
| 396 |
+
logging.info(f'Train set - Evaluation Results - SRCC: {srcc_train_tmp}, KRCC: {krcc_train_tmp}, PLCC: {plcc_train_tmp}, RMSE: {rmse_train_tmp}')
|
| 397 |
+
|
| 398 |
+
# check for loss improvement
|
| 399 |
+
if early_stop_active:
|
| 400 |
+
if val_loss < best_val_loss:
|
| 401 |
+
best_val_loss = val_loss
|
| 402 |
+
# save the best model if validation loss improves
|
| 403 |
+
best_model = copy.deepcopy(model)
|
| 404 |
+
epochs_no_improve = 0
|
| 405 |
+
else:
|
| 406 |
+
epochs_no_improve += 1
|
| 407 |
+
if epochs_no_improve >= patience:
|
| 408 |
+
# epochs to wait for improvement before stopping
|
| 409 |
+
print(f"Early stopping triggered after {epoch + 1} epochs.")
|
| 410 |
+
break
|
| 411 |
+
|
| 412 |
+
# saving SWA models and updating BN statistics
|
| 413 |
+
if use_swa:
|
| 414 |
+
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_to_device(x, device))
|
| 415 |
+
best_model = best_model.to(device)
|
| 416 |
+
best_model.eval()
|
| 417 |
+
torch.optim.swa_utils.update_bn(train_loader, best_model)
|
| 418 |
+
# swa_model_path = os.path.join('save_swa_path='../model/', f'model_swa_fold{fold}.pth')
|
| 419 |
+
# torch.save(swa_model.state_dict(), swa_model_path)
|
| 420 |
+
# logging.info(f'SWA model saved at {swa_model_path}')
|
| 421 |
+
|
| 422 |
+
all_train_losses.append(train_losses)
|
| 423 |
+
all_val_losses.append(val_losses)
|
| 424 |
+
max_length = max(len(x) for x in all_train_losses)
|
| 425 |
+
all_train_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_train_losses]
|
| 426 |
+
max_length = max(len(x) for x in all_val_losses)
|
| 427 |
+
all_val_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_val_losses]
|
| 428 |
+
|
| 429 |
+
return best_model, all_train_losses, all_val_losses
|
| 430 |
+
|
| 431 |
+
def collate_to_device(batch, device):
|
| 432 |
+
data, targets = zip(*batch)
|
| 433 |
+
return torch.stack(data).to(device), torch.stack(targets).to(device)
|
| 434 |
+
|
| 435 |
+
def model_test(best_model, X, y, device):
|
| 436 |
+
test_dataset = TensorDataset(torch.FloatTensor(X), torch.FloatTensor(y))
|
| 437 |
+
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
|
| 438 |
+
|
| 439 |
+
best_model.eval()
|
| 440 |
+
y_pred = []
|
| 441 |
+
with torch.no_grad():
|
| 442 |
+
for inputs, _ in test_loader:
|
| 443 |
+
inputs = inputs.to(device)
|
| 444 |
+
|
| 445 |
+
outputs = best_model(inputs)
|
| 446 |
+
y_pred.extend(outputs.view(-1).tolist())
|
| 447 |
+
|
| 448 |
+
return y_pred
|
| 449 |
+
|
| 450 |
+
def main(config):
|
| 451 |
+
model_name = config['model_name']
|
| 452 |
+
data_name = config['data_name']
|
| 453 |
+
network_name = config['network_name']
|
| 454 |
+
|
| 455 |
+
metadata_path = config['metadata_path']
|
| 456 |
+
feature_path = config['feature_path']
|
| 457 |
+
log_path = config['log_path']
|
| 458 |
+
save_path = config['save_path']
|
| 459 |
+
score_path = config['score_path']
|
| 460 |
+
result_path = config['result_path']
|
| 461 |
+
|
| 462 |
+
# parameters
|
| 463 |
+
select_criteria = config['select_criteria']
|
| 464 |
+
n_repeats = config['n_repeats']
|
| 465 |
+
|
| 466 |
+
# logging and result
|
| 467 |
+
os.makedirs(log_path, exist_ok=True)
|
| 468 |
+
os.makedirs(save_path, exist_ok=True)
|
| 469 |
+
os.makedirs(score_path, exist_ok=True)
|
| 470 |
+
os.makedirs(result_path, exist_ok=True)
|
| 471 |
+
result_file = f'{result_path}{data_name}_{network_name}_{select_criteria}.mat'
|
| 472 |
+
pred_score_filename = os.path.join(score_path, f"{data_name}_{network_name}_{select_criteria}.csv")
|
| 473 |
+
file_path = os.path.join(save_path, f"{data_name}_{network_name}_{select_criteria}_trained_median_model_param.pth")
|
| 474 |
+
configure_logging(log_path, model_name, data_name, network_name, select_criteria)
|
| 475 |
+
|
| 476 |
+
'''======================== Main Body ==========================='''
|
| 477 |
+
PLCC_all_repeats_test = []
|
| 478 |
+
SRCC_all_repeats_test = []
|
| 479 |
+
KRCC_all_repeats_test = []
|
| 480 |
+
RMSE_all_repeats_test = []
|
| 481 |
+
PLCC_all_repeats_train = []
|
| 482 |
+
SRCC_all_repeats_train = []
|
| 483 |
+
KRCC_all_repeats_train = []
|
| 484 |
+
RMSE_all_repeats_train = []
|
| 485 |
+
all_repeats_test_vids = []
|
| 486 |
+
all_repeats_df_test_pred = []
|
| 487 |
+
best_model_list = []
|
| 488 |
+
|
| 489 |
+
for i in range(1, n_repeats + 1):
|
| 490 |
+
print(f"{i}th repeated 80-20 hold out test")
|
| 491 |
+
logging.info(f"{i}th repeated 80-20 hold out test")
|
| 492 |
+
t0 = time.time()
|
| 493 |
+
|
| 494 |
+
# train test split
|
| 495 |
+
test_size = 0.2
|
| 496 |
+
random_state = math.ceil(8.8 * i)
|
| 497 |
+
# NR: original
|
| 498 |
+
if data_name == 'lsvq_train':
|
| 499 |
+
test_data_name = 'lsvq_test' #lsvq_test, lsvq_test_1080p
|
| 500 |
+
train_features, test_features, test_vids = split_train_test.process_lsvq(data_name, test_data_name, metadata_path, feature_path, network_name)
|
| 501 |
+
elif data_name == 'cross_dataset':
|
| 502 |
+
train_data_name = 'youtube_ugc_all'
|
| 503 |
+
test_data_name = 'cvd_2014_all'
|
| 504 |
+
_, _, test_vids = split_train_test.process_cross_dataset(train_data_name, test_data_name, metadata_path, feature_path, network_name)
|
| 505 |
+
else:
|
| 506 |
+
_, _, test_vids = split_train_test.process_other(data_name, test_size, random_state, metadata_path, feature_path, network_name)
|
| 507 |
+
|
| 508 |
+
'''======================== read files =============================== '''
|
| 509 |
+
if data_name == 'lsvq_train':
|
| 510 |
+
X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features)
|
| 511 |
+
else:
|
| 512 |
+
X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, None, None)
|
| 513 |
+
|
| 514 |
+
'''======================== regression model =============================== '''
|
| 515 |
+
best_model, all_train_losses, all_val_losses = train_and_evaluate(X_train, y_train, config)
|
| 516 |
+
|
| 517 |
+
# average loss plots
|
| 518 |
+
avg_train_losses = np.mean(all_train_losses, axis=0)
|
| 519 |
+
avg_val_losses = np.mean(all_val_losses, axis=0)
|
| 520 |
+
test_vids = test_vids.tolist()
|
| 521 |
+
plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, len(test_vids), i)
|
| 522 |
+
|
| 523 |
+
# predict best model on the train dataset
|
| 524 |
+
y_train_pred = model_test(best_model, X_train, y_train, device)
|
| 525 |
+
y_train_pred = np.array(list(y_train_pred), dtype=float)
|
| 526 |
+
y_train_pred_logistic, plcc_train, rmse_train, srcc_train, krcc_train = compute_correlation_metrics(y_train, y_train_pred)
|
| 527 |
+
|
| 528 |
+
# test best model on the test dataset
|
| 529 |
+
y_test_pred = model_test(best_model, X_test, y_test, device)
|
| 530 |
+
y_test_pred = np.array(list(y_test_pred), dtype=float)
|
| 531 |
+
y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test = compute_correlation_metrics(y_test, y_test_pred)
|
| 532 |
+
|
| 533 |
+
# save the predict score results
|
| 534 |
+
test_pred_score = {'MOS': y_test, 'y_test_pred': y_test_pred, 'y_test_pred_logistic': y_test_pred_logistic}
|
| 535 |
+
df_test_pred = pd.DataFrame(test_pred_score)
|
| 536 |
+
|
| 537 |
+
# logging logistic predicted scores
|
| 538 |
+
logging.info("============================================================================================================")
|
| 539 |
+
SRCC_all_repeats_test.append(srcc_test)
|
| 540 |
+
KRCC_all_repeats_test.append(krcc_test)
|
| 541 |
+
PLCC_all_repeats_test.append(plcc_test)
|
| 542 |
+
RMSE_all_repeats_test.append(rmse_test)
|
| 543 |
+
SRCC_all_repeats_train.append(srcc_train)
|
| 544 |
+
KRCC_all_repeats_train.append(krcc_train)
|
| 545 |
+
PLCC_all_repeats_train.append(plcc_train)
|
| 546 |
+
RMSE_all_repeats_train.append(rmse_train)
|
| 547 |
+
all_repeats_test_vids.append(test_vids)
|
| 548 |
+
all_repeats_df_test_pred.append(df_test_pred)
|
| 549 |
+
best_model_list.append(copy.deepcopy(best_model))
|
| 550 |
+
|
| 551 |
+
# logging.info results for each iteration
|
| 552 |
+
logging.info('Best results in Mlp model within one split')
|
| 553 |
+
logging.info(f'MODEL: {best_model}')
|
| 554 |
+
logging.info('======================================================')
|
| 555 |
+
logging.info(f'Train set - Evaluation Results')
|
| 556 |
+
logging.info(f'SRCC_train: {srcc_train}')
|
| 557 |
+
logging.info(f'KRCC_train: {krcc_train}')
|
| 558 |
+
logging.info(f'PLCC_train: {plcc_train}')
|
| 559 |
+
logging.info(f'RMSE_train: {rmse_train}')
|
| 560 |
+
logging.info('======================================================')
|
| 561 |
+
logging.info(f'Test set - Evaluation Results')
|
| 562 |
+
logging.info(f'SRCC_test: {srcc_test}')
|
| 563 |
+
logging.info(f'KRCC_test: {krcc_test}')
|
| 564 |
+
logging.info(f'PLCC_test: {plcc_test}')
|
| 565 |
+
logging.info(f'RMSE_test: {rmse_test}')
|
| 566 |
+
logging.info('======================================================')
|
| 567 |
+
logging.info(' -- {} seconds elapsed...\n\n'.format(time.time() - t0))
|
| 568 |
+
|
| 569 |
+
logging.info('')
|
| 570 |
+
SRCC_all_repeats_test = np.nan_to_num(SRCC_all_repeats_test)
|
| 571 |
+
KRCC_all_repeats_test = np.nan_to_num(KRCC_all_repeats_test)
|
| 572 |
+
PLCC_all_repeats_test = np.nan_to_num(PLCC_all_repeats_test)
|
| 573 |
+
RMSE_all_repeats_test = np.nan_to_num(RMSE_all_repeats_test)
|
| 574 |
+
SRCC_all_repeats_train = np.nan_to_num(SRCC_all_repeats_train)
|
| 575 |
+
KRCC_all_repeats_train = np.nan_to_num(KRCC_all_repeats_train)
|
| 576 |
+
PLCC_all_repeats_train = np.nan_to_num(PLCC_all_repeats_train)
|
| 577 |
+
RMSE_all_repeats_train = np.nan_to_num(RMSE_all_repeats_train)
|
| 578 |
+
logging.info('======================================================')
|
| 579 |
+
logging.info('Average training results among all repeated 80-20 holdouts:')
|
| 580 |
+
logging.info('SRCC: %f (std: %f)', np.median(SRCC_all_repeats_train), np.std(SRCC_all_repeats_train))
|
| 581 |
+
logging.info('KRCC: %f (std: %f)', np.median(KRCC_all_repeats_train), np.std(KRCC_all_repeats_train))
|
| 582 |
+
logging.info('PLCC: %f (std: %f)', np.median(PLCC_all_repeats_train), np.std(PLCC_all_repeats_train))
|
| 583 |
+
logging.info('RMSE: %f (std: %f)', np.median(RMSE_all_repeats_train), np.std(RMSE_all_repeats_train))
|
| 584 |
+
logging.info('======================================================')
|
| 585 |
+
logging.info('Average testing results among all repeated 80-20 holdouts:')
|
| 586 |
+
logging.info('SRCC: %f (std: %f)', np.median(SRCC_all_repeats_test), np.std(SRCC_all_repeats_test))
|
| 587 |
+
logging.info('KRCC: %f (std: %f)', np.median(KRCC_all_repeats_test), np.std(KRCC_all_repeats_test))
|
| 588 |
+
logging.info('PLCC: %f (std: %f)', np.median(PLCC_all_repeats_test), np.std(PLCC_all_repeats_test))
|
| 589 |
+
logging.info('RMSE: %f (std: %f)', np.median(RMSE_all_repeats_test), np.std(RMSE_all_repeats_test))
|
| 590 |
+
logging.info('======================================================')
|
| 591 |
+
logging.info('\n')
|
| 592 |
+
|
| 593 |
+
# find the median model and the index of the median
|
| 594 |
+
print('======================================================')
|
| 595 |
+
if select_criteria == 'byrmse':
|
| 596 |
+
median_metrics = np.median(RMSE_all_repeats_test)
|
| 597 |
+
indices = np.where(RMSE_all_repeats_test == median_metrics)[0]
|
| 598 |
+
select_criteria = select_criteria.replace('by', '').upper()
|
| 599 |
+
print(RMSE_all_repeats_test)
|
| 600 |
+
logging.info(f'all {select_criteria}: {RMSE_all_repeats_test}')
|
| 601 |
+
elif select_criteria == 'bykrcc':
|
| 602 |
+
median_metrics = np.median(KRCC_all_repeats_test)
|
| 603 |
+
indices = np.where(KRCC_all_repeats_test == median_metrics)[0]
|
| 604 |
+
select_criteria = select_criteria.replace('by', '').upper()
|
| 605 |
+
print(KRCC_all_repeats_test)
|
| 606 |
+
logging.info(f'all {select_criteria}: {KRCC_all_repeats_test}')
|
| 607 |
+
|
| 608 |
+
median_test_vids = [all_repeats_test_vids[i] for i in indices]
|
| 609 |
+
test_vids = [arr.tolist() for arr in median_test_vids] if len(median_test_vids) > 1 else (median_test_vids[0] if median_test_vids else [])
|
| 610 |
+
|
| 611 |
+
# select the model with the first index where the median is located
|
| 612 |
+
# Note: If there are multiple iterations with the same median RMSE, the first index is selected here
|
| 613 |
+
median_model = None
|
| 614 |
+
if len(indices) > 0:
|
| 615 |
+
median_index = indices[0] # select the first index
|
| 616 |
+
median_model = best_model_list[median_index]
|
| 617 |
+
median_model_df_test_pred = all_repeats_df_test_pred[median_index]
|
| 618 |
+
|
| 619 |
+
median_model_df_test_pred.to_csv(pred_score_filename, index=False)
|
| 620 |
+
plot_results(y_test, y_test_pred_logistic, median_model_df_test_pred, model_name, data_name, network_name, select_criteria)
|
| 621 |
+
|
| 622 |
+
print(f'Median Metrics: {median_metrics}')
|
| 623 |
+
print(f'Indices: {indices}')
|
| 624 |
+
# print(f'Test Videos: {test_vids}')
|
| 625 |
+
print(f'Best model: {median_model}')
|
| 626 |
+
|
| 627 |
+
logging.info(f'median test {select_criteria}: {median_metrics}')
|
| 628 |
+
logging.info(f"Indices of median metrics: {indices}")
|
| 629 |
+
# logging.info(f'Best training and test dataset: {test_vids}')
|
| 630 |
+
logging.info(f'Best model predict score: {median_model_df_test_pred}')
|
| 631 |
+
logging.info(f'Best model: {median_model}')
|
| 632 |
+
|
| 633 |
+
# ================================================================================
|
| 634 |
+
# save mats
|
| 635 |
+
scipy.io.savemat(result_file, mdict={'SRCC_train': np.asarray(SRCC_all_repeats_train, dtype=float), \
|
| 636 |
+
'KRCC_train': np.asarray(KRCC_all_repeats_train, dtype=float), \
|
| 637 |
+
'PLCC_train': np.asarray(PLCC_all_repeats_train, dtype=float), \
|
| 638 |
+
'RMSE_train': np.asarray(RMSE_all_repeats_train, dtype=float), \
|
| 639 |
+
'SRCC_test': np.asarray(SRCC_all_repeats_test, dtype=float), \
|
| 640 |
+
'KRCC_test': np.asarray(KRCC_all_repeats_test, dtype=float), \
|
| 641 |
+
'PLCC_test': np.asarray(PLCC_all_repeats_test, dtype=float), \
|
| 642 |
+
'RMSE_test': np.asarray(RMSE_all_repeats_test, dtype=float), \
|
| 643 |
+
f'Median_{select_criteria}': median_metrics, \
|
| 644 |
+
'Test_Videos_list': all_repeats_test_vids, \
|
| 645 |
+
'Test_videos_Median_model': test_vids, \
|
| 646 |
+
})
|
| 647 |
+
|
| 648 |
+
# save model
|
| 649 |
+
torch.save(median_model.state_dict(), file_path)
|
| 650 |
+
print(f"Model state_dict saved to {file_path}")
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
if __name__ == '__main__':
|
| 654 |
+
parser = argparse.ArgumentParser()
|
| 655 |
+
# input parameters
|
| 656 |
+
parser.add_argument('--model_name', type=str, default='Mlp')
|
| 657 |
+
parser.add_argument('--data_name', type=str, default='lsvq_train', help='konvid_1k, youtube_ugc, live_vqc, cvd_2014, lsvq_train, cross_dataset')
|
| 658 |
+
parser.add_argument('--network_name', type=str, default='relaxvqa', help='relaxvqa, {frag_name}_{network_name}_{layer_name}')
|
| 659 |
+
|
| 660 |
+
parser.add_argument('--metadata_path', type=str, default='../metadata/')
|
| 661 |
+
parser.add_argument('--feature_path', type=str, default='../features/')
|
| 662 |
+
parser.add_argument('--log_path', type=str, default='../log/')
|
| 663 |
+
parser.add_argument('--save_path', type=str, default='../model/')
|
| 664 |
+
parser.add_argument('--score_path', type=str, default='../log/predict_score/')
|
| 665 |
+
parser.add_argument('--result_path', type=str, default='../log/result/')
|
| 666 |
+
# training parameters
|
| 667 |
+
parser.add_argument('--select_criteria', type=str, default='byrmse', help='byrmse, bykrcc')
|
| 668 |
+
parser.add_argument('--n_repeats', type=int, default=21, help='Number of repeats for 80-20 hold out test')
|
| 669 |
+
parser.add_argument('--n_splits', type=int, default=10, help='Number of splits for k-fold validation')
|
| 670 |
+
parser.add_argument('--batch_size', type=int, default=256, help='Batch size for training')
|
| 671 |
+
parser.add_argument('--epochs', type=int, default=20, help='Epochs for training') # 120(small), 20(big)
|
| 672 |
+
parser.add_argument('--hidden_features', type=int, default=256, help='Hidden features')
|
| 673 |
+
parser.add_argument('--drop_rate', type=float, default=0.1, help='Dropout rate.')
|
| 674 |
+
# misc
|
| 675 |
+
parser.add_argument('--loss_type', type=str, default='MAERankLoss', help='MSEloss or MAERankLoss')
|
| 676 |
+
parser.add_argument('--optimizer_type', type=str, default='sgd', help='adam or sgd')
|
| 677 |
+
parser.add_argument('--initial_lr', type=float, default=1e-1, help='Initial learning rate: 1e-2')
|
| 678 |
+
parser.add_argument('--weight_decay', type=float, default=0.005, help='Weight decay (L2 loss): 1e-4')
|
| 679 |
+
parser.add_argument('--patience', type=int, default=5, help='Early stopping patience.')
|
| 680 |
+
parser.add_argument('--use_swa', type=bool, default=True, help='Use Stochastic Weight Averaging')
|
| 681 |
+
parser.add_argument('--l1_w', type=float, default=0.6, help='MAE loss weight')
|
| 682 |
+
parser.add_argument('--rank_w', type=float, default=1.0, help='Rank loss weight')
|
| 683 |
+
|
| 684 |
+
args = parser.parse_args()
|
| 685 |
+
config = vars(args) # args to dict
|
| 686 |
+
print(config)
|
| 687 |
+
|
| 688 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 689 |
+
print(device)
|
| 690 |
+
if device.type == "cuda":
|
| 691 |
+
torch.cuda.set_device(0)
|
| 692 |
+
|
| 693 |
+
main(config)
|
relax_vqa.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from extractor import visualise_resnet, visualise_resnet_layer, visualise_vit_layer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_deep_feature(network_name, video_name, frame, frame_number, model, device, layer_name):
|
| 9 |
+
if network_name == 'resnet50':
|
| 10 |
+
if layer_name == 'layerstack':
|
| 11 |
+
all_layers = ['resnet50.conv1',
|
| 12 |
+
'resnet50.layer1[0]', 'resnet50.layer1[1]', 'resnet50.layer1[2]',
|
| 13 |
+
'resnet50.layer2[0]', 'resnet50.layer2[1]', 'resnet50.layer2[2]', 'resnet50.layer2[3]',
|
| 14 |
+
'resnet50.layer3[0]', 'resnet50.layer3[1]', 'resnet50.layer3[2]', 'resnet50.layer3[3]',
|
| 15 |
+
'resnet50.layer4[0]', 'resnet50.layer4[1]', 'resnet50.layer4[2]']
|
| 16 |
+
resnet50 = model
|
| 17 |
+
activations_dict, _, total_flops, total_params = visualise_resnet.process_video_frame(video_name, frame, frame_number, all_layers, resnet50, device)
|
| 18 |
+
|
| 19 |
+
elif layer_name == 'pool':
|
| 20 |
+
visual_layer = 'resnet50.avgpool' # before avg_pool
|
| 21 |
+
resnet50 = model
|
| 22 |
+
activations_dict, _, total_flops, total_params = visualise_resnet_layer.process_video_frame(video_name, frame, frame_number, visual_layer, resnet50, device)
|
| 23 |
+
|
| 24 |
+
elif network_name == 'vit':
|
| 25 |
+
patch_size = 16
|
| 26 |
+
activations_dict, _, total_flops, total_params = visualise_vit_layer.process_video_frame(video_name, frame, frame_number, model, patch_size, device)
|
| 27 |
+
|
| 28 |
+
return activations_dict, total_flops, total_params
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def process_video_feature(video_feature, network_name, layer_name):
|
| 32 |
+
# initialize an empty list to store processed frames
|
| 33 |
+
averaged_frames = []
|
| 34 |
+
# iterate through each frame in the video_feature
|
| 35 |
+
for frame in video_feature:
|
| 36 |
+
frame_features = []
|
| 37 |
+
|
| 38 |
+
if network_name == 'vit':
|
| 39 |
+
# global mean and std
|
| 40 |
+
global_mean = torch.mean(frame, dim=0)
|
| 41 |
+
global_max = torch.max(frame, dim=0)[0]
|
| 42 |
+
global_std = torch.std(frame, dim=0)
|
| 43 |
+
# concatenate all pooling
|
| 44 |
+
combined_features = torch.hstack([global_mean, global_max, global_std])
|
| 45 |
+
frame_features.append(combined_features)
|
| 46 |
+
|
| 47 |
+
elif network_name == 'resnet50':
|
| 48 |
+
if layer_name == 'layerstack':
|
| 49 |
+
# iterate through each layer in the current framex
|
| 50 |
+
for layer_array in frame.values():
|
| 51 |
+
# calculate the mean along the specified axes (1 and 2) for each layer
|
| 52 |
+
layer_mean = torch.mean(layer_array, dim=(1, 2))
|
| 53 |
+
# append the calculated mean to the list for the current frame
|
| 54 |
+
frame_features.append(layer_mean)
|
| 55 |
+
elif layer_name == 'pool':
|
| 56 |
+
frame = torch.squeeze(torch.tensor(frame))
|
| 57 |
+
# global mean and std
|
| 58 |
+
global_mean = torch.mean(frame, dim=0)
|
| 59 |
+
global_max = torch.max(frame, dim=0)[0]
|
| 60 |
+
global_std = torch.std(frame, dim=0)
|
| 61 |
+
# concatenate all pooling
|
| 62 |
+
combined_features = torch.hstack([frame, global_mean, global_max, global_std])
|
| 63 |
+
frame_features.append(combined_features)
|
| 64 |
+
|
| 65 |
+
# concatenate the layer means horizontally to form the processed frame
|
| 66 |
+
processed_frame = torch.hstack(frame_features)
|
| 67 |
+
averaged_frames.append(processed_frame)
|
| 68 |
+
|
| 69 |
+
averaged_frames = torch.stack(averaged_frames)
|
| 70 |
+
return averaged_frames
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def flow_to_rgb(flow):
|
| 74 |
+
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
|
| 75 |
+
mag = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
|
| 76 |
+
# convert angle to hue
|
| 77 |
+
hue = ang * 180 / np.pi / 2
|
| 78 |
+
|
| 79 |
+
# create HSV
|
| 80 |
+
hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
|
| 81 |
+
hsv[..., 0] = hue
|
| 82 |
+
hsv[..., 1] = 255
|
| 83 |
+
hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
|
| 84 |
+
# convert HSV to RGB
|
| 85 |
+
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
|
| 86 |
+
return rgb
|
| 87 |
+
|
| 88 |
+
def get_patch_diff(residual_frame, patch_size):
|
| 89 |
+
h, w = residual_frame.shape[2:] # Assuming (1, C, H, W) shape
|
| 90 |
+
h_adj = (h // patch_size) * patch_size
|
| 91 |
+
w_adj = (w // patch_size) * patch_size
|
| 92 |
+
residual_frame_adj = residual_frame[:, :, :h_adj, :w_adj]
|
| 93 |
+
# calculate absolute patch difference
|
| 94 |
+
diff = torch.zeros((h_adj // patch_size, w_adj // patch_size), device=residual_frame.device)
|
| 95 |
+
for i in range(0, h_adj, patch_size):
|
| 96 |
+
for j in range(0, w_adj, patch_size):
|
| 97 |
+
patch = residual_frame_adj[:, :, i:i + patch_size, j:j + patch_size]
|
| 98 |
+
# absolute sum
|
| 99 |
+
diff[i // patch_size, j // patch_size] = torch.sum(torch.abs(patch))
|
| 100 |
+
return diff
|
| 101 |
+
|
| 102 |
+
def extract_important_patches(residual_frame, diff, patch_size=16, target_size=224, top_n=196):
|
| 103 |
+
# find top n patches indices
|
| 104 |
+
patch_idx = torch.argsort(-diff.view(-1))
|
| 105 |
+
top_patches = [(idx // diff.shape[1], idx % diff.shape[1]) for idx in patch_idx[:top_n]]
|
| 106 |
+
sorted_idx = sorted(top_patches, key=lambda x: (x[0], x[1]))
|
| 107 |
+
|
| 108 |
+
imp_patches_img = torch.zeros((residual_frame.shape[1], target_size, target_size), dtype=residual_frame.dtype, device=residual_frame.device)
|
| 109 |
+
patches_per_row = target_size // patch_size # 14
|
| 110 |
+
# order the patch in the original location relation
|
| 111 |
+
positions = []
|
| 112 |
+
for idx, (y, x) in enumerate(sorted_idx):
|
| 113 |
+
patch = residual_frame[:, :, y * patch_size:(y + 1) * patch_size, x * patch_size:(x + 1) * patch_size]
|
| 114 |
+
# new patch location
|
| 115 |
+
row_idx = idx // patches_per_row
|
| 116 |
+
col_idx = idx % patches_per_row
|
| 117 |
+
start_y = row_idx * patch_size
|
| 118 |
+
start_x = col_idx * patch_size
|
| 119 |
+
imp_patches_img[:, start_y:start_y + patch_size, start_x:start_x + patch_size] = patch
|
| 120 |
+
positions.append((y.item(), x.item()))
|
| 121 |
+
return imp_patches_img, positions
|
| 122 |
+
|
| 123 |
+
def get_frame_patches(frame, positions, patch_size, target_size):
|
| 124 |
+
imp_patches_img = torch.zeros((frame.shape[1], target_size, target_size), dtype=frame.dtype, device=frame.device)
|
| 125 |
+
patches_per_row = target_size // patch_size
|
| 126 |
+
|
| 127 |
+
for idx, (y, x) in enumerate(positions):
|
| 128 |
+
start_y = y * patch_size
|
| 129 |
+
start_x = x * patch_size
|
| 130 |
+
end_y = start_y + patch_size
|
| 131 |
+
end_x = start_x + patch_size
|
| 132 |
+
|
| 133 |
+
patch = frame[:, :, start_y:end_y, start_x:end_x]
|
| 134 |
+
row_idx = idx // patches_per_row
|
| 135 |
+
col_idx = idx % patches_per_row
|
| 136 |
+
target_start_y = row_idx * patch_size
|
| 137 |
+
target_start_x = col_idx * patch_size
|
| 138 |
+
|
| 139 |
+
imp_patches_img[:, target_start_y:target_start_y + patch_size,
|
| 140 |
+
target_start_x:target_start_x + patch_size] = patch.squeeze(0)
|
| 141 |
+
return imp_patches_img
|
| 142 |
+
|
| 143 |
+
def process_patches(original_path, frag_name, residual, patch_size, target_size, top_n):
|
| 144 |
+
diff = get_patch_diff(residual, patch_size)
|
| 145 |
+
imp_patches, positions = extract_important_patches(residual, diff, patch_size, target_size, top_n)
|
| 146 |
+
if frag_name == 'frame_diff':
|
| 147 |
+
frag_path = original_path.replace('.png', '_residual_imp.png')
|
| 148 |
+
elif frag_name == 'optical_flow':
|
| 149 |
+
frag_path = original_path.replace('.png', '_residual_of_imp.png')
|
| 150 |
+
# cv2.imwrite(frag_path, imp_patches)
|
| 151 |
+
return frag_path, imp_patches, positions
|
| 152 |
+
|
| 153 |
+
def merge_fragments(diff_fragment, flow_fragment):
|
| 154 |
+
alpha = 0.5
|
| 155 |
+
merged_fragment = diff_fragment * alpha + flow_fragment * (1 - alpha)
|
| 156 |
+
return merged_fragment
|
| 157 |
+
|
| 158 |
+
def concatenate_features(frame_feature, residual_feature):
|
| 159 |
+
return torch.cat((frame_feature, residual_feature), dim=-1)
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
torchaudio
|
| 5 |
+
opencv-python
|
| 6 |
+
joblib
|
| 7 |
+
scikit-learn
|
| 8 |
+
scipy
|
| 9 |
+
numpy
|
| 10 |
+
pandas
|
| 11 |
+
matplotlib
|
| 12 |
+
ipywidgets
|
| 13 |
+
thop
|
| 14 |
+
PyYAML
|
| 15 |
+
seaborn
|