BATUTOchatbot / model_manager.py
ivanoctaviogaitansantos's picture
Update model_manager.py
abf3a92 verified
raw
history blame
4.7 kB
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import logging
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self):
self.dialo_model = None
self.code_model = None
self.dialo_tokenizer = None
self.code_tokenizer = None
self.loaded = False
self.config = {}
def load_models(self):
"""Carga los modelos locales en CPU optimizado para HF"""
try:
logger.info('Cargando modelos locales en CPU (HF Optimizado)...')
# Cargar DialoGPT-small - m谩s r谩pido y liviano
self.dialo_model = pipeline(
'text-generation',
model='microsoft/DialoGPT-small',
device='cpu',
torch_dtype=torch.float32,
model_kwargs={'low_cpu_mem_usage': True}
)
logger.info('DialoGPT-small cargado')
# Cargar CodeGPT-small-py para c贸digo
self.code_tokenizer = AutoTokenizer.from_pretrained(
'microsoft/CodeGPT-small-py',
trust_remote_code=True
)
self.code_model = AutoModelForCausalLM.from_pretrained(
'microsoft/CodeGPT-small-py',
device_map='cpu',
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True
)
logger.info('CodeGPT-small-py cargado')
self.loaded = True
logger.info('Todos los modelos locales cargados exitosamente')
return True
except Exception as e:
logger.error(f'Error cargando modelos locales: {e}')
# Fallback a modelos m谩s simples si falla
try:
logger.info('Intentando carga alternativa...')
self.dialo_model = pipeline(
'text-generation',
model='microsoft/DialoGPT-small',
device='cpu'
)
self.loaded = True
logger.info('Carga alternativa exitosa')
return True
except Exception as e2:
logger.error(f'Carga alternativa fallida: {e2}')
self.loaded = False
return False
def set_config(self, config):
"""Configuraci贸n para APIs externas"""
self.config = config or {}
def generate_local_response(self, prompt, is_code=False, max_length=150):
"""Genera respuesta usando modelos locales optimizados"""
if not self.loaded:
if not self.load_models():
return 'Error: No se pudieron cargar los modelos locales'
try:
if is_code and self.code_model and self.code_tokenizer:
# Usar CodeGPT para generaci贸n de c贸digo
inputs = self.code_tokenizer.encode(prompt, return_tensors='pt')
attention_mask = torch.ones_like(inputs)
with torch.no_grad():
outputs = self.code_model.generate(
inputs,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=self.code_tokenizer.eos_token_id,
early_stopping=True
)
response = self.code_tokenizer.decode(outputs[0], skip_special_tokens=True)
else:
# Usar DialoGPT para conversaci贸n general
result = self.dialo_model(
prompt,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=self.dialo_model.tokenizer.eos_token_id,
early_stopping=True
)
response = result[0]['generated_text']
# Limpiar la respuesta
if response.startswith(prompt):
response = response[len(prompt):].strip()
return response if response else 'No pude generar una respuesta. Intenta reformular tu pregunta.'
except Exception as e:
logger.error(f'Error en generaci贸n local: {e}')
return f'Error temporal: {str(e)}. Intenta nuevamente.'