mmesa-gitex / app /au_processing.py
vitorcalvi's picture
12 Oct Gitex 2024
b20a621
raw
history blame
2.41 kB
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
from PIL import Image
from app.model import pth_model_static, cam, pth_processing
from app.face_utils import get_box
import mediapipe as mp
mp_face_mesh = mp.solutions.face_mesh
def preprocess_frame_and_predict_aus(frame):
if len(frame.shape) == 2:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
elif frame.shape[2] == 4:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
with mp_face_mesh.FaceMesh(
max_num_faces=1,
refine_landmarks=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5
) as face_mesh:
results = face_mesh.process(frame)
if results.multi_face_landmarks:
h, w = frame.shape[:2]
for fl in results.multi_face_landmarks:
startX, startY, endX, endY = get_box(fl, w, h)
cur_face = frame[startY:endY, startX:endX]
cur_face_n = pth_processing(Image.fromarray(cur_face))
with torch.no_grad():
features = pth_model_static(cur_face_n)
au_intensities = features_to_au_intensities(features)
grayscale_cam = cam(input_tensor=cur_face_n)
grayscale_cam = grayscale_cam[0, :]
cur_face_hm = cv2.resize(cur_face, (224, 224))
cur_face_hm = np.float32(cur_face_hm) / 255
heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=True)
return cur_face, au_intensities, heatmap
return None, None, None
def features_to_au_intensities(features):
features_np = features.detach().cpu().numpy()[0]
au_intensities = (features_np - features_np.min()) / (features_np.max() - features_np.min())
return au_intensities[:24] # Assuming we want 24 AUs
def au_statistics_plot(frames, au_intensities_list):
fig, ax = plt.subplots(figsize=(12, 6))
au_intensities_array = np.array(au_intensities_list)
for i in range(au_intensities_array.shape[1]):
ax.plot(frames, au_intensities_array[:, i], label=f'AU{i+1}')
ax.set_xlabel('Frame')
ax.set_ylabel('AU Intensity')
ax.set_title('Action Unit Intensities Over Time')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
return fig