prithivMLmods commited on
Commit
1fdedf4
Β·
verified Β·
1 Parent(s): a5f8631

update app

Browse files
Files changed (1) hide show
  1. app.py +173 -58
app.py CHANGED
@@ -1,16 +1,15 @@
1
  import gradio as gr
2
  import torch
3
  import torchaudio
 
4
  import os
5
  import tempfile
6
  import spaces
 
7
  from typing import Iterable
8
  from gradio.themes import Soft
9
  from gradio.themes.utils import colors, fonts, sizes
10
 
11
- # ==========================================
12
- # 1. Theme Definition (Orange Red)
13
- # ==========================================
14
  colors.orange_red = colors.Color(
15
  name="orange_red",
16
  c50="#FFF0E5",
@@ -31,7 +30,7 @@ class OrangeRedTheme(Soft):
31
  self,
32
  *,
33
  primary_hue: colors.Color | str = colors.gray,
34
- secondary_hue: colors.Color | str = colors.orange_red,
35
  neutral_hue: colors.Color | str = colors.slate,
36
  text_size: sizes.Size | str = sizes.text_lg,
37
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
@@ -55,28 +54,41 @@ class OrangeRedTheme(Soft):
55
  body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
56
  body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
57
  button_primary_text_color="white",
 
58
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
59
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
 
 
 
 
 
 
 
 
 
 
60
  block_title_text_weight="600",
61
  block_border_width="3px",
62
  block_shadow="*shadow_drop_lg",
63
  button_primary_shadow="*shadow_drop_lg",
64
  button_large_padding="11px",
 
 
65
  )
66
 
67
  orange_red_theme = OrangeRedTheme()
68
 
69
- # ==========================================
70
- # 2. Model Loading
71
- # ==========================================
72
  try:
73
  from sam_audio import SAMAudio, SAMAudioProcessor
74
  except ImportError as e:
75
- print(f"Warning: 'sam_audio' library not found. Error: {e}")
76
 
77
  MODEL_ID = "facebook/sam-audio-large"
78
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
79
 
 
80
  print(f"Loading {MODEL_ID} on {device}...")
81
 
82
  model = None
@@ -89,67 +101,173 @@ try:
89
  except Exception as e:
90
  print(f"❌ Error loading SAM-Audio: {e}")
91
 
92
- # ==========================================
93
- # 3. Processing Function
94
- # ==========================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def save_audio(tensor, sample_rate):
96
  """Saves a tensor to a temporary WAV file and returns path."""
97
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
98
  tensor = tensor.cpu()
99
- # torchaudio expects [channels, time]
100
  if tensor.dim() == 1:
101
  tensor = tensor.unsqueeze(0)
102
  torchaudio.save(tmp.name, tensor, sample_rate)
103
  return tmp.name
104
 
105
  @spaces.GPU(duration=120)
106
- def process_audio(file_path, text_prompt, rerank, progress=gr.Progress()):
107
  global model, processor
108
 
109
  if model is None or processor is None:
110
- return None, None, "❌ Model not loaded correctly."
 
 
111
 
112
  if not file_path:
113
- return None, None, "❌ Please upload an audio file."
114
  if not text_prompt or not text_prompt.strip():
115
  return None, None, "❌ Please enter a text prompt."
116
 
117
  try:
118
- progress(0.2, desc="Processing audio...")
119
-
120
- # Prepare inputs
121
- inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
122
-
123
- progress(0.5, desc="Separating sound...")
124
- with torch.inference_mode():
125
- # Run separation
126
- # Using reranking improves quality but adds latency
127
- candidates = int(rerank) if rerank else 1
128
- result = model.separate(inputs, predict_spans=True, reranking_candidates=candidates)
129
-
130
- progress(0.9, desc="Saving results...")
131
- sr = processor.audio_sampling_rate
132
-
133
- # Save Target
134
- target_path = save_audio(result.target[0], sr)
135
-
136
- # Save Residual (Background)
137
- residual_path = save_audio(result.residual[0], sr)
138
-
139
- progress(1.0, desc="Done!")
140
- return target_path, residual_path, f"βœ… Successfully isolated '{text_prompt}'"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  except Exception as e:
143
  import traceback
144
  traceback.print_exc()
145
  return None, None, f"❌ Error: {str(e)}"
146
 
147
- # ==========================================
148
- # 4. Gradio Interface
149
- # ==========================================
150
  css = """
151
  #main-title h1 {font-size: 2.4em}
152
- #col-container {max-width: 1000px; margin: 0 auto;}
153
  """
154
 
155
  with gr.Blocks() as demo:
@@ -158,40 +276,37 @@ with gr.Blocks() as demo:
158
 
159
  with gr.Column(elem_id="col-container"):
160
  with gr.Row():
161
- # Left Column: Inputs
162
  with gr.Column(scale=1):
163
  input_file = gr.Audio(label="Input Audio", type="filepath")
164
  text_prompt = gr.Textbox(label="Sound to Isolate", placeholder="e.g., 'A man speaking', 'Bird chirping'")
165
-
166
  with gr.Accordion("Advanced Settings", open=False):
167
- rerank_slider = gr.Slider(
168
- minimum=1, maximum=8, value=3, step=1,
169
- label="Reranking Candidates",
170
- info="Higher values improve quality but take longer."
171
  )
172
-
173
  run_btn = gr.Button("Segment Audio", variant="primary")
174
 
175
- # Right Column: Outputs
176
  with gr.Column(scale=1):
177
  output_target = gr.Audio(label="Isolated Sound (Target)", type="filepath")
178
  output_residual = gr.Audio(label="Background (Residual)", type="filepath")
179
- status_out = gr.Textbox(label="Status", interactive=False, show_label=True, lines=2)
180
 
181
- # Examples
182
  gr.Examples(
183
  examples=[
184
- ["example_audio/speech.mp3", "Music"],
185
- ["example_audio/song.mp3", "Drum"],
186
- ["example_audio/song2.mp3", "Vocals"],
187
  ],
188
- inputs=[input_file, text_prompt],
189
  label="Audio Examples"
190
  )
191
 
192
  run_btn.click(
193
  fn=process_audio,
194
- inputs=[input_file, text_prompt, rerank_slider],
195
  outputs=[output_target, output_residual, status_out]
196
  )
197
 
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
+ import numpy as np
5
  import os
6
  import tempfile
7
  import spaces
8
+
9
  from typing import Iterable
10
  from gradio.themes import Soft
11
  from gradio.themes.utils import colors, fonts, sizes
12
 
 
 
 
13
  colors.orange_red = colors.Color(
14
  name="orange_red",
15
  c50="#FFF0E5",
 
30
  self,
31
  *,
32
  primary_hue: colors.Color | str = colors.gray,
33
+ secondary_hue: colors.Color | str = colors.orange_red, # Use the new color
34
  neutral_hue: colors.Color | str = colors.slate,
35
  text_size: sizes.Size | str = sizes.text_lg,
36
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
54
  body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
55
  body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
56
  button_primary_text_color="white",
57
+ button_primary_text_color_hover="white",
58
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
59
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
60
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
61
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
62
+ button_secondary_text_color="black",
63
+ button_secondary_text_color_hover="white",
64
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
65
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
66
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
67
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
68
+ slider_color="*secondary_500",
69
+ slider_color_dark="*secondary_600",
70
  block_title_text_weight="600",
71
  block_border_width="3px",
72
  block_shadow="*shadow_drop_lg",
73
  button_primary_shadow="*shadow_drop_lg",
74
  button_large_padding="11px",
75
+ color_accent_soft="*primary_100",
76
+ block_label_background_fill="*primary_200",
77
  )
78
 
79
  orange_red_theme = OrangeRedTheme()
80
 
 
 
 
81
  try:
82
  from sam_audio import SAMAudio, SAMAudioProcessor
83
  except ImportError as e:
84
+ print(f"Warning: 'sam_audio' library not found. Please install it to use this app. Error: {e}")
85
 
86
  MODEL_ID = "facebook/sam-audio-large"
87
+ DEFAULT_CHUNK_DURATION = 30.0
88
+ OVERLAP_DURATION = 2.0
89
+ MAX_DURATION_WITHOUT_CHUNKING = 30.0
90
 
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
  print(f"Loading {MODEL_ID} on {device}...")
93
 
94
  model = None
 
101
  except Exception as e:
102
  print(f"❌ Error loading SAM-Audio: {e}")
103
 
104
+ def load_audio(file_path):
105
+ """Load audio from file (supports both audio and video files)."""
106
+ waveform, sample_rate = torchaudio.load(file_path)
107
+ if waveform.shape[0] > 1:
108
+ waveform = waveform.mean(dim=0, keepdim=True)
109
+ return waveform, sample_rate
110
+
111
+ def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration):
112
+ """Split audio waveform into overlapping chunks."""
113
+ chunk_samples = int(chunk_duration * sample_rate)
114
+ overlap_samples = int(overlap_duration * sample_rate)
115
+ stride = chunk_samples - overlap_samples
116
+
117
+ chunks = []
118
+ total_samples = waveform.shape[1]
119
+
120
+ if total_samples <= chunk_samples:
121
+ return [waveform]
122
+
123
+ start = 0
124
+ while start < total_samples:
125
+ end = min(start + chunk_samples, total_samples)
126
+ chunk = waveform[:, start:end]
127
+ chunks.append(chunk)
128
+ if end >= total_samples:
129
+ break
130
+ start += stride
131
+
132
+ return chunks
133
+
134
+ def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration):
135
+ """Merge audio chunks with crossfade on overlapping regions."""
136
+ if len(chunks) == 1:
137
+ chunk = chunks[0]
138
+ if chunk.dim() == 1:
139
+ chunk = chunk.unsqueeze(0)
140
+ return chunk
141
+
142
+ overlap_samples = int(overlap_duration * sample_rate)
143
+
144
+ processed_chunks = []
145
+ for chunk in chunks:
146
+ if chunk.dim() == 1:
147
+ chunk = chunk.unsqueeze(0)
148
+ processed_chunks.append(chunk)
149
+
150
+ result = processed_chunks[0]
151
+
152
+ for i in range(1, len(processed_chunks)):
153
+ prev_chunk = result
154
+ next_chunk = processed_chunks[i]
155
+
156
+ actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1])
157
+
158
+ if actual_overlap <= 0:
159
+ result = torch.cat([prev_chunk, next_chunk], dim=1)
160
+ continue
161
+
162
+ fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device)
163
+ fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device)
164
+
165
+ prev_overlap = prev_chunk[:, -actual_overlap:]
166
+ next_overlap = next_chunk[:, :actual_overlap]
167
+
168
+ crossfaded = prev_overlap * fade_out + next_overlap * fade_in
169
+
170
+ result = torch.cat([
171
+ prev_chunk[:, :-actual_overlap],
172
+ crossfaded,
173
+ next_chunk[:, actual_overlap:]
174
+ ], dim=1)
175
+
176
+ return result
177
+
178
  def save_audio(tensor, sample_rate):
179
  """Saves a tensor to a temporary WAV file and returns path."""
180
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
181
  tensor = tensor.cpu()
 
182
  if tensor.dim() == 1:
183
  tensor = tensor.unsqueeze(0)
184
  torchaudio.save(tmp.name, tensor, sample_rate)
185
  return tmp.name
186
 
187
  @spaces.GPU(duration=120)
188
+ def process_audio(file_path, text_prompt, chunk_duration_val, progress=gr.Progress()):
189
  global model, processor
190
 
191
  if model is None or processor is None:
192
+ return None, None, "❌ Model not loaded correctly. Check logs."
193
+
194
+ progress(0.05, desc="Checking inputs...")
195
 
196
  if not file_path:
197
+ return None, None, "❌ Please upload an audio or video file."
198
  if not text_prompt or not text_prompt.strip():
199
  return None, None, "❌ Please enter a text prompt."
200
 
201
  try:
202
+ progress(0.15, desc="Loading audio...")
203
+ waveform, sample_rate = load_audio(file_path)
204
+ duration = waveform.shape[1] / sample_rate
205
+
206
+ c_dur = chunk_duration_val if chunk_duration_val else DEFAULT_CHUNK_DURATION
207
+ use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING
208
+
209
+ if use_chunking:
210
+ progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...")
211
+ chunks = split_audio_into_chunks(waveform, sample_rate, c_dur, OVERLAP_DURATION)
212
+ num_chunks = len(chunks)
213
+
214
+ target_chunks = []
215
+ residual_chunks = []
216
+
217
+ for i, chunk in enumerate(chunks):
218
+ chunk_progress = 0.2 + (i / num_chunks) * 0.6
219
+ progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...")
220
+
221
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
222
+ torchaudio.save(tmp.name, chunk, sample_rate)
223
+ chunk_path = tmp.name
224
+
225
+ try:
226
+ inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device)
227
+
228
+ with torch.inference_mode():
229
+ result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
230
+
231
+ target_chunks.append(result.target[0].detach().cpu())
232
+ residual_chunks.append(result.residual[0].detach().cpu())
233
+ finally:
234
+ if os.path.exists(chunk_path):
235
+ os.unlink(chunk_path)
236
+
237
+ progress(0.85, desc="Merging chunks...")
238
+ target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION)
239
+ residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION)
240
+
241
+ progress(0.95, desc="Saving results...")
242
+ target_path = save_audio(target_merged, sample_rate)
243
+ residual_path = save_audio(residual_merged, sample_rate)
244
+
245
+ progress(1.0, desc="Done!")
246
+ return target_path, residual_path, f"βœ… Isolated '{text_prompt}' ({num_chunks} chunks)"
247
+
248
+ else:
249
+ progress(0.3, desc="Processing audio...")
250
+ inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
251
+
252
+ progress(0.6, desc="Separating sounds...")
253
+ with torch.inference_mode():
254
+ result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
255
+
256
+ progress(0.9, desc="Saving results...")
257
+ sr = processor.audio_sampling_rate
258
+ target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sr)
259
+ residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sr)
260
+
261
+ progress(1.0, desc="Done!")
262
+ return target_path, residual_path, f"βœ… Isolated '{text_prompt}'"
263
 
264
  except Exception as e:
265
  import traceback
266
  traceback.print_exc()
267
  return None, None, f"❌ Error: {str(e)}"
268
 
 
 
 
269
  css = """
270
  #main-title h1 {font-size: 2.4em}
 
271
  """
272
 
273
  with gr.Blocks() as demo:
 
276
 
277
  with gr.Column(elem_id="col-container"):
278
  with gr.Row():
 
279
  with gr.Column(scale=1):
280
  input_file = gr.Audio(label="Input Audio", type="filepath")
281
  text_prompt = gr.Textbox(label="Sound to Isolate", placeholder="e.g., 'A man speaking', 'Bird chirping'")
282
+
283
  with gr.Accordion("Advanced Settings", open=False):
284
+ chunk_duration_slider = gr.Slider(
285
+ minimum=10, maximum=60, value=30, step=5,
286
+ label="Chunk Duration (seconds)",
287
+ info="Processing long audio in chunks prevents out-of-memory errors."
288
  )
289
+
290
  run_btn = gr.Button("Segment Audio", variant="primary")
291
 
 
292
  with gr.Column(scale=1):
293
  output_target = gr.Audio(label="Isolated Sound (Target)", type="filepath")
294
  output_residual = gr.Audio(label="Background (Residual)", type="filepath")
295
+ status_out = gr.Textbox(label="Status", interactive=False, show_label=True, lines=6)
296
 
 
297
  gr.Examples(
298
  examples=[
299
+ ["example_audio/speech.mp3", "Music", 30],
300
+ ["example_audio/song.mp3", "Drum", 30],
301
+ ["example_audio/song2.mp3", "Music", 30],
302
  ],
303
+ inputs=[input_file, text_prompt, chunk_duration_slider],
304
  label="Audio Examples"
305
  )
306
 
307
  run_btn.click(
308
  fn=process_audio,
309
+ inputs=[input_file, text_prompt, chunk_duration_slider],
310
  outputs=[output_target, output_residual, status_out]
311
  )
312