BATUTOchatbot / model_manager.py
ivanoctaviogaitansantos's picture
Create model_manager.py
75c6930 verified
raw
history blame
2.86 kB
# model_manager.py
from transformers import pipeline
import asyncio
class ModelManager:
def __init__(self, cache_dir="./model_cache"):
self.models = {}
self.models_loaded = False
self.cache_dir = cache_dir
async def load_models_async(self):
async def load(key, model_name):
try:
self.models[key] = pipeline(
"text-generation",
model=model_name,
device="cpu",
model_kwargs={"cache_dir": self.cache_dir}
)
except Exception as e:
self.models[key] = None
tasks = [
load("chat","microsoft/DialoGPT-small"),
load("code","microsoft/CodeGPT-small-py"),
load("creative","microsoft/DialoGPT-small"),
]
await asyncio.gather(*tasks)
self.models_loaded = True
def get_model_for_task(self, task_type):
mapping = {
"CONVERSATION": self.models.get("chat"),
"CODE_GENERATION": self.models.get("code"),
"CREATIVE_WRITING": self.models.get("creative"),
}
return mapping.get(task_type, self.models.get("chat"))
# api_agent.py
import requests
class APIAgent:
def __init__(self, config):
self.config = config # debe contener keys ya cargadas desde /config
self.session = requests.Session()
def call_openai(self, prompt, system_message=""):
key = self.config.get("openai_api_key", "")
if not key:
return None
url = "https://api.openai.com/v1/chat/completions"
headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
payload = {"model":"gpt-3.5-turbo", "messages":[{"role":"system","content":system_message},{"role":"user","content":prompt}], "temperature": self.config.get("temperature", 0.7), "max_tokens": self.config.get("max_tokens", 600)}
r = self.session.post(url, headers=headers, json=payload, timeout=self.config.get("timeout", 30))
if r.ok:
return r.json()["choices"][0]["message"]["content"]
return None
def call_deepseek(self, prompt, system_message=""):
key = self.config.get("deepseek_api_key", "")
if not key:
return None
# similar a OpenAI: payload adaptado
url = "https://api.deepseek.com/v1/chat/completions"
headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
payload = {"model":"deepseek-chat","messages":[{"role":"system","content":system_message},{"role":"user","content":prompt}]}
r = self.session.post(url, headers=headers, json=payload, timeout=self.config.get("timeout", 30))
if r.ok:
return r.json()["choices"][0]["message"]["content"]
return None