surena26 commited on
Commit
d36e9c5
·
verified ·
1 Parent(s): a90b104

Upload ComfyUI/latent_preview.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ComfyUI/latent_preview.py +98 -0
ComfyUI/latent_preview.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import struct
4
+ import numpy as np
5
+ from comfy.cli_args import args, LatentPreviewMethod
6
+ from comfy.taesd.taesd import TAESD
7
+ import folder_paths
8
+ import comfy.utils
9
+ import logging
10
+
11
+ MAX_PREVIEW_RESOLUTION = 512
12
+
13
+ class LatentPreviewer:
14
+ def decode_latent_to_preview(self, x0):
15
+ pass
16
+
17
+ def decode_latent_to_preview_image(self, preview_format, x0):
18
+ preview_image = self.decode_latent_to_preview(x0)
19
+ return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
20
+
21
+ class TAESDPreviewerImpl(LatentPreviewer):
22
+ def __init__(self, taesd):
23
+ self.taesd = taesd
24
+
25
+ def decode_latent_to_preview(self, x0):
26
+ x_sample = self.taesd.decode(x0[:1])[0].detach()
27
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
28
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
29
+ x_sample = x_sample.astype(np.uint8)
30
+
31
+ preview_image = Image.fromarray(x_sample)
32
+ return preview_image
33
+
34
+
35
+ class Latent2RGBPreviewer(LatentPreviewer):
36
+ def __init__(self, latent_rgb_factors):
37
+ self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
38
+
39
+ def decode_latent_to_preview(self, x0):
40
+ latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
41
+
42
+ latents_ubyte = (((latent_image + 1) / 2)
43
+ .clamp(0, 1) # change scale from -1..1 to 0..1
44
+ .mul(0xFF) # to 0..255
45
+ .byte()).cpu()
46
+
47
+ return Image.fromarray(latents_ubyte.numpy())
48
+
49
+
50
+ def get_previewer(device, latent_format):
51
+ previewer = None
52
+ method = args.preview_method
53
+ if method != LatentPreviewMethod.NoPreviews:
54
+ # TODO previewer methods
55
+ taesd_decoder_path = None
56
+ if latent_format.taesd_decoder_name is not None:
57
+ taesd_decoder_path = next(
58
+ (fn for fn in folder_paths.get_filename_list("vae_approx")
59
+ if fn.startswith(latent_format.taesd_decoder_name)),
60
+ ""
61
+ )
62
+ taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
63
+
64
+ if method == LatentPreviewMethod.Auto:
65
+ method = LatentPreviewMethod.Latent2RGB
66
+ if taesd_decoder_path:
67
+ method = LatentPreviewMethod.TAESD
68
+
69
+ if method == LatentPreviewMethod.TAESD:
70
+ if taesd_decoder_path:
71
+ taesd = TAESD(None, taesd_decoder_path).to(device)
72
+ previewer = TAESDPreviewerImpl(taesd)
73
+ else:
74
+ logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
75
+
76
+ if previewer is None:
77
+ if latent_format.latent_rgb_factors is not None:
78
+ previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
79
+ return previewer
80
+
81
+ def prepare_callback(model, steps, x0_output_dict=None):
82
+ preview_format = "JPEG"
83
+ if preview_format not in ["JPEG", "PNG"]:
84
+ preview_format = "JPEG"
85
+
86
+ previewer = get_previewer(model.load_device, model.model.latent_format)
87
+
88
+ pbar = comfy.utils.ProgressBar(steps)
89
+ def callback(step, x0, x, total_steps):
90
+ if x0_output_dict is not None:
91
+ x0_output_dict["x0"] = x0
92
+
93
+ preview_bytes = None
94
+ if previewer:
95
+ preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
96
+ pbar.update_absolute(step + 1, total_steps, preview_bytes)
97
+ return callback
98
+