mujahid1214 commited on
Commit
8a97291
·
verified ·
1 Parent(s): 324b96a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -0
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from snac import SNAC
6
+ import soundfile as sf
7
+ import numpy as np
8
+
9
+ CODE_START_TOKEN_ID = 128257
10
+ CODE_END_TOKEN_ID = 128258
11
+ CODE_TOKEN_OFFSET = 128266
12
+ SNAC_MIN_ID = 128266
13
+ SNAC_MAX_ID = 156937
14
+ SNAC_TOKENS_PER_FRAME = 7
15
+
16
+ SOH_ID = 128259
17
+ EOH_ID = 128260
18
+ SOA_ID = 128261
19
+ BOS_ID = 128000
20
+ TEXT_EOT_ID = 128009
21
+
22
+
23
+ def build_prompt(tokenizer, description: str, text: str) -> str:
24
+ """Build formatted prompt for Maya1."""
25
+ soh_token = tokenizer.decode([SOH_ID])
26
+ eoh_token = tokenizer.decode([EOH_ID])
27
+ soa_token = tokenizer.decode([SOA_ID])
28
+ sos_token = tokenizer.decode([CODE_START_TOKEN_ID])
29
+ eot_token = tokenizer.decode([TEXT_EOT_ID])
30
+ bos_token = tokenizer.bos_token
31
+
32
+ formatted_text = f'<description="{description}"> {text}'
33
+
34
+ prompt = (
35
+ soh_token + bos_token + formatted_text + eot_token +
36
+ eoh_token + soa_token + sos_token
37
+ )
38
+
39
+ return prompt
40
+
41
+
42
+ def extract_snac_codes(token_ids: list) -> list:
43
+ """Extract SNAC codes from generated tokens."""
44
+ try:
45
+ eos_idx = token_ids.index(CODE_END_TOKEN_ID)
46
+ except ValueError:
47
+ eos_idx = len(token_ids)
48
+
49
+ snac_codes = [
50
+ token_id for token_id in token_ids[:eos_idx]
51
+ if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID
52
+ ]
53
+
54
+ return snac_codes
55
+
56
+
57
+ def unpack_snac_from_7(snac_tokens: list) -> list:
58
+ """Unpack 7-token SNAC frames to 3 hierarchical levels."""
59
+ if snac_tokens and snac_tokens[-1] == CODE_END_TOKEN_ID:
60
+ snac_tokens = snac_tokens[:-1]
61
+
62
+ frames = len(snac_tokens) // SNAC_TOKENS_PER_FRAME
63
+ snac_tokens = snac_tokens[:frames * SNAC_TOKENS_PER_FRAME]
64
+
65
+ if frames == 0:
66
+ return [[], [], []]
67
+
68
+ l1, l2, l3 = [], [], []
69
+
70
+ for i in range(frames):
71
+ slots = snac_tokens[i*7:(i+1)*7]
72
+ l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096)
73
+ l2.extend([
74
+ (slots[1] - CODE_TOKEN_OFFSET) % 4096,
75
+ (slots[4] - CODE_TOKEN_OFFSET) % 4096,
76
+ ])
77
+ l3.extend([
78
+ (slots[2] - CODE_TOKEN_OFFSET) % 4096,
79
+ (slots[3] - CODE_TOKEN_OFFSET) % 4096,
80
+ (slots[5] - CODE_TOKEN_OFFSET) % 4096,
81
+ (slots[6] - CODE_TOKEN_OFFSET) % 4096,
82
+ ])
83
+
84
+ return [l1, l2, l3]
85
+
86
+
87
+ def main():
88
+
89
+ # Load the best open source voice AI model
90
+ print("\n[1/3] Loading Maya1 model...")
91
+ model = AutoModelForCausalLM.from_pretrained(
92
+ "maya-research/maya1",
93
+ torch_dtype=torch.bfloat16,
94
+ device_map="auto",
95
+ trust_remote_code=True
96
+ )
97
+ tokenizer = AutoTokenizer.from_pretrained(
98
+ "maya-research/maya1",
99
+ trust_remote_code=True
100
+ )
101
+ print(f"Model loaded: {len(tokenizer)} tokens in vocabulary")
102
+
103
+ # Load SNAC audio decoder (24kHz)
104
+ print("\n[2/3] Loading SNAC audio decoder...")
105
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
106
+ if torch.cuda.is_available():
107
+ snac_model = snac_model.to("cuda")
108
+ print("SNAC decoder loaded")
109
+
110
+ # Design your voice with natural language
111
+ description = "Realistic male voice in the 30s age with american accent. Normal pitch, warm timbre, conversational pacing."
112
+ text = "Hello! This is Maya1 <laugh_harder> the best open source voice AI model with emotions."
113
+
114
+ print("\n[3/3] Generating speech...")
115
+ print(f"Description: {description}")
116
+ print(f"Text: {text}")
117
+
118
+ # Create prompt with proper formatting
119
+ prompt = build_prompt(tokenizer, description, text)
120
+
121
+ # Debug: Show prompt details
122
+ print(f"\nPrompt preview (first 200 chars):")
123
+ print(f" {repr(prompt[:200])}")
124
+ print(f" Prompt length: {len(prompt)} chars")
125
+
126
+ # Generate emotional speech
127
+ inputs = tokenizer(prompt, return_tensors="pt")
128
+ print(f" Input token count: {inputs['input_ids'].shape[1]} tokens")
129
+ if torch.cuda.is_available():
130
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
131
+
132
+ with torch.inference_mode():
133
+ outputs = model.generate(
134
+ **inputs,
135
+ max_new_tokens=2048, # Increase to let model finish naturally
136
+ min_new_tokens=28, # At least 4 SNAC frames
137
+ temperature=0.4,
138
+ top_p=0.9,
139
+ repetition_penalty=1.1, # Prevent loops
140
+ do_sample=True,
141
+ eos_token_id=CODE_END_TOKEN_ID, # Stop at end of speech token
142
+ pad_token_id=tokenizer.pad_token_id,
143
+ )
144
+
145
+ # Extract generated tokens (everything after the input prompt)
146
+ generated_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist()
147
+
148
+ print(f"Generated {len(generated_ids)} tokens")
149
+
150
+ # Debug: Check what tokens we got
151
+ print(f" First 20 tokens: {generated_ids[:20]}")
152
+ print(f" Last 20 tokens: {generated_ids[-20:]}")
153
+
154
+ # Check if EOS was generated
155
+ if CODE_END_TOKEN_ID in generated_ids:
156
+ eos_position = generated_ids.index(CODE_END_TOKEN_ID)
157
+ print(f" EOS token found at position {eos_position}/{len(generated_ids)}")
158
+
159
+ # Extract SNAC audio tokens
160
+ snac_tokens = extract_snac_codes(generated_ids)
161
+
162
+ print(f"Extracted {len(snac_tokens)} SNAC tokens")
163
+
164
+ # Debug: Analyze token types
165
+ snac_count = sum(1 for t in generated_ids if SNAC_MIN_ID <= t <= SNAC_MAX_ID)
166
+ other_count = sum(1 for t in generated_ids if t < SNAC_MIN_ID or t > SNAC_MAX_ID)
167
+ print(f" SNAC tokens in output: {snac_count}")
168
+ print(f" Other tokens in output: {other_count}")
169
+
170
+ # Check for SOS token
171
+ if CODE_START_TOKEN_ID in generated_ids:
172
+ sos_pos = generated_ids.index(CODE_START_TOKEN_ID)
173
+ print(f" SOS token at position: {sos_pos}")
174
+ else:
175
+ print(f" No SOS token found in generated output!")
176
+
177
+ if len(snac_tokens) < 7:
178
+ print("Error: Not enough SNAC tokens generated")
179
+ return
180
+
181
+ # Unpack SNAC tokens to 3 hierarchical levels
182
+ levels = unpack_snac_from_7(snac_tokens)
183
+ frames = len(levels[0])
184
+
185
+ print(f"Unpacked to {frames} frames")
186
+ print(f" L1: {len(levels[0])} codes")
187
+ print(f" L2: {len(levels[1])} codes")
188
+ print(f" L3: {len(levels[2])} codes")
189
+
190
+ # Convert to tensors
191
+ device = "cuda" if torch.cuda.is_available() else "cpu"
192
+ codes_tensor = [
193
+ torch.tensor(level, dtype=torch.long, device=device).unsqueeze(0)
194
+ for level in levels
195
+ ]
196
+
197
+ # Generate final audio with SNAC decoder
198
+ print("\n[4/4] Decoding to audio...")
199
+ with torch.inference_mode():
200
+ z_q = snac_model.quantizer.from_codes(codes_tensor)
201
+ audio = snac_model.decoder(z_q)[0, 0].cpu().numpy()
202
+
203
+ # Trim warmup samples (first 2048 samples)
204
+ if len(audio) > 2048:
205
+ audio = audio[2048:]
206
+
207
+ duration_sec = len(audio) / 24000
208
+ print(f"Audio generated: {len(audio)} samples ({duration_sec:.2f}s)")
209
+
210
+ # Save your emotional voice output
211
+ output_file = "output.wav"
212
+ sf.write(output_file, audio, 24000)
213
+ print(f"\nVoice generated successfully!")
214
+
215
+
216
+ if __name__ == "__main__":
217
+ main()