|
|
|
|
|
import torch |
|
|
import gc |
|
|
from transformers import activations |
|
|
|
|
|
|
|
|
if not hasattr(activations, 'PytorchGELUTanh'): |
|
|
activations.PytorchGELUTanh = activations.NewGELUActivation |
|
|
|
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
BitsAndBytesConfig, |
|
|
AutoModelForVision2Seq, |
|
|
AutoProcessor |
|
|
) |
|
|
from diffusers import DiffusionPipeline |
|
|
from diffusers.utils import export_to_video |
|
|
from PIL import Image |
|
|
import requests |
|
|
import io |
|
|
from qwen_vl_utils import process_vision_info |
|
|
import os |
|
|
|
|
|
class BrainBus: |
|
|
def __init__(self): |
|
|
print("Initializing Brain Bus Orchestrator...") |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
self.bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.float32, |
|
|
) |
|
|
|
|
|
|
|
|
self.orchestrator_path = "merged_models/math" |
|
|
self.tokenizer = None |
|
|
self.orchestrator = None |
|
|
self._load_orchestrator() |
|
|
|
|
|
def _load_orchestrator(self): |
|
|
print(f"Loading Orchestrator from {self.orchestrator_path}...") |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.orchestrator_path) |
|
|
self.orchestrator = AutoModelForCausalLM.from_pretrained( |
|
|
self.orchestrator_path, |
|
|
quantization_config=self.bnb_config, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Failed to load orchestrator: {e}") |
|
|
|
|
|
def _clean_memory(self): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
def determine_intent(self, user_input): |
|
|
|
|
|
prompt = ( |
|
|
"Classify the following user query into one of these categories: " |
|
|
"[CODE, MATH, GENERAL, VISION, VIDEO, 3D]. " |
|
|
"Return ONLY the category name.\n\n" |
|
|
f"Query: {user_input}\nCategory:" |
|
|
) |
|
|
|
|
|
try: |
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
outputs = self.orchestrator.generate(**inputs, max_new_tokens=10) |
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
|
|
|
if prompt in response: |
|
|
response = response.replace(prompt, "") |
|
|
|
|
|
response = response.strip().upper() |
|
|
|
|
|
|
|
|
for category in ['CODE', 'MATH', 'GENERAL', 'VISION', 'VIDEO', '3D']: |
|
|
if category in response: |
|
|
return category |
|
|
|
|
|
return "GENERAL" |
|
|
except Exception as e: |
|
|
print(f"Error determining intent: {e}") |
|
|
return "GENERAL" |
|
|
|
|
|
def run_code_expert(self, query): |
|
|
print("Loading Code Expert...") |
|
|
model = None |
|
|
try: |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"merged_models/code", |
|
|
quantization_config=self.bnb_config, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
inputs = self.tokenizer(query, return_tensors="pt").to(self.device) |
|
|
outputs = model.generate(**inputs, max_new_tokens=256) |
|
|
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
if query in result: |
|
|
result = result.replace(query, "").strip() |
|
|
return result |
|
|
except Exception as e: |
|
|
return f"Code Expert Error: {e}" |
|
|
finally: |
|
|
if model is not None: |
|
|
del model |
|
|
self._clean_memory() |
|
|
|
|
|
def run_general_expert(self, query): |
|
|
print("Loading General Expert...") |
|
|
model = None |
|
|
try: |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"merged_models/normal", |
|
|
quantization_config=self.bnb_config, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
inputs = self.tokenizer(query, return_tensors="pt").to(self.device) |
|
|
outputs = model.generate(**inputs, max_new_tokens=256) |
|
|
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
if query in result: |
|
|
result = result.replace(query, "").strip() |
|
|
return result |
|
|
except Exception as e: |
|
|
return f"General Expert Error: {e}" |
|
|
finally: |
|
|
if model is not None: |
|
|
del model |
|
|
self._clean_memory() |
|
|
|
|
|
def run_math_expert(self, query): |
|
|
print("Using Orchestrator (Math Expert)...") |
|
|
|
|
|
try: |
|
|
inputs = self.tokenizer(query, return_tensors="pt").to(self.device) |
|
|
outputs = self.orchestrator.generate(**inputs, max_new_tokens=256) |
|
|
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
if query in result: |
|
|
result = result.replace(query, "").strip() |
|
|
return result |
|
|
except Exception as e: |
|
|
return f"Math Expert Error: {e}" |
|
|
|
|
|
def run_vision_expert(self, query, image_path=None): |
|
|
print("Loading Vision Expert...") |
|
|
model = None |
|
|
try: |
|
|
|
|
|
model_id = "Qwen/Qwen2.5-VL-3B-Instruct-AWQ" |
|
|
|
|
|
model = AutoModelForVision2Seq.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
|
|
|
|
|
messages = [] |
|
|
content = [] |
|
|
if image_path: |
|
|
try: |
|
|
image = Image.open(image_path) |
|
|
content.append({"type": "image", "image": image}) |
|
|
except: |
|
|
return "Error loading image." |
|
|
|
|
|
content.append({"type": "text", "text": query}) |
|
|
messages.append({"role": "user", "content": content}) |
|
|
|
|
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
).to(self.device) |
|
|
|
|
|
generated_ids = model.generate(**inputs, max_new_tokens=128) |
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
result = processor.batch_decode( |
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
)[0] |
|
|
|
|
|
return result |
|
|
except Exception as e: |
|
|
return f"Vision Expert Error: {e}" |
|
|
finally: |
|
|
if model is not None: |
|
|
del model |
|
|
self._clean_memory() |
|
|
|
|
|
def run_video_expert(self, query): |
|
|
print("Loading Video Expert...") |
|
|
pipe = None |
|
|
try: |
|
|
|
|
|
model_id = "damo-vilab/text-to-video-ms-1.7b" |
|
|
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16") |
|
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
|
|
|
result = pipe(query, num_inference_steps=20) |
|
|
video_frames = result.frames[0] |
|
|
|
|
|
output_path = "generated_video.mp4" |
|
|
export_to_video(video_frames, output_path, fps=8) |
|
|
|
|
|
return f"Video generated at {output_path}" |
|
|
except Exception as e: |
|
|
return f"Video Expert Error: {e}" |
|
|
finally: |
|
|
if pipe is not None: |
|
|
del pipe |
|
|
self._clean_memory() |
|
|
|
|
|
def run_3d_expert(self, query): |
|
|
print("Loading 3D Expert...") |
|
|
pipe = None |
|
|
try: |
|
|
model_id = "openai/shap-e" |
|
|
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
|
|
pipe.to("cuda") |
|
|
|
|
|
_ = pipe(query, num_inference_steps=20) |
|
|
|
|
|
return "3D Object generated (check output directory)" |
|
|
except Exception as e: |
|
|
return f"3D Expert Error: {e}" |
|
|
finally: |
|
|
if pipe is not None: |
|
|
del pipe |
|
|
self._clean_memory() |
|
|
|
|
|
def process_query(self, text, image_path=None): |
|
|
|
|
|
print(f"\n[Input]: {text}") |
|
|
intent = self.determine_intent(text) |
|
|
print(f"[Intent Detected]: {intent}") |
|
|
|
|
|
|
|
|
response = "" |
|
|
if intent == "CODE": |
|
|
response = self.run_code_expert(text) |
|
|
elif intent == "MATH": |
|
|
response = self.run_math_expert(text) |
|
|
elif intent == "VISION": |
|
|
response = self.run_vision_expert(text, image_path) |
|
|
elif intent == "VIDEO": |
|
|
response = self.run_video_expert(text) |
|
|
elif intent == "3D": |
|
|
response = self.run_3d_expert(text) |
|
|
else: |
|
|
response = self.run_general_expert(text) |
|
|
|
|
|
return response |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
bus = BrainBus() |
|
|
print("Brain Bus ready. Run 'process_query' to interact.") |
|
|
|