import spaces from transformers import AutoModelForCausalLM, AutoTokenizer import torch import gradio as gr import asyncio import threading from concurrent.futures import ThreadPoolExecutor import time import re # Global model storage - all loaded simultaneously on L40S models = {} tokenizers = {} model_loaded = {} # Speed-optimized configurations for L40S MODEL_CONFIGS = { "Llama-1 7B": { "model_id": "huggyllama/llama-7b", "load_in_4bit": False, # Use full precision for speed "torch_dtype": torch.bfloat16, "device_map": {"": 0}, # Force to GPU 0 }, "Llama-2 7B Chat": { "model_id": "meta-llama/Llama-2-7b-chat-hf", "load_in_4bit": False, "torch_dtype": torch.bfloat16, "device_map": {"": 0}, }, "Llama-3.2 3B": { "model_id": "meta-llama/Llama-3.2-3B-Instruct", "load_in_4bit": False, "torch_dtype": torch.bfloat16, "device_map": {"": 0}, } } def load_all_models(): """Load all models simultaneously - L40S has enough VRAM""" global models, tokenizers, model_loaded print("Loading all models simultaneously on L40S...") start_time = time.time() for model_name, config in MODEL_CONFIGS.items(): print(f"Loading {model_name}...") try: # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(config["model_id"]) # Handle tokenizer setup differently for Llama-1 if "Llama-1" in model_name: # Llama-1 doesn't have a pad token by default tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" # Important for generation else: if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model with basic optimizations (ZeroGPU compatible) model = AutoModelForCausalLM.from_pretrained( config["model_id"], torch_dtype=config["torch_dtype"], device_map=config["device_map"], trust_remote_code=True, use_cache=True, low_cpu_mem_usage=True, ) models[model_name] = model tokenizers[model_name] = tokenizer model_loaded[model_name] = True print(f"✅ {model_name} loaded successfully") except Exception as e: print(f"❌ Error loading {model_name}: {e}") model_loaded[model_name] = False total_time = time.time() - start_time print(f"All models loaded in {total_time:.2f} seconds") print(f"GPU Memory used: {torch.cuda.memory_allocated()/1024**3:.2f}GB") def format_prompt(input_question, model_name): """Format prompt based on model type""" if "Llama-2" in model_name: system_msg = "You are a helpful assistant. Answer questions clearly and concisely." return f"[INST] <>\n{system_msg}\n<>\n\n{input_question} [/INST]" elif "3.2" in model_name: messages = [ {"role": "system", "content": "You are a helpful assistant. Answer questions clearly and concisely."}, {"role": "user", "content": input_question} ] return tokenizers[model_name].apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) else: # Llama-1 # Better prompt format for Llama-1 to encourage stopping return f"Question: {input_question}\n\nResponse: " def clean_llama1_response(response_text, original_prompt): """Clean up Llama-1 response to prevent loops and cut at natural stopping points""" # Remove the original prompt if it appears in the response if original_prompt in response_text: response_text = response_text.replace(original_prompt, "").strip() # Split by common conversation markers and take first part stop_markers = [ "\n\nHuman:", "\n\nUser:", "\n\nQuestion:", "\n\nQ:", "\n\nA:", "\nHuman:", "\nUser:", "Human:", "User:", "Question:", "###", "Answer:" ] for marker in stop_markers: if marker in response_text: response_text = response_text.split(marker)[0].strip() break # Remove repetitive patterns (simple heuristic) lines = response_text.split('\n') cleaned_lines = [] seen_lines = set() for line in lines: line = line.strip() if line and line not in seen_lines: cleaned_lines.append(line) seen_lines.add(line) elif line in seen_lines: # Stop if we see repetition break response_text = '\n'.join(cleaned_lines) # Truncate if too long (another safety measure) if len(response_text) > 1000: response_text = response_text[:1000] + "..." return response_text.strip() def generate_single_response(model_name, input_question): """Generate response from a single model - optimized for speed""" if not model_loaded.get(model_name, False): return f"❌ {model_name} not available" try: model = models[model_name] tokenizer = tokenizers[model_name] # Format prompt formatted_prompt = format_prompt(input_question, model_name) # Tokenize with speed optimizations inputs = tokenizer( formatted_prompt, return_tensors="pt", max_length=512, truncation=True, padding=False ).to(model.device) # Generate with speed-focused settings with torch.no_grad(): start_time = time.time() outputs = model.generate( **inputs, #**generation_kwargs, max_new_tokens=512, ) generation_time = time.time() - start_time # Extract response response_tokens = outputs[0][inputs['input_ids'].shape[1]:] response = tokenizer.decode(response_tokens, skip_special_tokens=True) # Special cleaning for Llama-1 if "Llama-1" in model_name: response = clean_llama1_response(response, formatted_prompt) return response except Exception as e: return f"❌ Error with {model_name}: {str(e)}" @spaces.GPU def process_all_models_parallel(input_question): """Process all models in parallel for maximum speed""" # if not input_question.strip(): # return "❌ Please enter a question", "❌ Please enter a question", "❌ Please enter a question" # start_time = time.time() # # Use ThreadPoolExecutor for parallel processing # with ThreadPoolExecutor(max_workers=3) as executor: # # Submit all tasks simultaneously # futures = { # executor.submit(generate_single_response, model_name, input_question): model_name # for model_name in MODEL_CONFIGS.keys() # } # # Collect results as they complete # results = {} # for future in futures: # model_name = futures[future] # try: # result = future.result(timeout=45) # Longer timeout for Llama-1 # results[model_name] = result # except Exception as e: # results[model_name] = f"❌ Timeout or error for {model_name}: {str(e)}" # total_time = time.time() - start_time # # Add total timing to first response # llama1_response = results.get("Llama-1 7B", "❌ Error") # return ( # llama1_response, # results.get("Llama-2 7B Chat", "❌ Error"), # results.get("Llama-3.2 3B", "❌ Error") # ) llama1_response = generate_single_response("Llama-1 7B", input_question) llama2_response = generate_single_response("Llama-2 7B Chat", input_question) llama3_response = generate_single_response("Llama-3.2 3B", input_question) return llama1_response, llama2_response, llama3_response # def benchmark_models(): # """Benchmark all models with a test question""" # test_question = "What is 2+2? Please provide a brief answer." # print("🏃‍♂️ Running benchmark...") # start_time = time.time() # results = process_all_models_parallel(test_question) # total_time = time.time() - start_time # print(f"Benchmark completed in {total_time:.2f}s") # return f"Benchmark completed! All models ready. Total time: {total_time:.2f}s" def create_interface(): """Create speed-optimized Gradio interface""" with gr.Blocks(title="Speed-Optimized Multi-Llama", theme=gr.themes.Glass()) as demo: gr.Markdown("NOTE: Llama-1 7b is NOT a chat model - behaviour in Question-Answering tasks is erratic!") # with gr.Row(): # benchmark_btn = gr.Button("🏃‍♂️ Run Benchmark", variant="secondary", size="sm") # benchmark_output = gr.Textbox(label="Benchmark Results", visible=False) with gr.Row(): question_input = gr.Textbox( label="Enter your question", placeholder="What is the meaning of life?", lines=2, max_lines=4 ) with gr.Row(): submit_btn = gr.Button("Ask All Models", variant="primary", size="lg") clear_btn = gr.Button("🗑️ Clear", variant="secondary") # Real-time responses in columns for better UX with gr.Row(): with gr.Column(): gr.Markdown("### 🦙 Llama-1 7B") output1 = gr.Textbox( label="Response", interactive=False, lines=8, max_lines=15, show_label=False ) with gr.Column(): gr.Markdown("### 🦙 Llama-2 7B Chat") output2 = gr.Textbox( label="Response", interactive=False, lines=8, max_lines=15, show_label=False ) with gr.Column(): gr.Markdown("### 🦙 Llama-3.2 3B") output3 = gr.Textbox( label="Response", interactive=False, lines=8, max_lines=15, show_label=False ) # Event handlers submit_btn.click( fn=process_all_models_parallel, inputs=[question_input], outputs=[output1, output2, output3], ) clear_btn.click( fn=lambda: ("", "", "", ""), outputs=[question_input, output1, output2, output3] ) return demo # Load all models at startup load_all_models() if __name__ == "__main__": demo = create_interface() demo.launch( )