prelington commited on
Commit
773f5ca
·
verified ·
1 Parent(s): adf2a16

Update ProTalk_ModelBuilder.py

Browse files
Files changed (1) hide show
  1. ProTalk_ModelBuilder.py +47 -8
ProTalk_ModelBuilder.py CHANGED
@@ -1,11 +1,50 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import torch
 
3
 
4
- base_model = "microsoft/phi-2"
5
- tokenizer = AutoTokenizer.from_pretrained(base_model)
6
- model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True)
7
 
8
- prompt = "User: Hello! Who are you?\nAI:"
9
- inputs = tokenizer(prompt, return_tensors="pt")
10
- outputs = model.generate(**inputs, max_new_tokens=60)
11
- print(tokenizer.decode(outputs[0], skip_special_tokens=True))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
2
  import torch
3
+ import threading
4
 
5
+ model_name = "microsoft/phi-2"
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
7
 
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ model_name,
11
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
12
+ low_cpu_mem_usage=True
13
+ ).to(device)
14
+
15
+ system_prompt = (
16
+ "You are ProTalk, a professional and intelligent AI. "
17
+ "You answer clearly, politely, and with insight. "
18
+ "Be professional, witty, and helpful in all responses."
19
+ )
20
+
21
+ def chat_loop():
22
+ history = []
23
+ print("ProTalk Online — type 'exit' to quit.\n")
24
+ while True:
25
+ user_input = input("User: ")
26
+ if user_input.lower() == "exit":
27
+ break
28
+ prompt = system_prompt + "\n" + "\n".join(history) + f"\nUser: {user_input}\nProTalk:"
29
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
30
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
31
+ thread = threading.Thread(target=model.generate, kwargs={
32
+ "input_ids": inputs["input_ids"],
33
+ "max_new_tokens": 200,
34
+ "do_sample": True,
35
+ "temperature": 0.7,
36
+ "top_p": 0.9,
37
+ "streamer": streamer
38
+ })
39
+ thread.start()
40
+ output_text = ""
41
+ for token in streamer:
42
+ print(token, end="", flush=True)
43
+ output_text += token
44
+ thread.join()
45
+ print()
46
+ history.append(f"User: {user_input}")
47
+ history.append(f"ProTalk: {output_text}")
48
+
49
+ if __name__ == "__main__":
50
+ chat_loop()