import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig import logging import gc import warnings import os from huggingface_hub import login from config import MODEL_CONFIGS, DEFAULT_MODEL, MODEL_SETTINGS, GENERATION_DEFAULTS, MEDICAL_SYSTEM_PROMPT, UI_CONFIG # Login with the secret token login(token=os.getenv("HF_TOKEN")) # Suppress warnings warnings.filterwarnings("ignore") logging.getLogger("transformers").setLevel(logging.ERROR) # Global variables for model and tokenizer model = None tokenizer = None current_model_name = None def load_model(model_key=None): """Load the specified medical model with optimizations for Hugging Face Spaces""" global model, tokenizer, current_model_name if model_key is None: model_key = DEFAULT_MODEL # Try to load models in order of preference - prioritize lightweight models model_keys_to_try = [model_key, "flan_t5_small", "dialogpt_medium", "meditron"] for key in model_keys_to_try: if key not in MODEL_CONFIGS: continue try: model_config = MODEL_CONFIGS[key] model_name = model_config["name"] print(f"Attempting to load model: {model_name} ({model_config['description']})") # Load tokenizer first print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=MODEL_SETTINGS["trust_remote_code"], padding_side="left" ) # Add pad token if it doesn't exist if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Configure quantization for memory efficiency (only for larger models) model_kwargs = { "trust_remote_code": MODEL_SETTINGS["trust_remote_code"], "low_cpu_mem_usage": MODEL_SETTINGS["low_cpu_mem_usage"] } # Optimized loading for CPU performance if MODEL_SETTINGS["use_quantization"] and torch.cuda.is_available() and key in ["medllama2", "meditron", "clinical_camel"]: # Only use quantization on GPU for larger models quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) model_kwargs["quantization_config"] = quantization_config model_kwargs["torch_dtype"] = torch.float16 model_kwargs["device_map"] = MODEL_SETTINGS["device_map"] else: # For CPU or smaller models, use optimized settings if torch.cuda.is_available(): model_kwargs["torch_dtype"] = torch.float16 model_kwargs["device_map"] = "auto" else: # CPU-optimized settings model_kwargs["torch_dtype"] = torch.float32 # Use float32 on CPU model_kwargs["device_map"] = None # Let it use CPU naturally print("Loading model...") # Use appropriate model class based on model type if "flan-t5" in model_name.lower() or "t5" in model_name.lower(): model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs) else: model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) current_model_name = model_name print(f"āœ… Model loaded successfully: {model_name}") return True except Exception as e: print(f"āŒ Failed to load {key}: {str(e)}") # Clean up on failure model = None tokenizer = None continue print("āŒ All model loading attempts failed") return False def generate_response(prompt, max_tokens=None, temperature=None, top_p=None): """Generate response using the loaded model""" global model, tokenizer, current_model_name print(f"Starting generation for prompt: {prompt}") if not prompt or not prompt.strip(): return "Please enter a question. 😊" if model is None or tokenizer is None: return "āŒ Model not loaded. Please wait for initialization or try restarting the space." # Use defaults if not specified max_tokens = max_tokens or GENERATION_DEFAULTS["max_new_tokens"] temperature = temperature or GENERATION_DEFAULTS["temperature"] top_p = top_p or GENERATION_DEFAULTS["top_p"] try: # Format prompt based on model type if "flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower(): # Use a concise instruction prefix for T5 instruction = "You are a friendly medical assistant. Answer with short, clear health info. Use emojis like 😊. For serious issues, suggest seeing a doctor." full_input = f"{instruction}\nQuestion: {prompt} Answer:" else: # Causal LM format full_input = f"{MEDICAL_SYSTEM_PROMPT}\n\nPatient/User: {prompt}\n" print(f"Full input: {full_input}") # Tokenize input with proper truncation (reduced max_length for T5) inputs = tokenizer( full_input, return_tensors="pt", truncation=True, max_length=512, padding=True ) # Move to appropriate device device = next(model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} # Generation parameters - optimized for T5 generation_kwargs = { "max_new_tokens": min(max_tokens, 256), # Reduced to 256 for control "temperature": temperature, "top_p": top_p, "do_sample": GENERATION_DEFAULTS["do_sample"], "repetition_penalty": GENERATION_DEFAULTS["repetition_penalty"], "no_repeat_ngram_size": GENERATION_DEFAULTS["no_repeat_ngram_size"] } # Add pad_token_id for non-T5 models if not ("flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower()): generation_kwargs["pad_token_id"] = tokenizer.eos_token_id print(f"Generating with kwargs: {generation_kwargs}") # Generate response print(f"šŸ¤– Generating response with {current_model_name}...") import time start_time = time.time() with torch.no_grad(): outputs = model.generate(**inputs, **generation_kwargs) generation_time = time.time() - start_time print(f"ā±ļø Generation completed in {generation_time:.2f} seconds") # Decode response - different handling for T5 vs causal models if "flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower(): # T5 generates only the answer, no need to remove prompt response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() else: # Causal models generate prompt + answer, need to remove prompt full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = full_response.replace(full_input, "").strip() print(f"Generated response: {response}") # Clean up response if not response or len(response.strip()) < 10: response = "Sorry, I couldn't process that. Try again or see a doctor. 😊" print(f"āœ… Generated response length: {len(response)} characters") print(f"šŸ“„ Response preview: {response[:150]}{'...' if len(response) > 150 else ''}") # Clean up memory del inputs, outputs if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # Force garbage collection print(f"šŸ“œ Generated response: {response}") return response except Exception as e: error_msg = f"Error generating response: {str(e)}" print(error_msg) return f"āš ļø I encountered a technical issue while processing your request. Please try again or rephrase your question. If the problem persists, consider consulting a healthcare professional directly." def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, ): """Main response function for Gradio ChatInterface""" if not message or not message.strip(): return "Please enter a medical question or concern." # Add a disclaimer for first-time users disclaimer = "\n\nāš ļø **Medical Disclaimer**: This AI provides general health information only. Always consult healthcare professionals for medical advice, diagnosis, or treatment." try: # Generate response response = generate_response( message.strip(), max_tokens=int(max_tokens), temperature=float(temperature), top_p=float(top_p) ) # Add disclaimer to response if "disclaimer" not in response.lower() and "consult" not in response.lower(): response += disclaimer return response except Exception as e: error_msg = f"System error: {str(e)}" print(error_msg) return f"āš ļø System temporarily unavailable. Please try again later or consult a healthcare professional directly.{disclaimer}" def get_model_info(): """Get information about the currently loaded model""" if current_model_name: return f"Currently using: {current_model_name}" return "No model loaded" # Load model on startup print("šŸ„ Initializing MedLLaMA2 Medical Chatbot...") print("šŸ“‹ Loading medical language model...") model_loaded = load_model() if model_loaded: print(f"āœ… Ready! {get_model_info()}") else: print("āš ļø WARNING: Model failed to load. The app will run but responses may be limited.") # Create Gradio interface with configuration demo = gr.ChatInterface( respond, title=UI_CONFIG["title"], description=UI_CONFIG["description"], additional_inputs=[ gr.Textbox( value=MEDICAL_SYSTEM_PROMPT, label="System Instructions", lines=4, interactive=False # Make it read-only to prevent tampering ), gr.Slider( minimum=UI_CONFIG["max_tokens_range"][0], maximum=UI_CONFIG["max_tokens_range"][1], value=GENERATION_DEFAULTS["max_new_tokens"], step=10, label="Max new tokens" ), gr.Slider( minimum=UI_CONFIG["temperature_range"][0], maximum=UI_CONFIG["temperature_range"][1], value=GENERATION_DEFAULTS["temperature"], step=0.1, label="Temperature (creativity)" ), gr.Slider( minimum=UI_CONFIG["top_p_range"][0], maximum=UI_CONFIG["top_p_range"][1], value=GENERATION_DEFAULTS["top_p"], step=0.05, label="Top-p (focus)", ), ], examples=[[example] for example in UI_CONFIG["examples"]], cache_examples=False, theme=gr.themes.Soft(), css=".gradio-container {max-width: 900px; margin: auto;}" ) # Add model info to the interface with demo: gr.HTML(f"

Model Status: {get_model_info()}

") if __name__ == "__main__": # For Hugging Face Spaces deployment demo.launch( server_name="0.0.0.0", server_port=7860, share=True, show_error=True, debug=True )