Update barks.py
Browse files
barks.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
import os
|
| 3 |
import torch
|
| 4 |
import torchaudio
|
|
@@ -36,18 +35,16 @@ if device != "cuda":
|
|
| 36 |
sys.exit(1)
|
| 37 |
print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
|
| 38 |
|
| 39 |
-
# Initialize accelerator
|
| 40 |
accelerator = Accelerator(mixed_precision="fp16")
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
def
|
| 44 |
torch.cuda.empty_cache()
|
| 45 |
gc.collect()
|
| 46 |
-
|
| 47 |
-
torch.cuda.synchronize()
|
| 48 |
-
print("Performed aggressive memory cleanup.")
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
# 2) LOAD MODELS
|
| 53 |
try:
|
|
@@ -86,7 +83,7 @@ def print_resource_usage(stage: str):
|
|
| 86 |
print("---------------")
|
| 87 |
|
| 88 |
# Check available GPU memory
|
| 89 |
-
def check_vram_availability(required_gb=3.0):
|
| 90 |
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 91 |
allocated_vram = torch.cuda.memory_allocated() / (1024**3)
|
| 92 |
available_vram = total_vram - allocated_vram
|
|
@@ -298,7 +295,7 @@ def generate_vocals(vocal_prompt: str, total_duration: int):
|
|
| 298 |
|
| 299 |
# Move Bark model back to CPU
|
| 300 |
bark_model = bark_model.to("cpu")
|
| 301 |
-
|
| 302 |
|
| 303 |
return vocal_segment, "✅ Vocals generated successfully."
|
| 304 |
except Exception as e:
|
|
@@ -370,12 +367,12 @@ def generate_music(instrumental_prompt: str, vocal_prompt: str, cfg_scale: float
|
|
| 370 |
os.remove(temp_wav_path)
|
| 371 |
audio_segments.append(segment)
|
| 372 |
|
| 373 |
-
|
| 374 |
print_resource_usage(f"After Chunk {i+1} Generation")
|
| 375 |
|
| 376 |
# Move MusicGen back to CPU
|
| 377 |
musicgen_model = musicgen_model.to("cpu")
|
| 378 |
-
|
| 379 |
|
| 380 |
print("Combining instrumental chunks...")
|
| 381 |
final_segment = audio_segments[0]
|
|
@@ -415,7 +412,7 @@ def generate_music(instrumental_prompt: str, vocal_prompt: str, cfg_scale: float
|
|
| 415 |
except Exception as e:
|
| 416 |
return None, f"❌ Generation failed: {e}"
|
| 417 |
finally:
|
| 418 |
-
|
| 419 |
|
| 420 |
# Function to clear inputs
|
| 421 |
def clear_inputs():
|
|
@@ -692,4 +689,4 @@ try:
|
|
| 692 |
fastapi_app.redoc_url = None
|
| 693 |
fastapi_app.openapi_url = None
|
| 694 |
except Exception:
|
| 695 |
-
pass
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
import torchaudio
|
|
|
|
| 35 |
sys.exit(1)
|
| 36 |
print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
|
| 37 |
|
| 38 |
+
# Initialize accelerator
|
| 39 |
accelerator = Accelerator(mixed_precision="fp16")
|
| 40 |
|
| 41 |
+
# Simplified memory cleanup to avoid freezes
|
| 42 |
+
def memory_cleanup():
|
| 43 |
torch.cuda.empty_cache()
|
| 44 |
gc.collect()
|
| 45 |
+
print("Performed memory cleanup.")
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
memory_cleanup()
|
| 48 |
|
| 49 |
# 2) LOAD MODELS
|
| 50 |
try:
|
|
|
|
| 83 |
print("---------------")
|
| 84 |
|
| 85 |
# Check available GPU memory
|
| 86 |
+
def check_vram_availability(required_gb=3.0):
|
| 87 |
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 88 |
allocated_vram = torch.cuda.memory_allocated() / (1024**3)
|
| 89 |
available_vram = total_vram - allocated_vram
|
|
|
|
| 295 |
|
| 296 |
# Move Bark model back to CPU
|
| 297 |
bark_model = bark_model.to("cpu")
|
| 298 |
+
memory_cleanup()
|
| 299 |
|
| 300 |
return vocal_segment, "✅ Vocals generated successfully."
|
| 301 |
except Exception as e:
|
|
|
|
| 367 |
os.remove(temp_wav_path)
|
| 368 |
audio_segments.append(segment)
|
| 369 |
|
| 370 |
+
memory_cleanup()
|
| 371 |
print_resource_usage(f"After Chunk {i+1} Generation")
|
| 372 |
|
| 373 |
# Move MusicGen back to CPU
|
| 374 |
musicgen_model = musicgen_model.to("cpu")
|
| 375 |
+
memory_cleanup()
|
| 376 |
|
| 377 |
print("Combining instrumental chunks...")
|
| 378 |
final_segment = audio_segments[0]
|
|
|
|
| 412 |
except Exception as e:
|
| 413 |
return None, f"❌ Generation failed: {e}"
|
| 414 |
finally:
|
| 415 |
+
memory_cleanup()
|
| 416 |
|
| 417 |
# Function to clear inputs
|
| 418 |
def clear_inputs():
|
|
|
|
| 689 |
fastapi_app.redoc_url = None
|
| 690 |
fastapi_app.openapi_url = None
|
| 691 |
except Exception:
|
| 692 |
+
pass
|