ford442 commited on
Commit
a70e549
·
verified ·
1 Parent(s): 792ec5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -5
app.py CHANGED
@@ -100,7 +100,6 @@ def teacache_wrapper_forward(self, hidden_states: torch.Tensor, **kwargs):
100
  # COMPUTE: Call the original, stored method, passing 'self' explicitly
101
  self.previous_hidden_states = hidden_states.clone()
102
  output = original_transformer_forward(self, hidden_states=hidden_states, **kwargs)
103
-
104
  # Handle both tuple and object return types from the original function
105
  if isinstance(output, tuple):
106
  output_tensor = output[0]
@@ -110,13 +109,9 @@ def teacache_wrapper_forward(self, hidden_states: torch.Tensor, **kwargs):
110
  self.previous_residual = output_tensor - hidden_states
111
  return output
112
 
113
- # 3. Apply the patch
114
  Transformer3DModel.forward = teacache_wrapper_forward
115
  print("✅ Transformer3DModel patched with robust TeaCache Wrapper.")
116
 
117
- # --- End TeaCache Integration ---
118
-
119
-
120
  MAX_SEED = np.iinfo(np.int32).max
121
 
122
  upscaler = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(torch.device("cuda:0"))
@@ -159,6 +154,45 @@ print(f"Target inference device: {target_inference_device}")
159
  pipeline_instance.to(target_inference_device)
160
  if latent_upsampler_instance: latent_upsampler_instance.to(target_inference_device)
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  def upload_to_sftp(local_filepath):
164
  if not all([FTP_HOST, FTP_USER, FTP_PASS, FTP_DIR]):
 
100
  # COMPUTE: Call the original, stored method, passing 'self' explicitly
101
  self.previous_hidden_states = hidden_states.clone()
102
  output = original_transformer_forward(self, hidden_states=hidden_states, **kwargs)
 
103
  # Handle both tuple and object return types from the original function
104
  if isinstance(output, tuple):
105
  output_tensor = output[0]
 
109
  self.previous_residual = output_tensor - hidden_states
110
  return output
111
 
 
112
  Transformer3DModel.forward = teacache_wrapper_forward
113
  print("✅ Transformer3DModel patched with robust TeaCache Wrapper.")
114
 
 
 
 
115
  MAX_SEED = np.iinfo(np.int32).max
116
 
117
  upscaler = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(torch.device("cuda:0"))
 
154
  pipeline_instance.to(target_inference_device)
155
  if latent_upsampler_instance: latent_upsampler_instance.to(target_inference_device)
156
 
157
+ from diffusers.models.attention_processor import AttnProcessor2_0
158
+
159
+ from kernels import get_kernel
160
+
161
+ fa3_kernel = get_kernel("kernels-community/flash-attn3")
162
+
163
+ class FlashAttentionProcessor(AttnProcessor2_0):
164
+ def __call__(
165
+ self,
166
+ attn,
167
+ hidden_states,
168
+ encoder_hidden_states=None,
169
+ attention_mask=None,
170
+ **kwargs,
171
+ ):
172
+ query = attn.to_q(hidden_states)
173
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
174
+ key = attn.to_k(encoder_hidden_states)
175
+ value = attn.to_v(encoder_hidden_states)
176
+ scale = attn.scale
177
+ query = query * scale
178
+ b, t, c = query.shape
179
+ h = attn.heads
180
+ d = c // h
181
+ q_reshaped = query.reshape(b, t, h, d).permute(0, 2, 1, 3)
182
+ k_reshaped = key.reshape(b, t, h, d).permute(0, 2, 1, 3)
183
+ v_reshaped = value.reshape(b, t, h, d).permute(0, 2, 1, 3)
184
+ out_reshaped = torch.empty_like(q_reshaped)
185
+ fa3_kernel.attention(q_reshaped, k_reshaped, v_reshaped, out_reshaped)
186
+ out = out_reshaped.permute(0, 2, 1, 3).reshape(b, t, c)
187
+ out = attn.to_out(out)
188
+ return out
189
+
190
+ fa_processor = FlashAttentionProcessor()
191
+
192
+ # Iterate through the pipeline's UNet and apply the custom processor
193
+ for name, module in pipeline_instance.transformer.named_modules():
194
+ if isinstance(module, AttnProcessor2_0):
195
+ module.processor = fa_processor
196
 
197
  def upload_to_sftp(local_filepath):
198
  if not all([FTP_HOST, FTP_USER, FTP_PASS, FTP_DIR]):