Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -100,7 +100,6 @@ def teacache_wrapper_forward(self, hidden_states: torch.Tensor, **kwargs):
|
|
| 100 |
# COMPUTE: Call the original, stored method, passing 'self' explicitly
|
| 101 |
self.previous_hidden_states = hidden_states.clone()
|
| 102 |
output = original_transformer_forward(self, hidden_states=hidden_states, **kwargs)
|
| 103 |
-
|
| 104 |
# Handle both tuple and object return types from the original function
|
| 105 |
if isinstance(output, tuple):
|
| 106 |
output_tensor = output[0]
|
|
@@ -110,13 +109,9 @@ def teacache_wrapper_forward(self, hidden_states: torch.Tensor, **kwargs):
|
|
| 110 |
self.previous_residual = output_tensor - hidden_states
|
| 111 |
return output
|
| 112 |
|
| 113 |
-
# 3. Apply the patch
|
| 114 |
Transformer3DModel.forward = teacache_wrapper_forward
|
| 115 |
print("✅ Transformer3DModel patched with robust TeaCache Wrapper.")
|
| 116 |
|
| 117 |
-
# --- End TeaCache Integration ---
|
| 118 |
-
|
| 119 |
-
|
| 120 |
MAX_SEED = np.iinfo(np.int32).max
|
| 121 |
|
| 122 |
upscaler = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(torch.device("cuda:0"))
|
|
@@ -159,6 +154,45 @@ print(f"Target inference device: {target_inference_device}")
|
|
| 159 |
pipeline_instance.to(target_inference_device)
|
| 160 |
if latent_upsampler_instance: latent_upsampler_instance.to(target_inference_device)
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
def upload_to_sftp(local_filepath):
|
| 164 |
if not all([FTP_HOST, FTP_USER, FTP_PASS, FTP_DIR]):
|
|
|
|
| 100 |
# COMPUTE: Call the original, stored method, passing 'self' explicitly
|
| 101 |
self.previous_hidden_states = hidden_states.clone()
|
| 102 |
output = original_transformer_forward(self, hidden_states=hidden_states, **kwargs)
|
|
|
|
| 103 |
# Handle both tuple and object return types from the original function
|
| 104 |
if isinstance(output, tuple):
|
| 105 |
output_tensor = output[0]
|
|
|
|
| 109 |
self.previous_residual = output_tensor - hidden_states
|
| 110 |
return output
|
| 111 |
|
|
|
|
| 112 |
Transformer3DModel.forward = teacache_wrapper_forward
|
| 113 |
print("✅ Transformer3DModel patched with robust TeaCache Wrapper.")
|
| 114 |
|
|
|
|
|
|
|
|
|
|
| 115 |
MAX_SEED = np.iinfo(np.int32).max
|
| 116 |
|
| 117 |
upscaler = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(torch.device("cuda:0"))
|
|
|
|
| 154 |
pipeline_instance.to(target_inference_device)
|
| 155 |
if latent_upsampler_instance: latent_upsampler_instance.to(target_inference_device)
|
| 156 |
|
| 157 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 158 |
+
|
| 159 |
+
from kernels import get_kernel
|
| 160 |
+
|
| 161 |
+
fa3_kernel = get_kernel("kernels-community/flash-attn3")
|
| 162 |
+
|
| 163 |
+
class FlashAttentionProcessor(AttnProcessor2_0):
|
| 164 |
+
def __call__(
|
| 165 |
+
self,
|
| 166 |
+
attn,
|
| 167 |
+
hidden_states,
|
| 168 |
+
encoder_hidden_states=None,
|
| 169 |
+
attention_mask=None,
|
| 170 |
+
**kwargs,
|
| 171 |
+
):
|
| 172 |
+
query = attn.to_q(hidden_states)
|
| 173 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
| 174 |
+
key = attn.to_k(encoder_hidden_states)
|
| 175 |
+
value = attn.to_v(encoder_hidden_states)
|
| 176 |
+
scale = attn.scale
|
| 177 |
+
query = query * scale
|
| 178 |
+
b, t, c = query.shape
|
| 179 |
+
h = attn.heads
|
| 180 |
+
d = c // h
|
| 181 |
+
q_reshaped = query.reshape(b, t, h, d).permute(0, 2, 1, 3)
|
| 182 |
+
k_reshaped = key.reshape(b, t, h, d).permute(0, 2, 1, 3)
|
| 183 |
+
v_reshaped = value.reshape(b, t, h, d).permute(0, 2, 1, 3)
|
| 184 |
+
out_reshaped = torch.empty_like(q_reshaped)
|
| 185 |
+
fa3_kernel.attention(q_reshaped, k_reshaped, v_reshaped, out_reshaped)
|
| 186 |
+
out = out_reshaped.permute(0, 2, 1, 3).reshape(b, t, c)
|
| 187 |
+
out = attn.to_out(out)
|
| 188 |
+
return out
|
| 189 |
+
|
| 190 |
+
fa_processor = FlashAttentionProcessor()
|
| 191 |
+
|
| 192 |
+
# Iterate through the pipeline's UNet and apply the custom processor
|
| 193 |
+
for name, module in pipeline_instance.transformer.named_modules():
|
| 194 |
+
if isinstance(module, AttnProcessor2_0):
|
| 195 |
+
module.processor = fa_processor
|
| 196 |
|
| 197 |
def upload_to_sftp(local_filepath):
|
| 198 |
if not all([FTP_HOST, FTP_USER, FTP_PASS, FTP_DIR]):
|