Spaces:
Paused
Paused
| import spaces | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| """ | |
| LLM-JP(13B)モデルを使用したGradioチャットボット | |
| Hugging Face Transformersライブラリを使用してローカルでモデルを実行 | |
| """ | |
| # モデルとトークナイザーの初期化 | |
| # MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1" # Sarashina2 3B instruceted | |
| # MODEL_NAME = "sbintuitions/sarashina2-7b" # Sarashina2 7B | |
| # MODEL_NAME = "sbintuitions/sarashina2-13b" # Sarashina2 13B | |
| # MODEL_NAME = "sbintuitions/sarashina2-70b" # Sarashina2 70B | |
| # MODEL_NAME = "sbintuitions/sarashina1-65b" # Sarashina1 65B | |
| # MODEL_NAME = "elyza/Llama-3-ELYZA-JP-8B" # ELYZA-JP-8B | |
| # MODEL_NAME = "lightblue/ao-karasu-72B" # ao-karasu-72B | |
| MODEL_NAME = "llm-jp/llm-jp-3-13b-instruct" # llm-jp-3-13b-instruct | |
| # MODEL_NAME = "llm-jp/llm-jp-3-172b-instruct3" # llm-jp-3-172b-instruct3 | |
| # MODEL_NAME = "shisa-ai/shisa-v2-unphi4-14b" # shisa-v2-unphi4-14b | |
| # MODEL_NAME = "shisa-ai/shisa-v2-qwen2.5-32b" # shisa-v2-qwen2.5-32b | |
| print("モデルを読み込み中〜...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True | |
| ) | |
| print("モデルの読み込みが完了しました〜。") | |
| print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| # True | |
| print("あ") | |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| print("い") | |
| # Tesla T4 | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| """ | |
| チャットボットの応答を生成する関数 | |
| Gradio ChatInterfaceの標準形式に対応 | |
| """ | |
| try: | |
| # システムメッセージと会話履歴を含むプロンプトを構築 | |
| conversation = "" | |
| if system_message.strip(): | |
| conversation += f"システム: {system_message}\n" | |
| # 会話履歴を追加 | |
| for user_msg, bot_msg in history: | |
| if user_msg: | |
| conversation += f"ユーザー: {user_msg}\n" | |
| if bot_msg: | |
| conversation += f"アシスタント: {bot_msg}\n" | |
| # 現在のメッセージを追加 | |
| conversation += f"ユーザー: {message}\nアシスタント: " | |
| # トークン化 | |
| inputs = tokenizer.encode(conversation, return_tensors="pt") | |
| # GPU使用時はCUDAに移動 | |
| if torch.cuda.is_available(): | |
| inputs = inputs.cuda() | |
| # 応答生成(ストリーミング対応) | |
| response = "" | |
| with torch.no_grad(): | |
| # 一度に生成してからストリーミング風に出力 | |
| outputs = model.generate( | |
| inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.1 | |
| ) | |
| # 生成されたテキストをデコード | |
| generated = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # 変換できるかテスト用!! | |
| # import json | |
| # # レスポンス用の辞書を作るときに | |
| # return json.dumps({"result": generated}, ensure_ascii=False) | |
| # 応答部分のみを抽出 | |
| full_response = generated[len(conversation):].strip() | |
| # 不要な部分を除去 | |
| if "ユーザー:" in full_response: | |
| full_response = full_response.split("ユーザー:")[0].strip() | |
| # ストリーミング風の出力 | |
| #for i in range(len(full_response)): | |
| # response = full_response[:i+1] | |
| # yield response | |
| #response = full_response[:len(full_response)] #追加 | |
| #yield response #追加 | |
| #yield full_response #追加 | |
| return full_response #追加 | |
| except Exception as e: | |
| #yield f"エラーが発生しました: {str(e)}" | |
| return f"エラーが発生しました: {str(e)}" #追加 | |
| """ | |
| Gradio ChatInterfaceを使用したシンプルなチャットボット | |
| カスタマイズ可能なパラメータを含む | |
| """ | |
| demo = gr.ChatInterface( | |
| respond, | |
| title="🤖 LLM-JP Chatbot", | |
| description="LLM-JP-13B モデルを使用した日本語チャットボットです。", | |
| additional_inputs=[ | |
| gr.Textbox( | |
| value="あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。", | |
| label="システムメッセージ", | |
| lines=3 | |
| ), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=8192, | |
| value=4096, | |
| step=1, | |
| label="最大新規トークン数" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature (創造性)" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (多様性制御)", | |
| ), | |
| ], | |
| theme=gr.themes.Soft(), | |
| examples=[ | |
| ["こんにちは!今日はどんなことを話しましょうか?"], | |
| ["日本の文化について教えてください。"], | |
| ["簡単なレシピを教えてもらえますか?"], | |
| ["プログラミングについて質問があります。"], | |
| ], | |
| cache_examples=False, | |
| #streaming=False, # 追加 ← これで return のみ受け付ける同期モードに | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_api=True, # API documentation を表示 | |
| debug=True | |
| ) |