|
|
import spaces |
|
|
import os |
|
|
import sys |
|
|
|
|
|
os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1' |
|
|
os.environ['TORCH_LINALG_PREFER_CUSOLVER'] = '1' |
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,pinned_use_background_threads:True' |
|
|
os.environ["SAFETENSORS_FAST_GPU"] = "1" |
|
|
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' |
|
|
|
|
|
import torch |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False |
|
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False |
|
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False |
|
|
torch.backends.cudnn.allow_tf32 = False |
|
|
torch.backends.cudnn.deterministic = False |
|
|
torch.set_float32_matmul_precision("highest") |
|
|
|
|
|
torch.backends.cudnn.benchmark = False |
|
|
torch.backends.cuda.preferred_blas_library="cublas" |
|
|
torch.backends.cuda.preferred_linalg_library="cusolver" |
|
|
|
|
|
FTP_HOST = os.getenv("FTP_HOST") |
|
|
FTP_USER = os.getenv("FTP_USER") |
|
|
FTP_PASS = os.getenv("FTP_PASS") |
|
|
FTP_DIR = os.getenv("FTP_DIR") |
|
|
|
|
|
import cv2 |
|
|
import gc |
|
|
import subprocess |
|
|
import paramiko |
|
|
from image_gen_aux import UpscaleWithModel |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
import random |
|
|
import yaml |
|
|
from pathlib import Path |
|
|
import imageio |
|
|
import tempfile |
|
|
from PIL import Image |
|
|
from huggingface_hub import hf_hub_download |
|
|
import shutil |
|
|
from diffusers import StableDiffusionXLImg2ImgPipeline, AutoencoderKL |
|
|
from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline |
|
|
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy |
|
|
from inference import ( |
|
|
create_ltx_video_pipeline, |
|
|
create_latent_upsampler, |
|
|
load_image_to_tensor_with_resize_and_crop, |
|
|
seed_everething, |
|
|
get_device, |
|
|
calculate_padding, |
|
|
load_media_file |
|
|
) |
|
|
from moviepy.editor import VideoFileClip, concatenate_videoclips |
|
|
from typing import Any, Dict, Optional, Tuple |
|
|
|
|
|
|
|
|
from ltx_video.models.transformers.transformer3d import Transformer3DModel, Transformer3DModelOutput |
|
|
from diffusers.utils import logging |
|
|
import re |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
original_transformer_forward = Transformer3DModel.forward |
|
|
|
|
|
|
|
|
def teacache_wrapper_forward(self, hidden_states: torch.Tensor, **kwargs): |
|
|
if not hasattr(self, "enable_teacache") or not self.enable_teacache: |
|
|
|
|
|
return original_transformer_forward(self, hidden_states=hidden_states, **kwargs) |
|
|
|
|
|
|
|
|
should_calc = True |
|
|
if self.cnt > 0 and self.cnt < self.num_steps - 1: |
|
|
if (hasattr(self, "previous_hidden_states") and |
|
|
self.previous_hidden_states is not None and |
|
|
self.previous_hidden_states.shape == hidden_states.shape): |
|
|
rel_l1_dist = ((hidden_states - self.previous_hidden_states).abs().mean() / self.previous_hidden_states.abs().mean()).cpu().item() |
|
|
self.accumulated_rel_l1_distance += rel_l1_dist |
|
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh: |
|
|
should_calc = False |
|
|
else: |
|
|
self.accumulated_rel_l1_distance = 0 |
|
|
else: |
|
|
|
|
|
self.accumulated_rel_l1_distance = 0 |
|
|
|
|
|
self.cnt += 1 |
|
|
|
|
|
if not should_calc and hasattr(self, "previous_residual") and self.previous_residual is not None and self.previous_residual.shape == hidden_states.shape: |
|
|
|
|
|
|
|
|
return Transformer3DModelOutput(sample=self.previous_residual + hidden_states) |
|
|
else: |
|
|
|
|
|
self.previous_hidden_states = hidden_states.clone() |
|
|
output = original_transformer_forward(self, hidden_states=hidden_states, **kwargs) |
|
|
|
|
|
if isinstance(output, tuple): |
|
|
output_tensor = output[0] |
|
|
else: |
|
|
output_tensor = output.sample |
|
|
|
|
|
self.previous_residual = output_tensor - hidden_states |
|
|
return output |
|
|
|
|
|
Transformer3DModel.forward = teacache_wrapper_forward |
|
|
print("✅ Transformer3DModel patched with robust TeaCache Wrapper.") |
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
upscaler = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(torch.device("cuda:0")) |
|
|
|
|
|
print("Loading SDXL Image-to-Image pipeline...") |
|
|
enhancer_pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
|
|
"ford442/stable-diffusion-xl-refiner-1.0-bf16", |
|
|
use_safetensors=True, |
|
|
requires_aesthetics_score=True, |
|
|
) |
|
|
enhancer_pipeline.vae.set_default_attn_processor() |
|
|
enhancer_pipeline.to("cpu") |
|
|
print("SDXL Image-to-Image pipeline loaded successfully.") |
|
|
|
|
|
config_file_path = "configs/ltxv-13b-0.9.8-distilled.yaml" |
|
|
with open(config_file_path, "r") as file: |
|
|
PIPELINE_CONFIG_YAML = yaml.safe_load(file) |
|
|
|
|
|
LTX_REPO = "Lightricks/LTX-Video" |
|
|
MAX_IMAGE_SIZE = PIPELINE_CONFIG_YAML.get("max_resolution", 1280) |
|
|
MAX_NUM_FRAMES = 900 |
|
|
|
|
|
pipeline_instance = None |
|
|
latent_upsampler_instance = None |
|
|
models_dir = "downloaded_models_gradio_cpu_init" |
|
|
Path(models_dir).mkdir(parents=True, exist_ok=True) |
|
|
print("Downloading models (if not present)...") |
|
|
distilled_model_actual_path = hf_hub_download(repo_id=LTX_REPO, filename=PIPELINE_CONFIG_YAML["checkpoint_path"], local_dir=models_dir, local_dir_use_symlinks=False) |
|
|
PIPELINE_CONFIG_YAML["checkpoint_path"] = distilled_model_actual_path |
|
|
SPATIAL_UPSCALER_FILENAME = PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] |
|
|
spatial_upscaler_actual_path = hf_hub_download(repo_id=LTX_REPO, filename=SPATIAL_UPSCALER_FILENAME, local_dir=models_dir, local_dir_use_symlinks=False) |
|
|
PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] = spatial_upscaler_actual_path |
|
|
print("Creating LTX Video pipeline on CPU...") |
|
|
pipeline_instance = create_ltx_video_pipeline(ckpt_path=PIPELINE_CONFIG_YAML["checkpoint_path"], precision=PIPELINE_CONFIG_YAML["precision"], text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"], sampler=PIPELINE_CONFIG_YAML["sampler"], device="cpu", enhance_prompt=False, prompt_enhancer_image_caption_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_image_caption_model_name_or_path"], prompt_enhancer_llm_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_llm_model_name_or_path"]) |
|
|
if PIPELINE_CONFIG_YAML.get("spatial_upscaler_model_path"): |
|
|
print("Creating latent upsampler on CPU...") |
|
|
latent_upsampler_instance = create_latent_upsampler(PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"], device="cpu") |
|
|
target_inference_device = "cuda" |
|
|
print(f"Target inference device: {target_inference_device}") |
|
|
pipeline_instance.to(target_inference_device) |
|
|
if latent_upsampler_instance: latent_upsampler_instance.to(target_inference_device) |
|
|
|
|
|
from diffusers.models.attention_processor import AttnProcessor2_0 |
|
|
|
|
|
from kernels import get_kernel |
|
|
|
|
|
fa3_kernel = get_kernel("kernels-community/flash-attn3") |
|
|
|
|
|
class FlashAttentionProcessor(AttnProcessor2_0): |
|
|
def __call__( |
|
|
self, |
|
|
attn, |
|
|
hidden_states, |
|
|
encoder_hidden_states=None, |
|
|
attention_mask=None, |
|
|
**kwargs, |
|
|
): |
|
|
query = attn.to_q(hidden_states) |
|
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states |
|
|
key = attn.to_k(encoder_hidden_states) |
|
|
value = attn.to_v(encoder_hidden_states) |
|
|
scale = attn.scale |
|
|
query = query * scale |
|
|
b, t, c = query.shape |
|
|
h = attn.heads |
|
|
d = c // h |
|
|
q_reshaped = query.reshape(b, t, h, d).permute(0, 2, 1, 3) |
|
|
k_reshaped = key.reshape(b, t, h, d).permute(0, 2, 1, 3) |
|
|
v_reshaped = value.reshape(b, t, h, d).permute(0, 2, 1, 3) |
|
|
out_reshaped = torch.empty_like(q_reshaped) |
|
|
fa3_kernel.attention(q_reshaped, k_reshaped, v_reshaped, out_reshaped) |
|
|
out = out_reshaped.permute(0, 2, 1, 3).reshape(b, t, c) |
|
|
out = attn.to_out(out) |
|
|
return out |
|
|
|
|
|
fa_processor = FlashAttentionProcessor() |
|
|
|
|
|
|
|
|
for name, module in pipeline_instance.transformer.named_modules(): |
|
|
if isinstance(module, AttnProcessor2_0): |
|
|
module.processor = fa_processor |
|
|
|
|
|
def upload_to_sftp(local_filepath): |
|
|
if not all([FTP_HOST, FTP_USER, FTP_PASS, FTP_DIR]): |
|
|
print("SFTP credentials not set. Skipping upload.") |
|
|
return |
|
|
try: |
|
|
transport = paramiko.Transport((FTP_HOST, 22)) |
|
|
transport.connect(username=FTP_USER, password=FTP_PASS) |
|
|
sftp = paramiko.SFTPClient.from_transport(transport) |
|
|
remote_filename = os.path.basename(local_filepath) |
|
|
remote_filepath = os.path.join(FTP_DIR, remote_filename) |
|
|
print(f"Uploading {local_filepath} to {remote_filepath}...") |
|
|
sftp.put(local_filepath, remote_filepath) |
|
|
print("Upload successful.") |
|
|
sftp.close() |
|
|
transport.close() |
|
|
except Exception as e: |
|
|
print(f"SFTP upload failed: {e}") |
|
|
gr.Warning(f"SFTP upload failed: {e}") |
|
|
|
|
|
|
|
|
def calculate_new_dimensions(orig_w, orig_h): |
|
|
if orig_w == 0 or orig_h == 0: return int(1024), int(1024) |
|
|
if orig_w >= orig_h: |
|
|
new_h, new_w = 1024, round((1024 * (orig_w / orig_h)) / 32) * 32 |
|
|
else: |
|
|
new_w, new_h = 1024, round((1024 * (orig_h / orig_w)) / 32) * 32 |
|
|
return int(max(256, min(new_h, MAX_IMAGE_SIZE))), int(max(256, min(new_w, MAX_IMAGE_SIZE))) |
|
|
|
|
|
|
|
|
def get_duration(*args, **kwargs): |
|
|
duration_ui = kwargs.get('duration_ui', 5.0) |
|
|
if duration_ui > 7.0: return 110 |
|
|
if duration_ui > 5.0: return 100 |
|
|
if duration_ui > 4.0: return 90 |
|
|
if duration_ui > 3.0: return 70 |
|
|
if duration_ui > 2.0: return 60 |
|
|
if duration_ui > 1.5: return 50 |
|
|
if duration_ui > 1.0: return 45 |
|
|
if duration_ui > 0.5: return 30 |
|
|
return 90 |
|
|
|
|
|
@spaces.GPU(duration=20) |
|
|
def superres_image(image_to_enhance: Image.Image): |
|
|
print("Doing super-resolution.") |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True |
|
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.set_float32_matmul_precision("medium") |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
with torch.no_grad(): |
|
|
upscale_a = upscaler(image_to_enhance, tiling=True, tile_width=256, tile_height=256) |
|
|
upscale = upscaler(upscale_a, tiling=True, tile_width=256, tile_height=256) |
|
|
enhanced_image_a = upscale.resize((upscale.width // 4, upscale.height // 4), Image.LANCZOS) |
|
|
enhanced_image = enhanced_image_a.resize((enhanced_image_a.width // 4, enhanced_image_a.height // 4), Image.LANCZOS) |
|
|
return enhanced_image |
|
|
|
|
|
@spaces.GPU(duration=30) |
|
|
def enhance_frame(prompt, image_to_enhance: Image.Image): |
|
|
try: |
|
|
print("Moving enhancer pipeline to GPU...") |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
generator = torch.Generator(device='cuda').manual_seed(seed) |
|
|
enhancer_pipeline.to("cuda",torch.bfloat16) |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
refine_prompt = prompt +" high detail, sharp focus, 1024x1024, professional" |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True |
|
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.set_float32_matmul_precision("high") |
|
|
enhanced_image = enhancer_pipeline(prompt=refine_prompt, image=image_to_enhance, strength=0.07, generator=generator, num_inference_steps=180).images[0] |
|
|
print("Frame enhancement successful.") |
|
|
except Exception as e: |
|
|
print(f"Error during frame enhancement: {e}") |
|
|
gr.Warning("Frame enhancement failed. Using original frame.") |
|
|
return image_to_enhance |
|
|
finally: |
|
|
print("Moving enhancer pipeline to CPU...") |
|
|
enhancer_pipeline.to("cpu") |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
return enhanced_image |
|
|
|
|
|
|
|
|
def use_last_frame_as_input(prompt, video_filepath, do_enhance, do_superres): |
|
|
if not video_filepath or not os.path.exists(video_filepath): |
|
|
gr.Warning("No video clip available.") |
|
|
return None, gr.update() |
|
|
cap = None |
|
|
try: |
|
|
cap = cv2.VideoCapture(video_filepath) |
|
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) |
|
|
ret, frame = cap.read() |
|
|
if not ret: raise ValueError("Failed to read frame.") |
|
|
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
|
print("Displaying original last frame...") |
|
|
yield pil_image, gr.update() |
|
|
if do_superres: |
|
|
pil_image = superres_image(pil_image) |
|
|
if do_enhance: |
|
|
enhanced_image = enhance_frame(prompt, pil_image) |
|
|
if do_superres: |
|
|
enhanced_image = superres_image(enhanced_image) |
|
|
print("Displaying enhanced frame and switching tab...") |
|
|
yield enhanced_image, gr.update(selected="i2v_tab") |
|
|
else: |
|
|
if do_superres: |
|
|
pil_image = superres_image(pil_image) |
|
|
yield pil_image, gr.update(selected="i2v_tab") |
|
|
except Exception as e: |
|
|
gr.Error(f"Failed to extract frame: {e}") |
|
|
return None, gr.update() |
|
|
finally: |
|
|
if cap: cap.release() |
|
|
|
|
|
|
|
|
def stitch_videos(clips_list): |
|
|
if not clips_list or len(clips_list) < 2: |
|
|
raise gr.Error("You need at least two clips to stitch them together!") |
|
|
print(f"Stitching {len(clips_list)} clips...") |
|
|
try: |
|
|
video_clips = [VideoFileClip(clip_path) for clip_path in clips_list] |
|
|
final_clip = concatenate_videoclips(video_clips, method="compose") |
|
|
final_output_path = os.path.join(tempfile.mkdtemp(), f"stitched_video_{random.randint(10000,99999)}.mp4") |
|
|
final_clip.write_videofile(final_output_path, codec="libx264", audio=False, threads=4, preset='ultrafast') |
|
|
for clip in video_clips: |
|
|
clip.close() |
|
|
return final_output_path |
|
|
except Exception as e: |
|
|
raise gr.Error(f"Failed to stitch videos: {e}") |
|
|
|
|
|
|
|
|
def clear_clips(): |
|
|
return [], "Clips created: 0", None, None |
|
|
|
|
|
@spaces.GPU(duration=get_duration) |
|
|
def generate(prompt, negative_prompt, clips_list, input_image_filepath, input_video_filepath, |
|
|
height_ui, width_ui, mode, duration_ui, ui_frames_to_use, |
|
|
seed_ui, randomize_seed, ui_guidance_scale, improve_texture_flag, num_steps, fps, |
|
|
enable_teacache, teacache_threshold, |
|
|
progress=gr.Progress(track_tqdm=True)): |
|
|
|
|
|
|
|
|
try: |
|
|
pipeline_instance.transformer.enable_teacache = enable_teacache |
|
|
if enable_teacache: |
|
|
print(f"✅ TeaCache is ENABLED with threshold: {teacache_threshold}") |
|
|
pipeline_instance.transformer.rel_l1_thresh = teacache_threshold |
|
|
except AttributeError: |
|
|
print("⚠️ Could not configure TeaCache on transformer.") |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False |
|
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False |
|
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False |
|
|
torch.backends.cudnn.allow_tf32 = False |
|
|
torch.backends.cudnn.deterministic = False |
|
|
torch.set_float32_matmul_precision("highest") |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
|
|
if mode not in ["text-to-video", "image-to-video", "video-to-video"]: |
|
|
raise gr.Error(f"Invalid mode: {mode}.") |
|
|
if mode == "image-to-video" and not input_image_filepath: |
|
|
raise gr.Error("input_image_filepath is required for image-to-video mode") |
|
|
elif mode == "video-to-video" and not input_video_filepath: |
|
|
raise gr.Error("input_video_filepath is required for video-to-video mode") |
|
|
if randomize_seed: seed_ui = random.randint(0, 2**32 - 1) |
|
|
seed_everething(int(seed_ui)) |
|
|
|
|
|
actual_num_frames = max(9, min(MAX_NUM_FRAMES, int(round((max(1, round(duration_ui * fps)) - 1.0) / 8.0) * 8 + 1))) |
|
|
actual_height, actual_width = int(height_ui), int(width_ui) |
|
|
height_padded, width_padded = ((actual_height - 1) // 32 + 1) * 32, ((actual_width - 1) // 32 + 1) * 32 |
|
|
padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded) |
|
|
num_frames_padded = max(9, ((actual_num_frames - 2) // 8 + 1) * 8 + 1) |
|
|
|
|
|
call_kwargs = { |
|
|
"prompt": prompt, "negative_prompt": negative_prompt, "height": height_padded, "width": width_padded, |
|
|
"num_frames": num_frames_padded, "num_inference_steps": num_steps, "frame_rate": int(fps), |
|
|
"generator": torch.Generator(device=target_inference_device).manual_seed(int(seed_ui)), |
|
|
"output_type": "pt", "conditioning_items": None, "media_items": None, |
|
|
"decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"], |
|
|
"decode_noise_scale": PIPELINE_CONFIG_YAML["decode_noise_scale"], |
|
|
"stochastic_sampling": PIPELINE_CONFIG_YAML["stochastic_sampling"], |
|
|
"image_cond_noise_scale": 0.15, "is_video": True, "vae_per_channel_normalize": True, |
|
|
"mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "mixed_precision"), |
|
|
"offload_to_cpu": False, "enhance_prompt": False |
|
|
} |
|
|
|
|
|
stg_mode_str = PIPELINE_CONFIG_YAML.get("stg_mode", "attention_values").lower() |
|
|
stg_map = { |
|
|
"stg_av": SkipLayerStrategy.AttentionValues, "attention_values": SkipLayerStrategy.AttentionValues, |
|
|
"stg_as": SkipLayerStrategy.AttentionSkip, "attention_skip": SkipLayerStrategy.AttentionSkip, |
|
|
"stg_r": SkipLayerStrategy.Residual, "residual": SkipLayerStrategy.Residual, |
|
|
"stg_t": SkipLayerStrategy.TransformerBlock, "transformer_block": SkipLayerStrategy.TransformerBlock |
|
|
} |
|
|
call_kwargs["skip_layer_strategy"] = stg_map.get(stg_mode_str, SkipLayerStrategy.AttentionValues) |
|
|
|
|
|
if mode == "image-to-video": |
|
|
media_tensor = load_image_to_tensor_with_resize_and_crop(input_image_filepath, actual_height, actual_width) |
|
|
call_kwargs["conditioning_items"] = [ConditioningItem(torch.nn.functional.pad(media_tensor, padding_values).to(target_inference_device), 0, 1.0)] |
|
|
elif mode == "video-to-video": |
|
|
call_kwargs["media_items"] = load_media_file(media_path=input_video_filepath, height=actual_height, width=actual_width, max_frames=int(ui_frames_to_use), padding=padding_values).to(target_inference_device) |
|
|
|
|
|
if improve_texture_flag and latent_upsampler_instance: |
|
|
multi_scale_pipeline = LTXMultiScalePipeline(pipeline_instance, latent_upsampler_instance) |
|
|
pass_args = {"guidance_scale": float(ui_guidance_scale)} |
|
|
|
|
|
multi_scale_kwargs = { |
|
|
**call_kwargs, |
|
|
"downscale_factor": PIPELINE_CONFIG_YAML["downscale_factor"], |
|
|
"first_pass": {**PIPELINE_CONFIG_YAML.get("first_pass", {}), **pass_args}, |
|
|
"second_pass": {**PIPELINE_CONFIG_YAML.get("second_pass", {}), **pass_args} |
|
|
} |
|
|
|
|
|
|
|
|
first_pass_steps = multi_scale_kwargs.get("first_pass", {}).get("num_inference_steps", num_steps) |
|
|
pipeline_instance.transformer.num_steps = first_pass_steps |
|
|
pipeline_instance.transformer.cnt = 0 |
|
|
pipeline_instance.transformer.previous_hidden_states = None |
|
|
pipeline_instance.transformer.previous_residual = None |
|
|
pipeline_instance.transformer.accumulated_rel_l1_distance = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result_images_tensor = multi_scale_pipeline(**multi_scale_kwargs).images |
|
|
else: |
|
|
pipeline_instance.transformer.num_steps = num_steps |
|
|
pipeline_instance.transformer.cnt = 0 |
|
|
pipeline_instance.transformer.previous_hidden_states = None |
|
|
pipeline_instance.transformer.previous_residual = None |
|
|
pipeline_instance.transformer.accumulated_rel_l1_distance = 0 |
|
|
|
|
|
single_pass_kwargs = {**call_kwargs, "guidance_scale": float(ui_guidance_scale), **PIPELINE_CONFIG_YAML.get("first_pass", {})} |
|
|
result_images_tensor = pipeline_instance(**single_pass_kwargs).images |
|
|
|
|
|
if result_images_tensor is None: raise gr.Error("Generation failed.") |
|
|
pad_l, pad_r, pad_t, pad_b = padding_values |
|
|
result_images_tensor = result_images_tensor[:, :, :actual_num_frames, pad_t:(-pad_b or None), pad_l:(-pad_r or None)] |
|
|
video_np = (np.clip(result_images_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy(), 0, 1) * 255).astype(np.uint8) |
|
|
output_video_path = os.path.join(tempfile.mkdtemp(), f"output_{random.randint(10000,99999)}.mp4") |
|
|
with imageio.get_writer(output_video_path, format='FFMPEG', fps=call_kwargs["frame_rate"], codec='libx264', quality=10, pixelformat='yuv420p') as video_writer: |
|
|
for idx, frame in enumerate(video_np): |
|
|
progress(idx / len(video_np), desc="Saving video clip...") |
|
|
video_writer.append_data(frame) |
|
|
|
|
|
updated_clips_list = clips_list + [output_video_path] |
|
|
counter_text = f"Clips created: {len(updated_clips_list)}" |
|
|
return output_video_path, seed_ui, gr.update(visible=True), updated_clips_list, counter_text |
|
|
|
|
|
|
|
|
|
|
|
def update_task_image(): |
|
|
return "image-to-video" |
|
|
def update_task_text(): |
|
|
return "text-to-video" |
|
|
def update_task_video(): |
|
|
return "video-to-video" |
|
|
|
|
|
css="""#col-container{margin:0 auto;max-width:900px;}""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
clips_state = gr.State([]) |
|
|
gr.Markdown("# LTX Video Clip Stitcher") |
|
|
gr.Markdown("Generate short video clips and stitch them together to create a longer animation.") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
with gr.Tabs() as tabs: |
|
|
with gr.Tab("image-to-video", id="i2v_tab") as image_tab: |
|
|
video_i_hidden = gr.Textbox(visible=False); |
|
|
image_i2v = gr.Image(label="Input Image", type="filepath", sources=["upload", "webcam", "clipboard"]); |
|
|
i2v_prompt = gr.Textbox(label="Prompt", value="The creature from the image starts to move", lines=3); |
|
|
i2v_button = gr.Button("Generate Image-to-Video Clip", variant="primary") |
|
|
with gr.Tab("text-to-video", id="t2v_tab") as text_tab: |
|
|
image_n_hidden = gr.Textbox(visible=False); |
|
|
video_n_hidden = gr.Textbox(visible=False); t2v_prompt = gr.Textbox(label="Prompt", value="A majestic dragon flying over a medieval castle", lines=3); |
|
|
t2v_button = gr.Button("Generate Text-to-Video Clip", variant="primary") |
|
|
with gr.Tab("video-to-video", id="v2v_tab") as video_tab: |
|
|
image_v_hidden = gr.Textbox(visible=False); |
|
|
video_v2v = gr.Video(label="Input Video", sources=["upload", "webcam"]); |
|
|
frames_to_use = gr.Slider(label="Frames to use from input video", minimum=9, maximum=120, value=9, step=8, info="Must be N*8+1."); |
|
|
v2v_prompt = gr.Textbox(label="Prompt", value="Change the style to cinematic anime", lines=3); |
|
|
v2v_button = gr.Button("Generate Video-to-Video Clip", variant="primary") |
|
|
duration_input = gr.Slider(label="Clip Duration (seconds)", minimum=1.0, maximum=10.0, value=2.0, step=0.1) |
|
|
improve_texture = gr.Checkbox(label="Improve Texture (multi-scale)", value=True) |
|
|
enhance_checkbox = gr.Checkbox(label="Improve Frame (SDXL Refiner)", value=True) |
|
|
superres_checkbox = gr.Checkbox(label="Upscale Frame (ClearRealityV1)", value=True) |
|
|
with gr.Column(): |
|
|
output_video = gr.Video(label="Last Generated Clip", interactive=False) |
|
|
use_last_frame_button = gr.Button("Use Last Frame as Input Image", visible=False) |
|
|
with gr.Accordion("Stitching Controls", open=True): |
|
|
clip_counter_display = gr.Markdown("Clips created: 0") |
|
|
with gr.Row(): stitch_button = gr.Button("🎬 Stitch All Clips"); clear_button = gr.Button("🗑️ Clear All Clips") |
|
|
final_video_output = gr.Video(label="Final Stitched Video", interactive=False) |
|
|
with gr.Accordion("Advanced settings", open=False): |
|
|
mode = gr.Dropdown(["text-to-video", "image-to-video", "video-to-video"], label="task", value="image-to-video", visible=False); |
|
|
negative_prompt_input = gr.Textbox(label="Negative Prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted", lines=2) |
|
|
with gr.Row(): |
|
|
teacache_checkbox = gr.Checkbox(label="Enable TeaCache Acceleration", value=True) |
|
|
teacache_slider = gr.Slider( |
|
|
minimum=0.01, |
|
|
maximum=0.1, |
|
|
step=0.01, |
|
|
value=0.05, |
|
|
label="TeaCache Threshold (Higher = Faster)" |
|
|
) |
|
|
with gr.Row(): |
|
|
seed_input = gr.Number(label="Seed", value=42, precision=0); |
|
|
randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True) |
|
|
with gr.Row(visible=False): |
|
|
guidance_scale_input = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0), step=0.1) |
|
|
with gr.Row(): |
|
|
height_input = gr.Slider(label="Height", value=1024, step=32, minimum=32, maximum=MAX_IMAGE_SIZE); |
|
|
width_input = gr.Slider(label="Width", value=1024, step=32, minimum=32, maximum=MAX_IMAGE_SIZE); |
|
|
num_steps = gr.Slider(label="Steps", value=30, step=1, minimum=1, maximum=420); |
|
|
fps = gr.Slider(label="FPS", value=30.0, step=1.0, minimum=4.0, maximum=60.0) |
|
|
def handle_image_upload_for_dims(f, h, w): |
|
|
if not f: return gr.update(value=h), gr.update(value=w) |
|
|
img = Image.open(f); new_h, new_w = calculate_new_dimensions(img.width, img.height); return gr.update(value=new_h), gr.update(value=new_w) |
|
|
def handle_video_upload_for_dims(f, h, w): |
|
|
if not f or not os.path.exists(str(f)): return gr.update(value=h), gr.update(value=w) |
|
|
with imageio.get_reader(str(f)) as reader: |
|
|
meta = reader.get_meta_data(); orig_w, orig_h = meta.get('size', (reader.get_data(0).shape[1], reader.get_data(0).shape[0])); |
|
|
new_h, new_w = calculate_new_dimensions(orig_w, orig_h); return gr.update(value=new_h), gr.update(value=new_w) |
|
|
image_i2v.upload(handle_image_upload_for_dims, [image_i2v, height_input, width_input], [height_input, width_input]); |
|
|
video_v2v.upload(handle_video_upload_for_dims, [video_v2v, height_input, width_input], [height_input, width_input]); |
|
|
image_tab.select(update_task_image, outputs=[mode]); text_tab.select(update_task_text, outputs=[mode]); |
|
|
video_tab.select(update_task_video, outputs=[mode]) |
|
|
common_params = [height_input, width_input, mode, duration_input, frames_to_use, seed_input, randomize_seed_input, guidance_scale_input, improve_texture, num_steps, fps, teacache_checkbox, teacache_slider] |
|
|
t2v_inputs = [t2v_prompt, negative_prompt_input, clips_state, image_n_hidden, video_n_hidden] + common_params; |
|
|
i2v_inputs = [i2v_prompt, negative_prompt_input, clips_state, image_i2v, video_i_hidden] + common_params; |
|
|
v2v_inputs = [v2v_prompt, negative_prompt_input, clips_state, image_v_hidden, video_v2v] + common_params |
|
|
gen_outputs = [output_video, seed_input, use_last_frame_button, clips_state, clip_counter_display] |
|
|
hide_btn = lambda: gr.update(visible=False) |
|
|
t2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=t2v_inputs, outputs=gen_outputs, api_name="text_to_video") |
|
|
i2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=i2v_inputs, outputs=gen_outputs, api_name="image_to_video") |
|
|
v2v_button.click(hide_btn, outputs=[use_last_frame_button], queue=False).then(fn=generate, inputs=v2v_inputs, outputs=gen_outputs, api_name="video_to_video") |
|
|
use_last_frame_button.click(fn=use_last_frame_as_input, inputs=[i2v_prompt,output_video,enhance_checkbox, superres_checkbox], outputs=[image_i2v, tabs]) |
|
|
stitch_button.click(fn=stitch_videos, inputs=[clips_state], outputs=[final_video_output]) |
|
|
clear_button.click(fn=clear_clips, outputs=[clips_state, clip_counter_display, output_video, final_video_output]) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if os.path.exists(models_dir): print(f"Model directory: {Path(models_dir).resolve()}") |
|
|
demo.queue().launch(debug=True, share=True, mcp_server=True) |