cover comparison

Slowpics comparison: generated images, decoded with original vae vs 2x vae

Slowpics comparison: highres fix over ultrasharp upscaler, vs highres fix over 2x decoder

A decoder-only finetune of the Wan2.1 VAE, with 2x upscaling integrated directly into the decoder. The main purpose of this is to kill the dreaded wan speckles/polka dots/grain, but it's also convenient for highres fix workflows. The outputs of the 2x decoder are usually much better than what you would get by running the outputs of the original decoder through an image upscale model, and even better, it's effectively free, since the compute cost of decoding is virtually unchanged. If you don't want to use the extra resolution, a slight blur and downsample will give you an original resolution image with much higher quality than the original decoder can produce.

In particular, this VAE improves skin details and hair very significantly. It is trained almost exclusively on real images, so it may struggle with anime/lineart and text. It would be possible to finetune on anime/lineart, but I'm not aware of a suitable dataset that's licensed correctly and not just full of scraped media with massive copyright violations. If you know of an appropriate datset for this that is sourced from cc-by materials or similar, let me know and I'd be happy to try training it.

The first released version is trained on images only, and is compatible with both Wan and Qwen since they share the same latent space. A video version is planned, but video training is more complex than image training, so it will take some time.

How to use (ComfyUI):

Install these custom nodes, which allow loading and decoding latents in a native-compatible way. Wan wrapper latents can also be used if they are de-normalized using the correct wrapper node.

https://github.com/spacepxl/ComfyUI-VAE-Utils

How to use (diffusers):

import torch
import torch.nn.functional as F
from diffusers import AutoencoderKLWan

vae = AutoencoderKLWan.from_pretrained("spacepxl/Wan2.1-VAE-upscale2x", subfolder="diffusers/Wan2.1_VAE_upscale2x_imageonly_real_v1")
decoder_out = vae.decode(latents, return_dict=False)[0] # [B, 12, F, H, W]
decoded_video = F.pixel_shuffle(decoder_out.movedim(2, 1), upscale_factor=2).movedim(1, 2) # pixel shuffle needs [..., C, H, W] format
# or just use decoder_out.squeeze(2) for images to convert BCFHW to BCHW

Training details

The VAE is initialized from the pretrained Wan2.1 VAE. Encoder is unchanged, but the final Conv3d in the decoder is replaced to expand output channels from 3 -> 12. The key idea behind this change is that many image upscale models do all processing at the original resolution, then kick out a high res image through a pixel shuffle as the last layer. For training, the encoder is frozen to preserve the latent space, and only the decoder is trained.

Training uses a combination of L1 loss, LPIPS, FDL, and GAN loss:

  • LPIPS is useful to regularize the other losses, but needs to be kept low to prevent the checkerboard/speckle artifacts that come from over reliance on it. This was probably the primary mistake made by the Wan team, too much LPIPS loss weight.

  • Frequency Distribution Loss is also a perceptual loss using VGG features, but has very different properties from LPIPS, and encourages more realistic textures and local statistics. It's not a replacement, on its own it generates different artifacts, but together they work quite well. Not well enough to eliminate GAN, but close.

  • For the GAN loss, I used a patchgan discriminator with spectral norm, and non-saturating LSGAN loss. I tried to use R3GAN, but the gradient penalties are slow and memory intensive, and it didn't converge well for me anyway. The GAN loss weight is fairly high, to improve sharpness and realism, possibly even at the expense of metric accuracy.

Speaking of metrics, I have none to share. This is because the focus was on human perceptual quality, not metrics. Most common metrics used to report VAE reconstruction accuracy (PSNR/SSIM/LPIPS/etc) are very poorly aligned with human preference, so all tuning was done by eye instead. The result is that decoded images may have slight color shifts, and can sometimes be overly sharp. This can always be overcome by color matching and blur, so in my opinion this is an acceptable tradeoff to fix the metrically accurate, but subjectively poor quality of the original Wan VAE decoder.

Final training for the image-only model took about 40h on a single 5090, although that is not counting the dozens of test runs needed to dial in loss functions and data preprocessing.

  • base resolution: 256 (upscaled to 512)
  • batch size: 4
  • total steps: 300k

latent degradation

In addition to all that, there is one other critical piece to improve decoder quality: latent degradation. If you train a VAE only to reconstruct real encoder latents, it can still struggle with latents generated by the diffusion model. This is because diffusion models usually fail to generate the highest frequency information in the latent space. In other words, they have a low-frequency bias. This is probably caused by the use of only MSE loss in diffusion/flow training, which is well known in image model training to have a low frequency bias. One piece of evidence to support this theory is that distilled models which use distribution matching and/or adversarial loss functions in the distillation process, like DMD2 (lightning) distillation, tend to improve the quality of fine details compared to the base diffusion model.

There is probably a whole branch of ideas to explore on the side of diffusion model training, like incorporating latent perceptual or GAN loss into the diffusion training process, but another option is to just make the VAE decoder robust against the types of degradations that are generated by the diffusion model. This is what I did.

To simulate generated latents, you can simply add noise to real latents, and use the diffusion model to predict the clean latents in one or a few steps, using an empty or negative caption. I've seen this referred to in papers as "SD-edit degradation", aka img2img (or vid2vid). Crucially, the degradation is NOT noise. LTXV tried to train the decoder with noisy latents to fix this problem, and it did not work. If you inspect real vs generated latents, it's clear that the generated latents are softer than the real ones, not noisy.

For my purposes, I didn't want to keep a diffusion model loaded during VAE training, so I trained a small convolution-only model to match the sdedit degradations generated by Wan2.1-1.3b, since it tends to generate lower quality images than the 14b models. I simply used the preset negative caption, and trained with random timesteps on a lognorm distribution centered around ~0.2. The convolution model is also conditioned on timestep, so it learns to simulate whatever degradations the diffusion model produces towards the end of the generation process. This is mostly like blurring, but it's much easier to just learn it directly instead of trying to manually tune gaussian blur filters to match. It also doesn't need to be perfect, since the whole purpose is to make the encoded latents less accurate.

The degradation proxy model is trained first, and kept frozen during VAE training, with timesteps sampled randomly in the 0 to 0.12 range to approximate the typical level of degradations in generated images. The result of degrading the latents during VAE decoder training is that the decoder becomes more robust at decoding plausible details from generated latents, instead of producing artifacts like the uniform speckle grid of the original decoder. Without this, the decoder learns to reconstruct good details from encoded latents, but still produces artifacts sometimes on generated latents.

Downloads last month
2,390
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for spacepxl/Wan2.1-VAE-upscale2x

Finetuned
(33)
this model
Finetunes
1 model