|
|
import copy |
|
|
import gc |
|
|
import logging |
|
|
import math |
|
|
import os |
|
|
import random |
|
|
import sys |
|
|
from contextlib import contextmanager |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.cuda.amp as amp |
|
|
import torchvision.transforms.functional as TF |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .modules.clip import CLIPModel |
|
|
from .modules.t5 import T5EncoderModel |
|
|
from .modules.vae import WanVAE |
|
|
from .modules.model import WanModel |
|
|
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) |
|
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
from .utils.audio_utils import preprocess_audio, resample_audio |
|
|
from .utils.infer_utils import create_null_audio_ref_features, \ |
|
|
expand_face_mask_flexible, gen_inference_masks, expand_bbox_and_crop_image, count_parameters, \ |
|
|
gen_smooth_transition_mask_for_dit, process_audio_features, process_audio_features |
|
|
|
|
|
|
|
|
class WanAF2V: |
|
|
def __init__( |
|
|
self, |
|
|
config, |
|
|
checkpoint_dir, |
|
|
device_id=0, |
|
|
rank=0, |
|
|
t5_fsdp=False, |
|
|
dit_fsdp=False, |
|
|
t5_cpu=False, |
|
|
init_on_cpu=True, |
|
|
use_gradient_checkpointing=False, |
|
|
post_trained_checkpoint_path=None, |
|
|
dit_config=None, |
|
|
crop_image_size=224, |
|
|
use_half=False, |
|
|
): |
|
|
r""" |
|
|
Initializes the image-to-video generation model components. |
|
|
|
|
|
Args: |
|
|
config (EasyDict): |
|
|
Object containing model parameters initialized from config.py |
|
|
checkpoint_dir (`str`): |
|
|
Path to directory containing model checkpoints |
|
|
device_id (`int`, *optional*, defaults to 0): |
|
|
Id of target GPU device |
|
|
rank (`int`, *optional*, defaults to 0): |
|
|
Process rank for distributed training |
|
|
t5_fsdp (`bool`, *optional*, defaults to False): |
|
|
Enable FSDP sharding for T5 model |
|
|
dit_fsdp (`bool`, *optional*, defaults to False): |
|
|
Enable FSDP sharding for DiT model |
|
|
t5_cpu (`bool`, *optional*, defaults to False): |
|
|
Whether to place T5 model on CPU. Only works without t5_fsdp. |
|
|
init_on_cpu (`bool`, *optional*, defaults to True): |
|
|
Enable initializing Transformer Model on CPU. Only works without FSDP or USP. |
|
|
post_trained_checkpoint_path (`str`, *optional*, defaults to None): |
|
|
Path to the post-trained checkpoint file. If provided, model will be loaded from this checkpoint. |
|
|
use_half (`bool`, *optional*, defaults to False): |
|
|
Whether to use half precision (float16/bfloat16) for model inference. Reduces memory usage. |
|
|
""" |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
self.device = torch.device(f"cuda:{device_id}") |
|
|
else: |
|
|
self.device = torch.device("cpu") |
|
|
self.config = config |
|
|
self.rank = rank |
|
|
self.t5_cpu = t5_cpu |
|
|
self.use_half = use_half |
|
|
|
|
|
self.num_train_timesteps = config.num_train_timesteps |
|
|
self.param_dtype = config.param_dtype |
|
|
|
|
|
if use_half: |
|
|
self.half_dtype = torch.float16 |
|
|
logging.info(f"Half precision enabled, using dtype: {self.half_dtype} (forced float16 for faster inference)") |
|
|
else: |
|
|
self.half_dtype = torch.float32 |
|
|
|
|
|
self.text_encoder = T5EncoderModel( |
|
|
text_len=config.text_len, |
|
|
dtype=config.t5_dtype, |
|
|
device=torch.device('cpu'), |
|
|
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), |
|
|
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), |
|
|
shard_fn=None, |
|
|
) |
|
|
|
|
|
|
|
|
model_param_count = count_parameters(self.text_encoder.model) |
|
|
logging.info(f"Text Model parameters: {model_param_count}M") |
|
|
|
|
|
|
|
|
self.crop_image_size = crop_image_size |
|
|
self.vae_stride = config.vae_stride |
|
|
self.patch_size = config.patch_size |
|
|
self.vae = WanVAE( |
|
|
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), |
|
|
device=self.device) |
|
|
|
|
|
|
|
|
model_param_count = count_parameters(self.vae.model) |
|
|
logging.info(f"VAE Model parameters: {model_param_count}M") |
|
|
|
|
|
self.clip = CLIPModel( |
|
|
dtype=config.clip_dtype, |
|
|
device=self.device, |
|
|
checkpoint_path=os.path.join(checkpoint_dir, |
|
|
config.clip_checkpoint), |
|
|
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) |
|
|
|
|
|
|
|
|
model_param_count = count_parameters(self.clip.model) |
|
|
logging.info(f"CLIP Model parameters: {model_param_count}M") |
|
|
|
|
|
logging.info(f"Creating WanModel from {checkpoint_dir}") |
|
|
|
|
|
|
|
|
if post_trained_checkpoint_path: |
|
|
try: |
|
|
if rank == 0: |
|
|
print(f"Loading post-trained model from {post_trained_checkpoint_path}") |
|
|
|
|
|
|
|
|
|
|
|
config_dict = dit_config |
|
|
|
|
|
self.model = WanModel.from_config(config_dict) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(post_trained_checkpoint_path, map_location='cpu', weights_only=True) |
|
|
model_state = checkpoint['model'] |
|
|
self.model.load_state_dict(model_state) |
|
|
if rank == 0: |
|
|
print(f"safertensors have been loaded: {post_trained_checkpoint_path}") |
|
|
except Exception as e: |
|
|
if rank == 0: |
|
|
print(f"Error loading post-trained model: {e}") |
|
|
raise e |
|
|
else: |
|
|
self.model = WanModel.from_pretrained(checkpoint_dir) |
|
|
|
|
|
self.model.eval().requires_grad_(False) |
|
|
|
|
|
|
|
|
model_param_count = count_parameters(self.model) |
|
|
logging.info(f"DiT Model parameters: {model_param_count}M") |
|
|
|
|
|
|
|
|
|
|
|
if use_gradient_checkpointing: |
|
|
self.model.enable_gradient_checkpointing() |
|
|
logging.info("Gradient checkpointing enabled for WanModel") |
|
|
|
|
|
self.sp_size = 1 |
|
|
|
|
|
|
|
|
if not init_on_cpu: |
|
|
self.model.to(self.device) |
|
|
|
|
|
if use_half: |
|
|
try: |
|
|
self.model = self.model.to(dtype=self.half_dtype) |
|
|
logging.info(f"Model converted to {self.half_dtype} precision") |
|
|
except Exception as e: |
|
|
logging.warning(f"Failed to convert model to half precision: {e}. Continuing with float32.") |
|
|
self.use_half = False |
|
|
self.half_dtype = torch.float32 |
|
|
|
|
|
self.sample_neg_prompt = config.sample_neg_prompt |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
input_prompt, |
|
|
img, |
|
|
audio=None, |
|
|
max_area=720 * 1280, |
|
|
frame_num=81, |
|
|
shift=5.0, |
|
|
sample_solver='unipc', |
|
|
sampling_steps=40, |
|
|
guide_scale=5.0, |
|
|
n_prompt="", |
|
|
seed=-1, |
|
|
offload_model=True, |
|
|
cfg_zero=False, |
|
|
zero_init_steps=0, |
|
|
face_processor=None, |
|
|
img_path=None, |
|
|
audio_paths=None, |
|
|
task_key=None, |
|
|
mode="pad", |
|
|
): |
|
|
r""" |
|
|
Generates video frames from input image and text prompt using diffusion process. |
|
|
|
|
|
Args: |
|
|
input_prompt (`str`): |
|
|
Text prompt for content generation. |
|
|
img (PIL.Image.Image): |
|
|
Input image tensor. Shape: [3, H, W] |
|
|
max_area (`int`, *optional*, defaults to 720*1280): |
|
|
Maximum pixel area for latent space calculation. Controls video resolution scaling |
|
|
frame_num (`int`, *optional*, defaults to 81): |
|
|
How many frames to sample from a video. The number should be 4n+1 |
|
|
shift (`float`, *optional*, defaults to 5.0): |
|
|
Noise schedule shift parameter. Affects temporal dynamics |
|
|
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. |
|
|
sample_solver (`str`, *optional*, defaults to 'unipc'): |
|
|
Solver used to sample the video. |
|
|
sampling_steps (`int`, *optional*, defaults to 40): |
|
|
Number of diffusion sampling steps. Higher values improve quality but slow generation |
|
|
guide_scale (`float`, *optional*, defaults 5.0): |
|
|
Classifier-free guidance scale. Controls prompt adherence vs. creativity |
|
|
n_prompt (`str`, *optional*, defaults to ""): |
|
|
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` |
|
|
seed (`int`, *optional*, defaults to -1): |
|
|
Random seed for noise generation. If -1, use random seed |
|
|
offload_model (`bool`, *optional*, defaults to True): |
|
|
If True, offloads models to CPU during generation to save VRAM |
|
|
cfg_zero (`bool`, *optional*, defaults to False): |
|
|
Whether to use adaptive CFG-Zero guidance instead of fixed guidance scale |
|
|
zero_init_steps (`int`, *optional*, defaults to 0): |
|
|
Number of initial steps to use zero guidance when using cfg_zero |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: |
|
|
Generated video frames tensor. Dimensions: (C, N H, W) where: |
|
|
- C: Color channels (3 for RGB) |
|
|
- N: Number of frames (81) |
|
|
- H: Frame height (from max_area) |
|
|
- W: Frame width from max_area) |
|
|
""" |
|
|
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) |
|
|
|
|
|
if self.use_half: |
|
|
img = img.to(dtype=self.half_dtype) |
|
|
|
|
|
|
|
|
self.audio_paths = audio_paths |
|
|
|
|
|
|
|
|
F = frame_num |
|
|
print(f"Using frame number: {F} (mode: {mode})") |
|
|
h, w = img.shape[1:] |
|
|
print(f"Input image size: {h}, {w}, {max_area}") |
|
|
aspect_ratio = h / w |
|
|
lat_h = round( |
|
|
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // |
|
|
self.patch_size[1] * self.patch_size[1]) |
|
|
lat_w = round( |
|
|
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // |
|
|
self.patch_size[2] * self.patch_size[2]) |
|
|
|
|
|
h = lat_h * self.vae_stride[1] |
|
|
w = lat_w * self.vae_stride[2] |
|
|
|
|
|
|
|
|
original_aspect_ratio = h / w |
|
|
|
|
|
|
|
|
|
|
|
w_candidate1 = (w // 32) * 32 |
|
|
if w_candidate1 == 0: |
|
|
w_candidate1 = 32 |
|
|
h_candidate1 = (int(w_candidate1 * original_aspect_ratio) // 32) * 32 |
|
|
if h_candidate1 == 0: |
|
|
h_candidate1 = 32 |
|
|
|
|
|
|
|
|
h_candidate2 = (h // 32) * 32 |
|
|
if h_candidate2 == 0: |
|
|
h_candidate2 = 32 |
|
|
w_candidate2 = (int(h_candidate2 / original_aspect_ratio) // 32) * 32 |
|
|
if w_candidate2 == 0: |
|
|
w_candidate2 = 32 |
|
|
|
|
|
|
|
|
diff1 = abs(w_candidate1 - w) + abs(h_candidate1 - h) |
|
|
diff2 = abs(w_candidate2 - w) + abs(h_candidate2 - h) |
|
|
|
|
|
if diff1 <= diff2: |
|
|
w, h = w_candidate1, h_candidate1 |
|
|
else: |
|
|
w, h = w_candidate2, h_candidate2 |
|
|
|
|
|
|
|
|
lat_h = h // self.vae_stride[1] |
|
|
lat_w = w // self.vae_stride[2] |
|
|
|
|
|
print(f"Processed image size: {h}, {w}, {lat_h}, {lat_w}") |
|
|
|
|
|
|
|
|
latent_frame_num = (F - 1) // self.vae_stride[0] + 1 |
|
|
seed = seed if seed >= 0 else random.randint(0, sys.maxsize) |
|
|
seed_g = torch.Generator(device=self.device) |
|
|
seed_g.manual_seed(seed) |
|
|
|
|
|
noise_dtype = self.half_dtype if self.use_half else torch.float32 |
|
|
noise = torch.randn(16, latent_frame_num, lat_h, lat_w, dtype=noise_dtype, generator=seed_g, device=self.device) |
|
|
print(f"noise shape: {noise.shape}, dtype: {noise.dtype}") |
|
|
|
|
|
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( |
|
|
self.patch_size[1] * self.patch_size[2]) |
|
|
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size |
|
|
print(f"Max seq_len: {max_seq_len}") |
|
|
|
|
|
if n_prompt == "": |
|
|
n_prompt = self.sample_neg_prompt |
|
|
|
|
|
|
|
|
if not self.t5_cpu: |
|
|
self.text_encoder.model.to(self.device) |
|
|
context = self.text_encoder([input_prompt], self.device) |
|
|
context_null = self.text_encoder([n_prompt], self.device) |
|
|
if offload_model: |
|
|
self.text_encoder.model.cpu() |
|
|
else: |
|
|
context = self.text_encoder([input_prompt], torch.device('cpu')) |
|
|
context_null = self.text_encoder([n_prompt], torch.device('cpu')) |
|
|
context = [t.to(self.device) for t in context] |
|
|
context_null = [t.to(self.device) for t in context_null] |
|
|
|
|
|
self.clip.model.to(self.device) |
|
|
with torch.no_grad(): |
|
|
|
|
|
img_input = img[:, None, :, :] |
|
|
if self.use_half: |
|
|
img_input = img_input.to(dtype=self.half_dtype) |
|
|
clip_context = self.clip.visual([img_input]) |
|
|
|
|
|
if self.use_half: |
|
|
clip_context = clip_context.to(dtype=self.half_dtype) |
|
|
|
|
|
|
|
|
""" |
|
|
Start of i2v mask and ref latent construction logic |
|
|
""" |
|
|
|
|
|
mask_dtype = self.half_dtype if self.use_half else torch.float32 |
|
|
msk = torch.ones(1, F, lat_h, lat_w, device=self.device, dtype=mask_dtype) |
|
|
msk[:, 1:] = 0 |
|
|
msk = torch.concat([ |
|
|
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] |
|
|
], |
|
|
dim=1) |
|
|
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) |
|
|
msk = msk.transpose(1, 2)[0] |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
img_for_vae = img[None].cpu() |
|
|
if img_for_vae.dtype != torch.float32: |
|
|
img_for_vae = img_for_vae.float() |
|
|
y = self.vae.encode([ |
|
|
torch.concat([ |
|
|
torch.nn.functional.interpolate( |
|
|
img_for_vae, size=(h, w), mode='bicubic').transpose( |
|
|
0, 1), |
|
|
torch.zeros(3, F-1, h, w, dtype=torch.float32) |
|
|
], |
|
|
dim=1).to(self.device) |
|
|
])[0] |
|
|
|
|
|
if self.use_half: |
|
|
y = y.to(dtype=self.half_dtype) |
|
|
print(f"y shape after VAE encode: {y.shape}, dtype: {y.dtype}") |
|
|
|
|
|
|
|
|
face_clip_context = None |
|
|
if face_processor is not None: |
|
|
w_scale_factor = 1.2 |
|
|
h_scale_factor = 1.1 |
|
|
|
|
|
|
|
|
if hasattr(self, 'audio_paths') and self.audio_paths: |
|
|
|
|
|
n_faces_needed = len(self.audio_paths) |
|
|
print(f"number of faces needed: {n_faces_needed}") |
|
|
else: |
|
|
|
|
|
n_faces_needed = 1 |
|
|
print(f"only one audio or no audio, number of faces needed: {n_faces_needed}") |
|
|
|
|
|
|
|
|
face_info = face_processor.infer(img_path, n=n_faces_needed) |
|
|
|
|
|
n = len(face_info['masks']) |
|
|
print(f"number of faces detected: {n}") |
|
|
|
|
|
|
|
|
print(f"the face order is set to left-to-right: {list(range(n))}") |
|
|
masks = face_info['masks'] |
|
|
|
|
|
print(f"video with face processor, scale up w={w_scale_factor}, h={h_scale_factor} ###") |
|
|
|
|
|
expanded_masks = [] |
|
|
for mask in masks: |
|
|
expanded_mask = expand_face_mask_flexible( |
|
|
torch.from_numpy(mask).to(self.device, dtype=img.dtype).clone(), |
|
|
width_scale_factor=w_scale_factor, |
|
|
height_scale_factor=h_scale_factor |
|
|
) |
|
|
expanded_masks.append(expanded_mask) |
|
|
global_mask = torch.zeros_like(expanded_masks[0]) |
|
|
for mask in expanded_masks: |
|
|
global_mask += mask |
|
|
|
|
|
|
|
|
mask_dtype = self.half_dtype if self.use_half else torch.float32 |
|
|
dit_mask = gen_smooth_transition_mask_for_dit( |
|
|
global_mask, |
|
|
lat_h, |
|
|
lat_w, |
|
|
F, |
|
|
device=self.device, |
|
|
mask_dtype=mask_dtype, |
|
|
target_translate=(0, 0), |
|
|
target_scale=1 |
|
|
) |
|
|
y = torch.cat([dit_mask, y], dim=0) |
|
|
|
|
|
|
|
|
resized_masks = [] |
|
|
with torch.no_grad(): |
|
|
for mask in expanded_masks: |
|
|
|
|
|
latent_mask = torch.nn.functional.interpolate( |
|
|
mask.unsqueeze(0).unsqueeze(0), |
|
|
size=(lat_h // 2, lat_w // 2), |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
).squeeze(0).squeeze(0) |
|
|
|
|
|
resized_masks.append(latent_mask) |
|
|
|
|
|
|
|
|
latent_frame_num = (F - 1) // self.vae_stride[0] + 1 |
|
|
|
|
|
inference_masks = gen_inference_masks( |
|
|
resized_masks, |
|
|
(lat_h // 2, lat_w // 2), |
|
|
num_frames=latent_frame_num |
|
|
) |
|
|
bboxes = face_info['bboxes'] |
|
|
|
|
|
|
|
|
face_clip_context_list = [] |
|
|
for i, bbox in enumerate(bboxes): |
|
|
try: |
|
|
|
|
|
bbox_x, bbox_y, bbox_w, bbox_h = bbox |
|
|
bbox_converted = [bbox_x, bbox_y, bbox_x + bbox_w, bbox_y + bbox_h] |
|
|
|
|
|
|
|
|
cropped_face, adjusted_bbox = expand_bbox_and_crop_image( |
|
|
img, |
|
|
bbox_converted, |
|
|
width_scale_factor=1, |
|
|
height_scale_factor=1 |
|
|
) |
|
|
|
|
|
cropped_face = cropped_face.to(self.device) |
|
|
|
|
|
if self.use_half: |
|
|
cropped_face = cropped_face.to(dtype=self.half_dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
face_clip_context = self.clip.visual([cropped_face[:, None, :, :]]) |
|
|
|
|
|
if self.use_half: |
|
|
face_clip_context = face_clip_context.to(dtype=self.half_dtype) |
|
|
|
|
|
|
|
|
face_clip_context_list.append(face_clip_context) |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
print(f"error on face {i+1}: {e}") |
|
|
continue |
|
|
|
|
|
print(f"face feature extraction loop completed, successfully processed {len(face_clip_context_list)} faces") |
|
|
|
|
|
|
|
|
face_clip_context = face_clip_context_list[0] if face_clip_context_list else None |
|
|
|
|
|
else: |
|
|
y = torch.concat([msk, y]) |
|
|
|
|
|
|
|
|
audio_feat_list = process_audio_features( |
|
|
audio_paths=audio_paths, |
|
|
audio=audio, |
|
|
mode=mode, |
|
|
F=F, |
|
|
frame_num=frame_num, |
|
|
task_key=task_key, |
|
|
fps=self.config.fps, |
|
|
wav2vec_model=self.config.wav2vec, |
|
|
vocal_separator_model=self.config.vocal_separator_path, |
|
|
audio_output_dir=self.config.audio_output_dir, |
|
|
device=self.device, |
|
|
use_half=self.use_half, |
|
|
half_dtype=self.half_dtype, |
|
|
preprocess_audio=preprocess_audio, |
|
|
resample_audio=resample_audio, |
|
|
) |
|
|
|
|
|
|
|
|
audio_ref_features = { |
|
|
"ref_face": face_clip_context, |
|
|
} |
|
|
|
|
|
|
|
|
ref_face_list = face_clip_context_list.copy() if face_clip_context_list else [] |
|
|
audio_list = audio_feat_list.copy() if audio_feat_list else [] |
|
|
|
|
|
|
|
|
audio_ref_features["ref_face_list"] = ref_face_list |
|
|
audio_ref_features["audio_list"] = audio_list |
|
|
|
|
|
@contextmanager |
|
|
def noop_no_sync(): |
|
|
yield |
|
|
|
|
|
no_sync = getattr(self.model, 'no_sync', noop_no_sync) |
|
|
|
|
|
if offload_model: |
|
|
self.clip.model.cpu() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): |
|
|
|
|
|
if sample_solver == 'unipc': |
|
|
sample_scheduler = FlowUniPCMultistepScheduler( |
|
|
num_train_timesteps=self.num_train_timesteps, |
|
|
shift=5, |
|
|
use_dynamic_shifting=False) |
|
|
sample_scheduler.set_timesteps( |
|
|
sampling_steps, device=self.device, shift=shift) |
|
|
timesteps = sample_scheduler.timesteps |
|
|
elif sample_solver == 'dpm++': |
|
|
sample_scheduler = FlowDPMSolverMultistepScheduler( |
|
|
num_train_timesteps=self.num_train_timesteps, |
|
|
shift=1, |
|
|
use_dynamic_shifting=False) |
|
|
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) |
|
|
timesteps, _ = retrieve_timesteps( |
|
|
sample_scheduler, |
|
|
device=self.device, |
|
|
sigmas=sampling_sigmas) |
|
|
else: |
|
|
raise NotImplementedError("Unsupported solver.") |
|
|
|
|
|
|
|
|
latent = noise |
|
|
|
|
|
if self.use_half and latent.dtype != self.half_dtype: |
|
|
latent = latent.to(dtype=self.half_dtype) |
|
|
|
|
|
|
|
|
if self.use_half: |
|
|
context = [c.to(dtype=self.half_dtype) for c in context] |
|
|
context_null = [c.to(dtype=self.half_dtype) for c in context_null] |
|
|
|
|
|
|
|
|
arg_c = { |
|
|
'context': [context[0]], |
|
|
'clip_fea': clip_context, |
|
|
'seq_len': max_seq_len, |
|
|
'y': [y], |
|
|
'audio_ref_features': audio_ref_features |
|
|
} |
|
|
|
|
|
|
|
|
arg_null = { |
|
|
'context': context_null, |
|
|
'clip_fea': clip_context, |
|
|
'seq_len': max_seq_len, |
|
|
'y': [y], |
|
|
'audio_ref_features': create_null_audio_ref_features(audio_ref_features) |
|
|
} |
|
|
|
|
|
|
|
|
if face_processor is not None and 'inference_masks' in locals(): |
|
|
|
|
|
arg_c['face_mask_list'] = copy.deepcopy(inference_masks['face_mask_list']) |
|
|
arg_null['face_mask_list'] = copy.deepcopy(inference_masks['face_mask_list']) |
|
|
|
|
|
if offload_model: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.model.to(self.device) |
|
|
masks_flattened = False |
|
|
|
|
|
for i, t in enumerate(tqdm(timesteps)): |
|
|
|
|
|
latent = latent.to(self.device) |
|
|
if self.use_half and latent.dtype != self.half_dtype: |
|
|
latent = latent.to(dtype=self.half_dtype) |
|
|
latent_model_input = [latent] |
|
|
timestep = [t] |
|
|
|
|
|
timestep = torch.stack(timestep).to(self.device) |
|
|
|
|
|
model_output_cond = self.model( |
|
|
latent_model_input, |
|
|
t=timestep, |
|
|
masks_flattened=masks_flattened, |
|
|
**arg_c |
|
|
) |
|
|
noise_pred_cond = model_output_cond[0] |
|
|
|
|
|
if self.use_half and noise_pred_cond.dtype != self.half_dtype: |
|
|
noise_pred_cond = noise_pred_cond.to(dtype=self.half_dtype) |
|
|
noise_pred_cond = noise_pred_cond.to( |
|
|
torch.device('cpu') if offload_model else self.device) |
|
|
|
|
|
if offload_model: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
if not cfg_zero: |
|
|
model_output_uncond = self.model( |
|
|
latent_model_input, |
|
|
t=timestep, |
|
|
masks_flattened=masks_flattened, |
|
|
**arg_null |
|
|
) |
|
|
noise_pred_uncond = model_output_uncond[0] |
|
|
|
|
|
if self.use_half and noise_pred_uncond.dtype != self.half_dtype: |
|
|
noise_pred_uncond = noise_pred_uncond.to(dtype=self.half_dtype) |
|
|
noise_pred_uncond = noise_pred_uncond.to( |
|
|
torch.device('cpu') if offload_model else self.device) |
|
|
else: |
|
|
noise_pred_uncond = None |
|
|
|
|
|
masks_flattened = True |
|
|
|
|
|
if offload_model: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
if cfg_zero: |
|
|
noise_pred = noise_pred_cond |
|
|
else: |
|
|
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
|
|
|
|
|
if self.use_half and noise_pred.dtype != self.half_dtype: |
|
|
noise_pred = noise_pred.to(dtype=self.half_dtype) |
|
|
|
|
|
latent = latent.to( |
|
|
torch.device('cpu') if offload_model else self.device |
|
|
) |
|
|
|
|
|
|
|
|
noise_pred_for_scheduler = noise_pred.float() if self.use_half else noise_pred |
|
|
latent_for_scheduler = latent.float() if self.use_half else latent |
|
|
|
|
|
temp_x0 = sample_scheduler.step( |
|
|
noise_pred_for_scheduler.unsqueeze(0), |
|
|
t, |
|
|
latent_for_scheduler.unsqueeze(0), |
|
|
return_dict=False, |
|
|
generator=seed_g)[0] |
|
|
|
|
|
|
|
|
if self.use_half: |
|
|
temp_x0 = temp_x0.to(dtype=self.half_dtype) |
|
|
latent = temp_x0.squeeze(0) |
|
|
x0 = [latent.to(self.device)] |
|
|
del latent_model_input, timestep |
|
|
|
|
|
if offload_model: |
|
|
self.model.cpu() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
x0_for_decode = [x.float() if self.use_half else x for x in x0] |
|
|
videos = self.vae.decode(x0_for_decode) |
|
|
result_videos = videos[0] |
|
|
|
|
|
del noise, latent |
|
|
del sample_scheduler |
|
|
if offload_model: |
|
|
gc.collect() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
return result_videos if self.rank == 0 else None |
|
|
|