Veena commited on
Commit
d1c3c57
·
1 Parent(s): 06301dc

Update Maya1 Gradio app with preset characters

Browse files
Files changed (2) hide show
  1. app.py +123 -67
  2. requirements.txt +0 -2
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
2
- import asyncio
3
  import io
4
- import sys
5
- sys.path.insert(0, '.')
 
 
6
 
7
  # Mock spaces module for local testing
8
  try:
@@ -14,11 +16,18 @@ except ImportError:
14
  return func
15
  spaces = SpacesMock()
16
 
17
- from maya1.model_loader import Maya1Model
18
- from maya1.pipeline import Maya1Pipeline
19
- from maya1.prompt_builder import Maya1PromptBuilder
20
- from maya1.snac_decoder import SNACDecoder
21
- from maya1.constants import AUDIO_SAMPLE_RATE
 
 
 
 
 
 
 
22
 
23
  # Preset characters (2 realistic + 2 creative)
24
  PRESET_CHARACTERS = {
@@ -40,53 +49,77 @@ PRESET_CHARACTERS = {
40
  }
41
  }
42
 
43
- # Global pipeline variables
44
  model = None
45
- prompt_builder = None
46
- snac_decoder = None
47
- pipeline = None
48
  models_loaded = False
49
 
50
- def load_models():
51
- """Load Maya1 vLLM model and pipeline (runs once)."""
52
- global model, prompt_builder, snac_decoder, pipeline, models_loaded
 
 
 
 
 
53
 
54
- if models_loaded:
55
- return
 
 
 
 
 
 
 
 
 
56
 
57
- import torch
58
- import os
59
 
60
- # Ensure CUDA is available for HF Spaces
61
- if not torch.cuda.is_available():
62
- print("Warning: CUDA not available, using CPU")
63
- device = "cpu"
64
- else:
65
- device = "cuda"
66
- print(f"CUDA available: {torch.cuda.get_device_name(0)}")
67
 
68
- # Set environment variable for vLLM
69
- os.environ.setdefault("VLLM_USE_V1", "0")
70
 
71
- print("Loading Maya1 model with vLLM...")
72
- model = Maya1Model(
73
- model_path="maya-research/maya1",
74
- dtype="bfloat16",
75
- max_model_len=8192,
76
- gpu_memory_utilization=0.85,
77
- )
 
 
 
 
 
 
78
 
79
- print("Initializing prompt builder...")
80
- prompt_builder = Maya1PromptBuilder(model.tokenizer, model)
 
 
 
81
 
82
- print("Loading SNAC decoder...")
83
- snac_decoder = SNACDecoder(
84
- device=device,
85
- enable_batching=False,
 
 
 
 
 
86
  )
 
87
 
88
- print("Initializing pipeline...")
89
- pipeline = Maya1Pipeline(model, prompt_builder, snac_decoder)
 
 
90
 
91
  models_loaded = True
92
  print("Models loaded successfully!")
@@ -100,7 +133,7 @@ def preset_selected(preset_name):
100
 
101
  @spaces.GPU
102
  def generate_speech(preset_name, description, text, temperature, max_tokens):
103
- """Generate emotional speech from description and text using vLLM."""
104
  try:
105
  # Load models if not already loaded
106
  load_models()
@@ -115,42 +148,65 @@ def generate_speech(preset_name, description, text, temperature, max_tokens):
115
 
116
  print(f"Generating with temperature={temperature}, max_tokens={max_tokens}...")
117
 
118
- # Generate audio using vLLM pipeline (async wrapper)
119
- loop = asyncio.new_event_loop()
120
- asyncio.set_event_loop(loop)
121
- audio_bytes = loop.run_until_complete(
122
- pipeline.generate_speech(
123
- description=description,
124
- text=text,
125
- temperature=temperature,
126
- top_p=0.9,
127
- max_tokens=max_tokens,
 
 
 
 
 
128
  repetition_penalty=1.1,
129
- seed=None,
 
 
130
  )
131
- )
132
- loop.close()
133
 
134
- if audio_bytes is None:
135
- return None, "Error: Audio generation failed. Try different text or increase max_tokens."
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- # Convert bytes to WAV file
138
- import wave
 
 
 
 
 
 
 
 
 
 
 
139
  wav_buffer = io.BytesIO()
140
  with wave.open(wav_buffer, 'wb') as wav_file:
141
  wav_file.setnchannels(1)
142
  wav_file.setsampwidth(2)
143
  wav_file.setframerate(AUDIO_SAMPLE_RATE)
144
- wav_file.writeframes(audio_bytes)
145
 
146
  wav_buffer.seek(0)
147
-
148
- # Calculate duration
149
- duration = len(audio_bytes) // 2 / AUDIO_SAMPLE_RATE
150
- frames = len(audio_bytes) // 2 // (AUDIO_SAMPLE_RATE // 6.86) // 7
151
 
152
  status_msg = f"Generated {duration:.2f}s of emotional speech!"
153
-
154
  return wav_buffer, status_msg
155
 
156
  except Exception as e:
 
1
  import gradio as gr
2
+ import torch
3
  import io
4
+ import wave
5
+ import numpy as np
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from snac import SNAC
8
 
9
  # Mock spaces module for local testing
10
  try:
 
16
  return func
17
  spaces = SpacesMock()
18
 
19
+ # Constants
20
+ CODE_START_TOKEN_ID = 128257
21
+ CODE_END_TOKEN_ID = 128258
22
+ CODE_TOKEN_OFFSET = 128266
23
+ SNAC_MIN_ID = 128266
24
+ SNAC_MAX_ID = 156937
25
+ SOH_ID = 128259
26
+ EOH_ID = 128260
27
+ SOA_ID = 128261
28
+ BOS_ID = 128000
29
+ TEXT_EOT_ID = 128009
30
+ AUDIO_SAMPLE_RATE = 24000
31
 
32
  # Preset characters (2 realistic + 2 creative)
33
  PRESET_CHARACTERS = {
 
49
  }
50
  }
51
 
52
+ # Global model variables
53
  model = None
54
+ tokenizer = None
55
+ snac_model = None
 
56
  models_loaded = False
57
 
58
+ def build_prompt(tokenizer, description: str, text: str) -> str:
59
+ """Build formatted prompt for Maya1."""
60
+ soh_token = tokenizer.decode([SOH_ID])
61
+ eoh_token = tokenizer.decode([EOH_ID])
62
+ soa_token = tokenizer.decode([SOA_ID])
63
+ sos_token = tokenizer.decode([CODE_START_TOKEN_ID])
64
+ eot_token = tokenizer.decode([TEXT_EOT_ID])
65
+ bos_token = tokenizer.bos_token
66
 
67
+ formatted_text = f'<description="{description}"> {text}'
68
+ prompt = (
69
+ soh_token + bos_token + formatted_text + eot_token +
70
+ eoh_token + soa_token + sos_token
71
+ )
72
+ return prompt
73
+
74
+ def unpack_snac_from_7(snac_tokens: list) -> list:
75
+ """Unpack 7-token SNAC frames to 3 hierarchical levels."""
76
+ if snac_tokens and snac_tokens[-1] == CODE_END_TOKEN_ID:
77
+ snac_tokens = snac_tokens[:-1]
78
 
79
+ frames = len(snac_tokens) // 7
80
+ snac_tokens = snac_tokens[:frames * 7]
81
 
82
+ if frames == 0:
83
+ return [[], [], []]
 
 
 
 
 
84
 
85
+ l1, l2, l3 = [], [], []
 
86
 
87
+ for i in range(frames):
88
+ slots = snac_tokens[i*7:(i+1)*7]
89
+ l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096)
90
+ l2.extend([
91
+ (slots[1] - CODE_TOKEN_OFFSET) % 4096,
92
+ (slots[4] - CODE_TOKEN_OFFSET) % 4096,
93
+ ])
94
+ l3.extend([
95
+ (slots[2] - CODE_TOKEN_OFFSET) % 4096,
96
+ (slots[3] - CODE_TOKEN_OFFSET) % 4096,
97
+ (slots[5] - CODE_TOKEN_OFFSET) % 4096,
98
+ (slots[6] - CODE_TOKEN_OFFSET) % 4096,
99
+ ])
100
 
101
+ return [l1, l2, l3]
102
+
103
+ def load_models():
104
+ """Load Maya1 Transformers model (runs once)."""
105
+ global model, tokenizer, snac_model, models_loaded
106
 
107
+ if models_loaded:
108
+ return
109
+
110
+ print("Loading Maya1 model with Transformers...")
111
+ model = AutoModelForCausalLM.from_pretrained(
112
+ "maya-research/maya1",
113
+ torch_dtype=torch.bfloat16,
114
+ device_map="auto",
115
+ trust_remote_code=True
116
  )
117
+ tokenizer = AutoTokenizer.from_pretrained("maya-research/maya1", trust_remote_code=True)
118
 
119
+ print("Loading SNAC decoder...")
120
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
121
+ if torch.cuda.is_available():
122
+ snac_model = snac_model.to("cuda")
123
 
124
  models_loaded = True
125
  print("Models loaded successfully!")
 
133
 
134
  @spaces.GPU
135
  def generate_speech(preset_name, description, text, temperature, max_tokens):
136
+ """Generate emotional speech from description and text using Transformers."""
137
  try:
138
  # Load models if not already loaded
139
  load_models()
 
148
 
149
  print(f"Generating with temperature={temperature}, max_tokens={max_tokens}...")
150
 
151
+ # Build prompt
152
+ prompt = build_prompt(tokenizer, description, text)
153
+ inputs = tokenizer(prompt, return_tensors="pt")
154
+
155
+ if torch.cuda.is_available():
156
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
157
+
158
+ # Generate tokens
159
+ with torch.inference_mode():
160
+ outputs = model.generate(
161
+ **inputs,
162
+ max_new_tokens=max_tokens,
163
+ min_new_tokens=28,
164
+ temperature=temperature,
165
+ top_p=0.9,
166
  repetition_penalty=1.1,
167
+ do_sample=True,
168
+ eos_token_id=CODE_END_TOKEN_ID,
169
+ pad_token_id=tokenizer.pad_token_id,
170
  )
 
 
171
 
172
+ # Extract SNAC tokens
173
+ generated_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist()
174
+
175
+ # Find EOS and extract SNAC codes
176
+ eos_idx = generated_ids.index(CODE_END_TOKEN_ID) if CODE_END_TOKEN_ID in generated_ids else len(generated_ids)
177
+ snac_tokens = [t for t in generated_ids[:eos_idx] if SNAC_MIN_ID <= t <= SNAC_MAX_ID]
178
+
179
+ if len(snac_tokens) < 7:
180
+ return None, "Error: Not enough tokens generated. Try different text or increase max_tokens."
181
+
182
+ # Unpack and decode
183
+ levels = unpack_snac_from_7(snac_tokens)
184
+ frames = len(levels[0])
185
 
186
+ device = "cuda" if torch.cuda.is_available() else "cpu"
187
+ codes_tensor = [torch.tensor(level, dtype=torch.long, device=device).unsqueeze(0) for level in levels]
188
+
189
+ with torch.inference_mode():
190
+ z_q = snac_model.quantizer.from_codes(codes_tensor)
191
+ audio = snac_model.decoder(z_q)[0, 0].cpu().numpy()
192
+
193
+ # Trim warmup
194
+ if len(audio) > 2048:
195
+ audio = audio[2048:]
196
+
197
+ # Convert to WAV
198
+ audio_int16 = (audio * 32767).astype(np.int16)
199
  wav_buffer = io.BytesIO()
200
  with wave.open(wav_buffer, 'wb') as wav_file:
201
  wav_file.setnchannels(1)
202
  wav_file.setsampwidth(2)
203
  wav_file.setframerate(AUDIO_SAMPLE_RATE)
204
+ wav_file.writeframes(audio_int16.tobytes())
205
 
206
  wav_buffer.seek(0)
207
+ duration = len(audio) / AUDIO_SAMPLE_RATE
 
 
 
208
 
209
  status_msg = f"Generated {duration:.2f}s of emotional speech!"
 
210
  return wav_buffer, status_msg
211
 
212
  except Exception as e:
requirements.txt CHANGED
@@ -1,10 +1,8 @@
1
  torch>=2.5.0
2
  transformers>=4.57.0
3
  gradio>=5.0.0
4
- vllm>=0.11.0
5
  snac>=1.2.1
6
  soundfile>=0.13.0
7
  numpy>=2.1.0
8
  accelerate>=1.10.0
9
- xformers>=0.0.32
10
 
 
1
  torch>=2.5.0
2
  transformers>=4.57.0
3
  gradio>=5.0.0
 
4
  snac>=1.2.1
5
  soundfile>=0.13.0
6
  numpy>=2.1.0
7
  accelerate>=1.10.0
 
8