Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
| import time | |
| import numpy as np | |
| from torch.nn import functional as F | |
| import os | |
| from threading import Thread | |
| model_path = "ayoolaolafenwa/ChatLM" | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForCausalLM.from_pretrained(model_path, device_map = "auto", torch_dtype=torch.bfloat16, load_in_8bit=True) | |
| class StopOnTokens(StoppingCriteria): | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| stop_ids = [0] | |
| for stop_id in stop_ids: | |
| if input_ids[0][-1] == stop_id: | |
| return True | |
| return False | |
| def user(message, history): | |
| # Append the user's message to the conversation history | |
| return "", history + [[message, ""]] | |
| def chat(curr_system_message, history): | |
| # Initialize a StopOnTokens object | |
| stop = StopOnTokens() | |
| # Construct the input message string for the model by concatenating the current system message and conversation history | |
| messages = curr_system_message + \ | |
| "".join(["".join(["<user>: "+item[0], "<chatbot>: "+item[1]]) | |
| for item in history]) | |
| # Tokenize the messages string | |
| tokens = tokenizer([messages], return_tensors="pt").to("cuda") | |
| streamer = TextIteratorStreamer( | |
| tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) | |
| token_ids = tokens.input_ids | |
| attention_mask=tokens.attention_mask | |
| generate_kwargs = dict( | |
| input_ids=token_ids, | |
| attention_mask = attention_mask, | |
| streamer = streamer, | |
| max_length=2048, | |
| do_sample=True, | |
| num_return_sequences=1, | |
| eos_token_id=tokenizer.eos_token_id, | |
| temperature = 0.5, | |
| stopping_criteria=StoppingCriteriaList([stop]) | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| #Initialize an empty string to store the generated text | |
| partial_text = "" | |
| for new_text in streamer: | |
| # print(new_text) | |
| partial_text += new_text | |
| history[-1][1] = partial_text | |
| # Yield an empty string to cleanup the message textbox and the updated conversation history | |
| yield history | |
| return partial_text | |
| with gr.Blocks() as demo: | |
| # history = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| ChatLM is a chat Large Language model finetuned with pretrained [Falcon-1B model](https://huggingface.co/tiiuae/falcon-rw-1b). | |
| It was trained on a dataset containing normal day to day human conversations, due to limited data used in training it will not generalize well for tasks like coding, current affairs and hallucinations may occur. | |
| """ | |
| ) | |
| gr.Markdown(""" # Github Repo | |
| https://github.com/ayoolaolafenwa/ChatLM/tree/main """) | |
| chatbot = gr.Chatbot().style(height=400) | |
| with gr.Row(): | |
| with gr.Column(): | |
| msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box", | |
| show_label=False).style(container=False) | |
| with gr.Column(): | |
| with gr.Row(): | |
| submit = gr.Button("Run") | |
| stop = gr.Button("Stop") | |
| clear = gr.Button("Clear") | |
| system_msg = gr.Textbox( | |
| label="Response Message", interactive=False, visible=False) | |
| submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( | |
| fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True) | |
| submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( | |
| fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True) | |
| stop.click(fn=None, inputs=None, outputs=None, cancels=[ | |
| submit_event, submit_click_event], queue=False) | |
| clear.click(lambda: None, None, [chatbot], queue=False) | |
| demo.queue(max_size=32, concurrency_count=2) | |
| demo.launch() |