prithivMLmods commited on
Commit
a5f8631
Β·
verified Β·
1 Parent(s): db15e8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -173
app.py CHANGED
@@ -1,15 +1,16 @@
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
- import numpy as np
5
  import os
6
  import tempfile
7
  import spaces
8
-
9
  from typing import Iterable
10
  from gradio.themes import Soft
11
  from gradio.themes.utils import colors, fonts, sizes
12
 
 
 
 
13
  colors.orange_red = colors.Color(
14
  name="orange_red",
15
  c50="#FFF0E5",
@@ -30,7 +31,7 @@ class OrangeRedTheme(Soft):
30
  self,
31
  *,
32
  primary_hue: colors.Color | str = colors.gray,
33
- secondary_hue: colors.Color | str = colors.orange_red, # Use the new color
34
  neutral_hue: colors.Color | str = colors.slate,
35
  text_size: sizes.Size | str = sizes.text_lg,
36
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
@@ -54,41 +55,28 @@ class OrangeRedTheme(Soft):
54
  body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
55
  body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
56
  button_primary_text_color="white",
57
- button_primary_text_color_hover="white",
58
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
59
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
60
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
61
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
62
- button_secondary_text_color="black",
63
- button_secondary_text_color_hover="white",
64
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
65
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
66
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
67
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
68
- slider_color="*secondary_500",
69
- slider_color_dark="*secondary_600",
70
  block_title_text_weight="600",
71
  block_border_width="3px",
72
  block_shadow="*shadow_drop_lg",
73
  button_primary_shadow="*shadow_drop_lg",
74
  button_large_padding="11px",
75
- color_accent_soft="*primary_100",
76
- block_label_background_fill="*primary_200",
77
  )
78
 
79
  orange_red_theme = OrangeRedTheme()
80
 
 
 
 
81
  try:
82
  from sam_audio import SAMAudio, SAMAudioProcessor
83
  except ImportError as e:
84
- print(f"Warning: 'sam_audio' library not found. Please install it to use this app. Error: {e}")
85
 
86
  MODEL_ID = "facebook/sam-audio-large"
87
- DEFAULT_CHUNK_DURATION = 30.0
88
- OVERLAP_DURATION = 2.0
89
- MAX_DURATION_WITHOUT_CHUNKING = 30.0
90
-
91
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
92
  print(f"Loading {MODEL_ID} on {device}...")
93
 
94
  model = None
@@ -101,173 +89,67 @@ try:
101
  except Exception as e:
102
  print(f"❌ Error loading SAM-Audio: {e}")
103
 
104
- def load_audio(file_path):
105
- """Load audio from file (supports both audio and video files)."""
106
- waveform, sample_rate = torchaudio.load(file_path)
107
- if waveform.shape[0] > 1:
108
- waveform = waveform.mean(dim=0, keepdim=True)
109
- return waveform, sample_rate
110
-
111
- def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration):
112
- """Split audio waveform into overlapping chunks."""
113
- chunk_samples = int(chunk_duration * sample_rate)
114
- overlap_samples = int(overlap_duration * sample_rate)
115
- stride = chunk_samples - overlap_samples
116
-
117
- chunks = []
118
- total_samples = waveform.shape[1]
119
-
120
- if total_samples <= chunk_samples:
121
- return [waveform]
122
-
123
- start = 0
124
- while start < total_samples:
125
- end = min(start + chunk_samples, total_samples)
126
- chunk = waveform[:, start:end]
127
- chunks.append(chunk)
128
- if end >= total_samples:
129
- break
130
- start += stride
131
-
132
- return chunks
133
-
134
- def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration):
135
- """Merge audio chunks with crossfade on overlapping regions."""
136
- if len(chunks) == 1:
137
- chunk = chunks[0]
138
- if chunk.dim() == 1:
139
- chunk = chunk.unsqueeze(0)
140
- return chunk
141
-
142
- overlap_samples = int(overlap_duration * sample_rate)
143
-
144
- processed_chunks = []
145
- for chunk in chunks:
146
- if chunk.dim() == 1:
147
- chunk = chunk.unsqueeze(0)
148
- processed_chunks.append(chunk)
149
-
150
- result = processed_chunks[0]
151
-
152
- for i in range(1, len(processed_chunks)):
153
- prev_chunk = result
154
- next_chunk = processed_chunks[i]
155
-
156
- actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1])
157
-
158
- if actual_overlap <= 0:
159
- result = torch.cat([prev_chunk, next_chunk], dim=1)
160
- continue
161
-
162
- fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device)
163
- fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device)
164
-
165
- prev_overlap = prev_chunk[:, -actual_overlap:]
166
- next_overlap = next_chunk[:, :actual_overlap]
167
-
168
- crossfaded = prev_overlap * fade_out + next_overlap * fade_in
169
-
170
- result = torch.cat([
171
- prev_chunk[:, :-actual_overlap],
172
- crossfaded,
173
- next_chunk[:, actual_overlap:]
174
- ], dim=1)
175
-
176
- return result
177
-
178
  def save_audio(tensor, sample_rate):
179
  """Saves a tensor to a temporary WAV file and returns path."""
180
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
181
  tensor = tensor.cpu()
 
182
  if tensor.dim() == 1:
183
  tensor = tensor.unsqueeze(0)
184
  torchaudio.save(tmp.name, tensor, sample_rate)
185
  return tmp.name
186
 
187
  @spaces.GPU(duration=120)
188
- def process_audio(file_path, text_prompt, chunk_duration_val, progress=gr.Progress()):
189
  global model, processor
190
 
191
  if model is None or processor is None:
192
- return None, None, "❌ Model not loaded correctly. Check logs."
193
-
194
- progress(0.05, desc="Checking inputs...")
195
 
196
  if not file_path:
197
- return None, None, "❌ Please upload an audio or video file."
198
  if not text_prompt or not text_prompt.strip():
199
  return None, None, "❌ Please enter a text prompt."
200
 
201
  try:
202
- progress(0.15, desc="Loading audio...")
203
- waveform, sample_rate = load_audio(file_path)
204
- duration = waveform.shape[1] / sample_rate
205
-
206
- c_dur = chunk_duration_val if chunk_duration_val else DEFAULT_CHUNK_DURATION
207
- use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING
208
-
209
- if use_chunking:
210
- progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...")
211
- chunks = split_audio_into_chunks(waveform, sample_rate, c_dur, OVERLAP_DURATION)
212
- num_chunks = len(chunks)
213
-
214
- target_chunks = []
215
- residual_chunks = []
216
-
217
- for i, chunk in enumerate(chunks):
218
- chunk_progress = 0.2 + (i / num_chunks) * 0.6
219
- progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...")
220
-
221
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
222
- torchaudio.save(tmp.name, chunk, sample_rate)
223
- chunk_path = tmp.name
224
-
225
- try:
226
- inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device)
227
-
228
- with torch.inference_mode():
229
- result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
230
-
231
- target_chunks.append(result.target[0].detach().cpu())
232
- residual_chunks.append(result.residual[0].detach().cpu())
233
- finally:
234
- if os.path.exists(chunk_path):
235
- os.unlink(chunk_path)
236
-
237
- progress(0.85, desc="Merging chunks...")
238
- target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION)
239
- residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION)
240
-
241
- progress(0.95, desc="Saving results...")
242
- target_path = save_audio(target_merged, sample_rate)
243
- residual_path = save_audio(residual_merged, sample_rate)
244
-
245
- progress(1.0, desc="Done!")
246
- return target_path, residual_path, f"βœ… Isolated '{text_prompt}' ({num_chunks} chunks)"
247
-
248
- else:
249
- progress(0.3, desc="Processing audio...")
250
- inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
251
-
252
- progress(0.6, desc="Separating sounds...")
253
- with torch.inference_mode():
254
- result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
255
-
256
- progress(0.9, desc="Saving results...")
257
- sr = processor.audio_sampling_rate
258
- target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sr)
259
- residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sr)
260
-
261
- progress(1.0, desc="Done!")
262
- return target_path, residual_path, f"βœ… Isolated '{text_prompt}'"
263
 
264
  except Exception as e:
265
  import traceback
266
  traceback.print_exc()
267
  return None, None, f"❌ Error: {str(e)}"
268
 
 
 
 
269
  css = """
270
  #main-title h1 {font-size: 2.4em}
 
271
  """
272
 
273
  with gr.Blocks() as demo:
@@ -276,37 +158,40 @@ with gr.Blocks() as demo:
276
 
277
  with gr.Column(elem_id="col-container"):
278
  with gr.Row():
 
279
  with gr.Column(scale=1):
280
  input_file = gr.Audio(label="Input Audio", type="filepath")
281
  text_prompt = gr.Textbox(label="Sound to Isolate", placeholder="e.g., 'A man speaking', 'Bird chirping'")
282
-
283
  with gr.Accordion("Advanced Settings", open=False):
284
- chunk_duration_slider = gr.Slider(
285
- minimum=10, maximum=60, value=30, step=5,
286
- label="Chunk Duration (seconds)",
287
- info="Processing long audio in chunks prevents out-of-memory errors."
288
  )
289
-
290
  run_btn = gr.Button("Segment Audio", variant="primary")
291
 
 
292
  with gr.Column(scale=1):
293
  output_target = gr.Audio(label="Isolated Sound (Target)", type="filepath")
294
  output_residual = gr.Audio(label="Background (Residual)", type="filepath")
295
- status_out = gr.Textbox(label="Status", interactive=False, show_label=True, lines=6)
296
 
 
297
  gr.Examples(
298
  examples=[
299
- ["example_audio/speech.mp3", "Music", 30],
300
- ["example_audio/song.mp3", "Drum", 30],
301
- ["example_audio/song2.mp3", "Music", 30],
302
  ],
303
- inputs=[input_file, text_prompt, chunk_duration_slider],
304
  label="Audio Examples"
305
  )
306
 
307
  run_btn.click(
308
  fn=process_audio,
309
- inputs=[input_file, text_prompt, chunk_duration_slider],
310
  outputs=[output_target, output_residual, status_out]
311
  )
312
 
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
 
4
  import os
5
  import tempfile
6
  import spaces
 
7
  from typing import Iterable
8
  from gradio.themes import Soft
9
  from gradio.themes.utils import colors, fonts, sizes
10
 
11
+ # ==========================================
12
+ # 1. Theme Definition (Orange Red)
13
+ # ==========================================
14
  colors.orange_red = colors.Color(
15
  name="orange_red",
16
  c50="#FFF0E5",
 
31
  self,
32
  *,
33
  primary_hue: colors.Color | str = colors.gray,
34
+ secondary_hue: colors.Color | str = colors.orange_red,
35
  neutral_hue: colors.Color | str = colors.slate,
36
  text_size: sizes.Size | str = sizes.text_lg,
37
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
55
  body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
56
  body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
57
  button_primary_text_color="white",
 
58
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
59
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
 
 
 
 
 
 
 
 
 
 
60
  block_title_text_weight="600",
61
  block_border_width="3px",
62
  block_shadow="*shadow_drop_lg",
63
  button_primary_shadow="*shadow_drop_lg",
64
  button_large_padding="11px",
 
 
65
  )
66
 
67
  orange_red_theme = OrangeRedTheme()
68
 
69
+ # ==========================================
70
+ # 2. Model Loading
71
+ # ==========================================
72
  try:
73
  from sam_audio import SAMAudio, SAMAudioProcessor
74
  except ImportError as e:
75
+ print(f"Warning: 'sam_audio' library not found. Error: {e}")
76
 
77
  MODEL_ID = "facebook/sam-audio-large"
 
 
 
 
78
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+
80
  print(f"Loading {MODEL_ID} on {device}...")
81
 
82
  model = None
 
89
  except Exception as e:
90
  print(f"❌ Error loading SAM-Audio: {e}")
91
 
92
+ # ==========================================
93
+ # 3. Processing Function
94
+ # ==========================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def save_audio(tensor, sample_rate):
96
  """Saves a tensor to a temporary WAV file and returns path."""
97
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
98
  tensor = tensor.cpu()
99
+ # torchaudio expects [channels, time]
100
  if tensor.dim() == 1:
101
  tensor = tensor.unsqueeze(0)
102
  torchaudio.save(tmp.name, tensor, sample_rate)
103
  return tmp.name
104
 
105
  @spaces.GPU(duration=120)
106
+ def process_audio(file_path, text_prompt, rerank, progress=gr.Progress()):
107
  global model, processor
108
 
109
  if model is None or processor is None:
110
+ return None, None, "❌ Model not loaded correctly."
 
 
111
 
112
  if not file_path:
113
+ return None, None, "❌ Please upload an audio file."
114
  if not text_prompt or not text_prompt.strip():
115
  return None, None, "❌ Please enter a text prompt."
116
 
117
  try:
118
+ progress(0.2, desc="Processing audio...")
119
+
120
+ # Prepare inputs
121
+ inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
122
+
123
+ progress(0.5, desc="Separating sound...")
124
+ with torch.inference_mode():
125
+ # Run separation
126
+ # Using reranking improves quality but adds latency
127
+ candidates = int(rerank) if rerank else 1
128
+ result = model.separate(inputs, predict_spans=True, reranking_candidates=candidates)
129
+
130
+ progress(0.9, desc="Saving results...")
131
+ sr = processor.audio_sampling_rate
132
+
133
+ # Save Target
134
+ target_path = save_audio(result.target[0], sr)
135
+
136
+ # Save Residual (Background)
137
+ residual_path = save_audio(result.residual[0], sr)
138
+
139
+ progress(1.0, desc="Done!")
140
+ return target_path, residual_path, f"βœ… Successfully isolated '{text_prompt}'"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  except Exception as e:
143
  import traceback
144
  traceback.print_exc()
145
  return None, None, f"❌ Error: {str(e)}"
146
 
147
+ # ==========================================
148
+ # 4. Gradio Interface
149
+ # ==========================================
150
  css = """
151
  #main-title h1 {font-size: 2.4em}
152
+ #col-container {max-width: 1000px; margin: 0 auto;}
153
  """
154
 
155
  with gr.Blocks() as demo:
 
158
 
159
  with gr.Column(elem_id="col-container"):
160
  with gr.Row():
161
+ # Left Column: Inputs
162
  with gr.Column(scale=1):
163
  input_file = gr.Audio(label="Input Audio", type="filepath")
164
  text_prompt = gr.Textbox(label="Sound to Isolate", placeholder="e.g., 'A man speaking', 'Bird chirping'")
165
+
166
  with gr.Accordion("Advanced Settings", open=False):
167
+ rerank_slider = gr.Slider(
168
+ minimum=1, maximum=8, value=3, step=1,
169
+ label="Reranking Candidates",
170
+ info="Higher values improve quality but take longer."
171
  )
172
+
173
  run_btn = gr.Button("Segment Audio", variant="primary")
174
 
175
+ # Right Column: Outputs
176
  with gr.Column(scale=1):
177
  output_target = gr.Audio(label="Isolated Sound (Target)", type="filepath")
178
  output_residual = gr.Audio(label="Background (Residual)", type="filepath")
179
+ status_out = gr.Textbox(label="Status", interactive=False, show_label=True, lines=2)
180
 
181
+ # Examples
182
  gr.Examples(
183
  examples=[
184
+ ["example_audio/speech.mp3", "Music"],
185
+ ["example_audio/song.mp3", "Drum"],
186
+ ["example_audio/song2.mp3", "Vocals"],
187
  ],
188
+ inputs=[input_file, text_prompt],
189
  label="Audio Examples"
190
  )
191
 
192
  run_btn.click(
193
  fn=process_audio,
194
+ inputs=[input_file, text_prompt, rerank_slider],
195
  outputs=[output_target, output_residual, status_out]
196
  )
197