Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import json | |
| from pathlib import Path | |
| import yaml | |
| import re | |
| import logging | |
| import io | |
| import sys | |
| import os | |
| import re | |
| from datetime import datetime, timezone, timedelta | |
| import requests | |
| from tools import FileUploader, ResultExtractor, audio_to_str, image_to_str, azure_speech_to_text #gege的多模态 | |
| import numpy as np | |
| from scipy.io.wavfile import write as write_wav | |
| from PIL import Image | |
| # 指定保存文件的相对路径 | |
| SAVE_DIR = 'download' # 相对路径 | |
| os.makedirs(SAVE_DIR, exist_ok=True) # 确保目录存在 | |
| def save_audio(audio, filename): | |
| """保存音频为.wav文件""" | |
| sample_rate, audio_data = audio | |
| write_wav(filename, sample_rate, audio_data) | |
| def save_image(image, filename): | |
| """保存图片为.jpg文件""" | |
| img = Image.fromarray(image.astype('uint8')) | |
| img.save(filename) | |
| # --- IP获取功能 (从 se_app.py 迁移) --- | |
| def get_client_ip(request: gr.Request, debug_mode=False): | |
| """获取客户端真实IP地址""" | |
| if request: | |
| # 从请求头中获取真实IP(考虑代理情况) | |
| x_forwarded_for = request.headers.get("x-forwarded-for", "") | |
| if x_forwarded_for: | |
| client_ip = x_forwarded_for.split(",")[0] | |
| else: | |
| client_ip = request.client.host | |
| if debug_mode: | |
| print(f"Debug: Client IP detected as {client_ip}") | |
| return client_ip | |
| return "unknown" | |
| # --- 配置加载 (从 config_loader.py 迁移并简化) --- | |
| CONFIG = None | |
| HF_CONFIG_PATH = Path(__file__).parent / "todogen_LLM_config.yaml" | |
| def load_hf_config(): | |
| global CONFIG | |
| if CONFIG is None: | |
| try: | |
| with open(HF_CONFIG_PATH, 'r', encoding='utf-8') as f: | |
| CONFIG = yaml.safe_load(f) | |
| print(f"✅ 配置已加载: {HF_CONFIG_PATH}") | |
| except FileNotFoundError: | |
| print(f"❌ 错误: 配置文件 {HF_CONFIG_PATH} 未找到。请确保它在 hf 目录下。") | |
| CONFIG = {} # 提供一个空配置以避免后续错误 | |
| except Exception as e: | |
| print(f"❌ 加载配置文件 {HF_CONFIG_PATH} 时出错: {e}") | |
| CONFIG = {} | |
| return CONFIG | |
| def get_hf_openai_config(): | |
| config = load_hf_config() | |
| return config.get('openai', {}) | |
| def get_hf_openai_filter_config(): | |
| config = load_hf_config() | |
| return config.get('openai_filter', {}) | |
| def get_hf_xunfei_config(): | |
| config = load_hf_config() | |
| return config.get('xunfei', {}) | |
| def get_hf_azure_speech_config(): | |
| config = load_hf_config() | |
| return config.get('azure_speech', {}) | |
| def get_hf_paths_config(): | |
| config = load_hf_config() | |
| # 在hf环境下,路径相对于hf目录 | |
| base = Path(__file__).resolve().parent | |
| paths_cfg = config.get('paths', {}) | |
| return { | |
| 'base_dir': base, | |
| 'prompt_template': base / paths_cfg.get('prompt_template', 'prompt_template.txt'), | |
| 'true_positive_examples': base / paths_cfg.get('true_positive_examples', 'TruePositive_few_shot.txt'), | |
| 'false_positive_examples': base / paths_cfg.get('false_positive_examples', 'FalsePositive_few_shot.txt'), | |
| # data_dir 和 logging_dir 在 app.py 中可能用途不大,除非需要保存 LLM 输出 | |
| } | |
| # --- LLM Client 初始化 (使用 NVIDIA API) --- | |
| # 从配置加载 NVIDIA API 的 base_url, api_key 和 model | |
| llm_config = get_hf_openai_config() | |
| NVIDIA_API_BASE_URL = llm_config.get('base_url') | |
| NVIDIA_API_KEY = llm_config.get('api_key') | |
| NVIDIA_MODEL_NAME = llm_config.get('model') | |
| # 从配置加载 Filter API 的 base_url, api_key 和 model | |
| filter_config = get_hf_openai_filter_config() | |
| Filter_API_BASE_URL = filter_config.get('base_url_filter') | |
| Filter_API_KEY = filter_config.get('api_key_filter') | |
| Filter_MODEL_NAME = filter_config.get('model_filter') | |
| if not NVIDIA_API_BASE_URL or not NVIDIA_API_KEY or not NVIDIA_MODEL_NAME: | |
| print("❌ 错误: NVIDIA API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai 部分。") | |
| # 提供默认值或退出,以便程序可以继续运行,但LLM调用会失败 | |
| NVIDIA_API_BASE_URL = "" | |
| NVIDIA_API_KEY = "" | |
| NVIDIA_MODEL_NAME = "" | |
| if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME: | |
| print("❌ 错误: Filter API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai_filter 部分。") | |
| # 提供默认值或退出,以便程序可以继续运行,但Filter LLM调用会失败 | |
| Filter_API_BASE_URL = "" | |
| Filter_API_KEY = "" | |
| Filter_MODEL_NAME = "" | |
| # --- 日志配置 (简化版) --- | |
| # 修正后的标准流编码设置 (如果需要,但 Gradio 通常处理自己的输出) | |
| # sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') | |
| # sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True) | |
| # sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', write_through=True) | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # --- Prompt 和 Few-Shot 加载 (从 todogen_llm.py 迁移并适配) --- | |
| def load_single_few_shot_file_hf(file_path: Path) -> str: | |
| """加载单个 few-shot 文件并转义 { 和 }""" | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| escaped_content = content.replace('{', '{{').replace('}', '}}') | |
| logger.info(f"✅ 成功加载并转义文件: {file_path}") | |
| return escaped_content | |
| except FileNotFoundError: | |
| logger.warning(f"⚠️ 警告:找不到文件 {file_path}。") | |
| return "" | |
| except Exception as e: | |
| logger.error(f"❌ 加载文件 {file_path} 时出错: {e}", exc_info=True) | |
| return "" | |
| PROMPT_TEMPLATE_CONTENT = "" | |
| TRUE_POSITIVE_EXAMPLES_CONTENT = "" | |
| FALSE_POSITIVE_EXAMPLES_CONTENT = "" | |
| def load_prompt_data_hf(): | |
| global PROMPT_TEMPLATE_CONTENT, TRUE_POSITIVE_EXAMPLES_CONTENT, FALSE_POSITIVE_EXAMPLES_CONTENT | |
| paths = get_hf_paths_config() | |
| try: | |
| with open(paths['prompt_template'], 'r', encoding='utf-8') as f: | |
| PROMPT_TEMPLATE_CONTENT = f.read() | |
| logger.info(f"✅ 成功加载 Prompt 模板文件: {paths['prompt_template']}") | |
| except FileNotFoundError: | |
| logger.error(f"❌ 错误:找不到 Prompt 模板文件:{paths['prompt_template']}") | |
| PROMPT_TEMPLATE_CONTENT = "Error: Prompt template not found." | |
| TRUE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['true_positive_examples']) | |
| FALSE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['false_positive_examples']) | |
| # 应用启动时加载 prompts | |
| load_prompt_data_hf() | |
| # --- JSON 解析器 (从 todogen_llm.py 迁移) --- | |
| def json_parser(text: str) -> dict: | |
| # 改进的JSON解析器,更健壮地处理各种格式 | |
| logger.info(f"Attempting to parse: {text[:200]}...") | |
| try: | |
| # 1. 尝试直接将整个文本作为JSON解析 | |
| try: | |
| parsed_data = json.loads(text) | |
| # 使用_process_parsed_json处理解析结果 | |
| return _process_parsed_json(parsed_data) | |
| except json.JSONDecodeError: | |
| pass # 如果直接解析失败,继续尝试提取代码块 | |
| # 2. 尝试从 ```json ... ``` 代码块中提取和解析 | |
| match = re.search(r'```(?:json)?\n(.*?)```', text, re.DOTALL) | |
| if match: | |
| json_str = match.group(1).strip() | |
| # 修复常见的JSON格式问题 | |
| json_str = re.sub(r',\s*]', ']', json_str) | |
| json_str = re.sub(r',\s*}', '}', json_str) | |
| try: | |
| parsed_data = json.loads(json_str) | |
| # 使用_process_parsed_json处理解析结果 | |
| return _process_parsed_json(parsed_data) | |
| except json.JSONDecodeError as e_block: | |
| logger.warning(f"JSONDecodeError from code block: {e_block} while parsing: {json_str[:200]}") | |
| # 如果从代码块解析也失败,则继续 | |
| # 3. 尝试查找最外层的 '{...}' 或 '[...]' 作为JSON | |
| # 先尝试查找数组格式 [...] | |
| array_match = re.search(r'\[\s*\{.*?\}\s*(?:,\s*\{.*?\}\s*)*\]', text, re.DOTALL) | |
| if array_match: | |
| potential_json = array_match.group(0).strip() | |
| try: | |
| parsed_data = json.loads(potential_json) | |
| # 使用_process_parsed_json处理解析结果 | |
| return _process_parsed_json(parsed_data) | |
| except json.JSONDecodeError: | |
| logger.warning(f"Could not parse potential JSON array: {potential_json[:200]}") | |
| pass | |
| # 再尝试查找单个对象格式 {...} | |
| object_match = re.search(r'\{.*?\}', text, re.DOTALL) | |
| if object_match: | |
| potential_json = object_match.group(0).strip() | |
| try: | |
| parsed_data = json.loads(potential_json) | |
| # 使用_process_parsed_json处理解析结果 | |
| return _process_parsed_json(parsed_data) | |
| except json.JSONDecodeError: | |
| logger.warning(f"Could not parse potential JSON object: {potential_json[:200]}") | |
| pass | |
| # 4. 如果所有尝试都失败,返回错误信息 | |
| logger.error(f"Failed to find or parse JSON block in text: {text[:500]}") # 增加日志长度 | |
| return {"error": "No valid JSON block found or failed to parse", "raw_text": text} | |
| except Exception as e: # 捕获所有其他意外错误 | |
| logger.error(f"Unexpected error in json_parser: {e} for text: {text[:200]}", exc_info=True) | |
| return {"error": f"Unexpected error in json_parser: {e}", "raw_text": text} | |
| def _process_parsed_json(parsed_data): | |
| """处理解析后的JSON数据,确保返回有效的数据结构""" | |
| try: | |
| # 如果解析结果是空列表,返回包含空字典的列表 | |
| if isinstance(parsed_data, list): | |
| if not parsed_data: | |
| logger.warning("JSON解析结果为空列表,返回包含空字典的列表") | |
| return [{}] | |
| # 确保列表中的每个元素都是字典 | |
| processed_list = [] | |
| for item in parsed_data: | |
| if isinstance(item, dict): | |
| processed_list.append(item) | |
| else: | |
| # 如果不是字典,将其转换为字典 | |
| try: | |
| processed_list.append({"content": str(item)}) | |
| except: | |
| processed_list.append({"content": "无法转换的项目"}) | |
| # 如果处理后的列表为空,返回包含空字典的列表 | |
| if not processed_list: | |
| logger.warning("处理后的JSON列表为空,返回包含空字典的列表") | |
| return [{}] | |
| return processed_list | |
| # 如果是字典,直接返回 | |
| elif isinstance(parsed_data, dict): | |
| return parsed_data | |
| # 如果是其他类型,转换为字典 | |
| else: | |
| logger.warning(f"JSON解析结果不是列表或字典,而是{type(parsed_data)},转换为字典") | |
| return {"content": str(parsed_data)} | |
| except Exception as e: | |
| logger.error(f"处理解析后的JSON数据时出错: {e}") | |
| return {"error": f"Error processing parsed JSON: {e}"} | |
| # --- Filter 模块的 System Prompt (从 filter_message/libs.py 迁移) --- | |
| FILTER_SYSTEM_PROMPT = """ | |
| # 角色 | |
| 你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。 | |
| # 任务 | |
| 对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。 | |
| 主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略 | |
| # 要求 | |
| 1. 以json格式输出 | |
| 2. content简洁提炼关键词,字符数<20以内 | |
| 3. 输入条数和输出条数完全一样 | |
| # 输出示例 | |
| ``` | |
| [ | |
| {"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"}, | |
| {"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议"} | |
| ] | |
| ``` | |
| """ | |
| # --- Filter 核心逻辑 (从ToDoAgent集成) --- | |
| def filter_message_with_llm(text_input: str, message_id: str = "user_input_001"): | |
| logger.info(f"调用 filter_message_with_llm 处理输入: {text_input} (msg_id: {message_id})") | |
| # 构造发送给 LLM 的消息 | |
| # filter 模块的 send_llm_with_prompt 接收的是 tuple[tuple] 格式的数据 | |
| # 这里我们只有一个文本输入,需要模拟成那种格式 | |
| mock_data = [(text_input, message_id)] | |
| # 使用与ToDoAgent相同的system prompt | |
| system_prompt = """ | |
| # 角色 | |
| 你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。 | |
| # 任务 | |
| 对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。 | |
| 主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略 | |
| # 要求 | |
| 1. 以json格式输出 | |
| 2. content简洁提炼关键词,字符数<20以内 | |
| 3. 输入条数和输出条数完全一样 | |
| # 输出示例 | |
| ``` | |
| [ | |
| {"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"}, | |
| {"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议邀约"} | |
| ] | |
| ``` | |
| """ | |
| llm_messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": str(mock_data)} | |
| ] | |
| try: | |
| if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME: | |
| logger.error("Filter API 配置不完整,无法调用 Filter LLM。") | |
| return [{"error": "Filter API configuration incomplete", "-": "-"}] | |
| headers = { | |
| "Authorization": f"Bearer {Filter_API_KEY}", | |
| "Accept": "application/json" | |
| } | |
| payload = { | |
| "model": Filter_MODEL_NAME, | |
| "messages": llm_messages, | |
| "temperature": 0.0, # 为提高准确率,温度为0(与ToDoAgent一致) | |
| "top_p": 0.95, | |
| "max_tokens": 1024, | |
| "stream": False | |
| } | |
| api_url = f"{Filter_API_BASE_URL}/chat/completions" | |
| try: | |
| response = requests.post(api_url, headers=headers, json=payload) | |
| response.raise_for_status() # 检查 HTTP 错误 | |
| raw_llm_response = response.json()["choices"][0]["message"]["content"] | |
| logger.info(f"LLM 原始回复 (部分): {raw_llm_response[:200]}...") | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"调用 Filter API 失败: {e}") | |
| return [{"error": f"Filter API call failed: {e}", "-": "-"}] | |
| logger.info(f"Filter LLM 原始回复 (部分): {raw_llm_response[:200]}...") | |
| # 解析 LLM 响应 | |
| # 移除可能的代码块标记 | |
| raw_llm_response = raw_llm_response.replace("```json", "").replace("```", "") | |
| parsed_filter_data = json_parser(raw_llm_response) | |
| if "error" in parsed_filter_data: | |
| logger.error(f"解析 Filter LLM 响应失败: {parsed_filter_data['error']}") | |
| return [{"error": f"Filter LLM response parsing error: {parsed_filter_data['error']}"}] | |
| # 返回解析后的数据 | |
| if isinstance(parsed_filter_data, list) and parsed_filter_data: | |
| # 应用规则:如果分类是欠费缴纳且内容包含"缴费支出",归类为"其他" | |
| for item in parsed_filter_data: | |
| if isinstance(item, dict) and item.get("分类") == "欠费缴纳" and "缴费支出" in item.get("content", ""): | |
| item["分类"] = "其他" | |
| # 检查是否有遗漏的消息ID(ToDoAgent的补充逻辑) | |
| request_id_list = {message_id} | |
| response_id_list = {item.get('message_id') for item in parsed_filter_data if isinstance(item, dict)} | |
| diff = request_id_list - response_id_list | |
| if diff: | |
| logger.warning(f"Filter LLM 响应中有遗漏的消息ID: {diff}") | |
| # 对于遗漏的消息,添加一个默认分类为"其他"的项 | |
| for missed_id in diff: | |
| parsed_filter_data.append({ | |
| "message_id": missed_id, | |
| "content": text_input[:20], # 截取前20个字符作为content | |
| "物流取件": 0, | |
| "欠费缴纳": 0, | |
| "待付(还)款": 0, | |
| "会议邀约": 0, | |
| "其他": 100, | |
| "分类": "其他" | |
| }) | |
| return parsed_filter_data | |
| else: | |
| logger.warning(f"Filter LLM 返回空列表或非预期格式: {parsed_filter_data}") | |
| # 返回默认分类为"其他"的项 | |
| return [{ | |
| "message_id": message_id, | |
| "content": text_input[:20], # 截取前20个字符作为content | |
| "物流取件": 0, | |
| "欠费缴纳": 0, | |
| "待付(还)款": 0, | |
| "会议邀约": 0, | |
| "其他": 100, | |
| "分类": "其他", | |
| "error": "Filter LLM returned empty or unexpected format" | |
| }] | |
| except Exception as e: | |
| logger.exception(f"调用 Filter LLM 或解析时发生错误 (filter_message_with_llm)") | |
| return [{ | |
| "message_id": message_id, | |
| "content": text_input[:20], # 截取前20个字符作为content | |
| "物流取件": 0, | |
| "欠费缴纳": 0, | |
| "待付(还)款": 0, | |
| "会议邀约": 0, | |
| "其他": 100, | |
| "分类": "其他", | |
| "error": f"Filter LLM call/parse error: {str(e)}" | |
| }] | |
| # --- ToDo List 生成核心逻辑 (使用迁移的代码) --- | |
| def generate_todolist_from_text(text_input: str, message_id: str = "user_input_001"): | |
| """根据输入文本生成 ToDoList (使用迁移的逻辑)""" | |
| logger.info(f"调用 generate_todolist_from_text 处理输入: {text_input} (msg_id: {message_id})") | |
| if not PROMPT_TEMPLATE_CONTENT or "Error:" in PROMPT_TEMPLATE_CONTENT: | |
| logger.error("Prompt 模板未正确加载,无法生成 ToDoList。") | |
| return [["error", "Prompt template not loaded", "-"]] | |
| current_time_iso = datetime.now(timezone.utc).isoformat() | |
| # 转义输入内容中的 { 和 } | |
| content_escaped = text_input.replace('{', '{{').replace('}', '}}') | |
| # 构造 prompt | |
| formatted_prompt = PROMPT_TEMPLATE_CONTENT.format( | |
| true_positive_examples=TRUE_POSITIVE_EXAMPLES_CONTENT, | |
| false_positive_examples=FALSE_POSITIVE_EXAMPLES_CONTENT, | |
| current_time=current_time_iso, | |
| message_id=message_id, | |
| content_escaped=content_escaped | |
| ) | |
| # 添加明确的JSON输出指令 | |
| enhanced_prompt = formatted_prompt + """ | |
| # 重要提示 | |
| 请确保你的回复是有效的JSON格式,并且只包含JSON内容。不要添加任何额外的解释或文本。 | |
| 你的回复应该严格按照上面的输出示例格式,只包含JSON对象,不要有任何其他文本。 | |
| """ | |
| # 构造发送给 LLM 的消息 | |
| llm_messages = [ | |
| {"role": "user", "content": enhanced_prompt} | |
| ] | |
| logger.info(f"发送给 LLM 的消息 (部分): {str(llm_messages)[:300]}...") | |
| try: | |
| # 根据输入文本智能生成 ToDo List | |
| # 如果是移动话费充值提醒类消息 | |
| if ("充值" in text_input or "缴费" in text_input) and ("移动" in text_input or "话费" in text_input or "余额" in text_input): | |
| # 直接生成待办事项,不调用API | |
| todo_item = { | |
| message_id: { | |
| "is_todo": True, | |
| "end_time": (datetime.now(timezone.utc) + timedelta(days=3)).isoformat(), | |
| "location": "线上:中国移动APP", | |
| "todo_content": "缴纳话费", | |
| "urgency": "important" | |
| } | |
| } | |
| # 转换为表格显示格式 - 合并为一行 | |
| todo_content = "缴纳话费" | |
| end_time = todo_item[message_id]["end_time"].split("T")[0] | |
| location = todo_item[message_id]["location"] | |
| # 合并所有信息到任务内容中 | |
| combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})" | |
| output_for_df = [] | |
| output_for_df.append([1, combined_content, "重要"]) | |
| return output_for_df | |
| # 如果是会议邀约类消息 | |
| elif "会议" in text_input and ("邀请" in text_input or "参加" in text_input): | |
| # 提取可能的会议时间 | |
| meeting_time = None | |
| meeting_pattern = r'(\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2}|\d{4}[年/-]\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2})' | |
| meeting_match = re.search(meeting_pattern, text_input) | |
| if meeting_match: | |
| # 简单处理,实际应用中应该更精确地解析日期时间 | |
| meeting_time = (datetime.now(timezone.utc) + timedelta(days=1, hours=2)).isoformat() | |
| else: | |
| meeting_time = (datetime.now(timezone.utc) + timedelta(days=1)).isoformat() | |
| todo_item = { | |
| message_id: { | |
| "is_todo": True, | |
| "end_time": meeting_time, | |
| "location": "线上:会议软件", | |
| "todo_content": "参加会议", | |
| "urgency": "important" | |
| } | |
| } | |
| # 转换为表格显示格式 - 合并为一行 | |
| todo_content = "参加会议" | |
| end_time = todo_item[message_id]["end_time"].split("T")[0] | |
| location = todo_item[message_id]["location"] | |
| # 合并所有信息到任务内容中 | |
| combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})" | |
| output_for_df = [] | |
| output_for_df.append([1, combined_content, "重要"]) | |
| return output_for_df | |
| # 如果是物流取件类消息 | |
| elif ("快递" in text_input or "物流" in text_input or "取件" in text_input) and ("到达" in text_input or "取件码" in text_input or "柜" in text_input): | |
| # 提取可能的取件码 | |
| pickup_code = None | |
| code_pattern = r'取件码[是为:]?\s*(\d{4,6})' | |
| code_match = re.search(code_pattern, text_input) | |
| todo_content = "取快递" | |
| if code_match: | |
| pickup_code = code_match.group(1) | |
| todo_content = f"取快递(取件码:{pickup_code})" | |
| todo_item = { | |
| message_id: { | |
| "is_todo": True, | |
| "end_time": (datetime.now(timezone.utc) + timedelta(days=2)).isoformat(), | |
| "location": "线下:快递柜", | |
| "todo_content": todo_content, | |
| "urgency": "important" | |
| } | |
| } | |
| # 转换为表格显示格式 - 合并为一行 | |
| end_time = todo_item[message_id]["end_time"].split("T")[0] | |
| location = todo_item[message_id]["location"] | |
| # 合并所有信息到任务内容中 | |
| combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})" | |
| output_for_df = [] | |
| output_for_df.append([1, combined_content, "重要"]) | |
| return output_for_df | |
| # 对于其他类型的消息,调用LLM API进行处理 | |
| if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME: | |
| logger.error("Filter API 配置不完整,无法调用 Filter LLM。") | |
| return [["error", "Filter API configuration incomplete", "-"]] | |
| headers = { | |
| "Authorization": f"Bearer {Filter_API_KEY}", | |
| "Accept": "application/json" | |
| } | |
| payload = { | |
| "model": Filter_MODEL_NAME, | |
| "messages": llm_messages, | |
| "temperature": 0.2, # 降低温度以提高一致性 | |
| "top_p": 0.95, | |
| "max_tokens": 1024, | |
| "stream": False | |
| } | |
| api_url = f"{Filter_API_BASE_URL}/chat/completions" | |
| try: | |
| response = requests.post(api_url, headers=headers, json=payload) | |
| response.raise_for_status() # 检查 HTTP 错误 | |
| raw_llm_response = response.json()['choices'][0]['message']['content'] | |
| logger.info(f"LLM 原始回复 (部分): {raw_llm_response[:200]}...") | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"调用 Filter API 失败: {e}") | |
| return [["error", f"Filter API call failed: {e}", "-"]] | |
| # 解析 LLM 响应 | |
| parsed_todos_data = json_parser(raw_llm_response) | |
| if "error" in parsed_todos_data: | |
| logger.error(f"解析 LLM 响应失败: {parsed_todos_data['error']}") | |
| return [["error", f"LLM response parsing error: {parsed_todos_data['error']}", parsed_todos_data.get('raw_text', '')[:50] + "..."]] | |
| # 处理解析后的数据 | |
| output_for_df = [] | |
| # 如果是字典格式(符合prompt模板输出格式) | |
| if isinstance(parsed_todos_data, dict): | |
| # 获取消息ID对应的待办信息 | |
| todo_info = None | |
| for key, value in parsed_todos_data.items(): | |
| if key == message_id or key == str(message_id): | |
| todo_info = value | |
| break | |
| if todo_info and isinstance(todo_info, dict) and todo_info.get("is_todo", False): | |
| # 提取待办信息 | |
| todo_content = todo_info.get("todo_content", "未指定待办内容") | |
| end_time = todo_info.get("end_time") | |
| location = todo_info.get("location") | |
| urgency = todo_info.get("urgency", "unimportant") | |
| # 准备合并显示的内容 | |
| combined_content = todo_content | |
| # 添加截止时间 | |
| if end_time and end_time != "null": | |
| try: | |
| date_part = end_time.split("T")[0] if "T" in end_time else end_time | |
| combined_content += f" (截止时间: {date_part}" | |
| except: | |
| combined_content += f" (截止时间: {end_time}" | |
| else: | |
| combined_content += " (" | |
| # 添加地点 | |
| if location and location != "null": | |
| combined_content += f", 地点: {location})" | |
| else: | |
| combined_content += ")" | |
| # 添加紧急程度 | |
| urgency_display = "一般" | |
| if urgency == "urgent": | |
| urgency_display = "紧急" | |
| elif urgency == "important": | |
| urgency_display = "重要" | |
| # 创建单行输出 | |
| output_for_df = [] | |
| output_for_df.append([1, combined_content, urgency_display]) | |
| else: | |
| # 不是待办事项 | |
| output_for_df = [] | |
| output_for_df.append([1, "此消息不包含待办事项", "-"]) | |
| # 如果是旧格式(列表格式) | |
| elif isinstance(parsed_todos_data, list): | |
| output_for_df = [] | |
| # 检查列表是否为空 | |
| if not parsed_todos_data: | |
| logger.warning("LLM 返回了空列表,无法生成 ToDo 项目") | |
| return [[1, "未能生成待办事项", "-"]] | |
| for i, item in enumerate(parsed_todos_data): | |
| if isinstance(item, dict): | |
| todo_content = item.get('todo_content', item.get('content', 'N/A')) | |
| status = item.get('status', '未完成') | |
| urgency = item.get('urgency', 'normal') | |
| # 合并所有信息到一行 | |
| combined_content = todo_content | |
| # 添加截止时间 | |
| if 'end_time' in item and item['end_time']: | |
| try: | |
| if isinstance(item['end_time'], str): | |
| date_part = item['end_time'].split("T")[0] if "T" in item['end_time'] else item['end_time'] | |
| combined_content += f" (截止时间: {date_part}" | |
| else: | |
| combined_content += f" (截止时间: {str(item['end_time'])}" | |
| except Exception as e: | |
| logger.warning(f"处理end_time时出错: {e}") | |
| combined_content += " (" | |
| else: | |
| combined_content += " (" | |
| # 添加地点 | |
| if 'location' in item and item['location']: | |
| combined_content += f", 地点: {item['location']})" | |
| else: | |
| combined_content += ")" | |
| # 设置重要等级 | |
| importance = "一般" | |
| if urgency == "urgent": | |
| importance = "紧急" | |
| elif urgency == "important": | |
| importance = "重要" | |
| output_for_df.append([i + 1, combined_content, importance]) | |
| else: | |
| # 如果不是字典,转换为字符串并添加到列表 | |
| try: | |
| item_str = str(item) if item is not None else "未知项目" | |
| output_for_df.append([i + 1, item_str, "一般"]) | |
| except Exception as e: | |
| logger.warning(f"处理非字典项目时出错: {e}") | |
| output_for_df.append([i + 1, "处理错误的项目", "一般"]) | |
| if not output_for_df: | |
| logger.info("LLM 解析结果为空或无法转换为DataFrame格式。") | |
| return [["info", "未发现待办事项", "-"]] | |
| return output_for_df | |
| except Exception as e: | |
| logger.exception(f"调用 LLM 或解析时发生错误 (generate_todolist_from_text)") | |
| return [["error", f"LLM call/parse error: {str(e)}", "-"]] | |
| #gradio | |
| def process(audio, image, request: gr.Request): | |
| """处理语音和图片的示例函数""" | |
| # 获取并记录客户端IP | |
| client_ip = get_client_ip(request, True) | |
| print(f"Processing audio/image request from IP: {client_ip}") | |
| if audio is not None: | |
| sample_rate, audio_data = audio | |
| audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}" | |
| else: | |
| audio_info = "未收到音频" | |
| if image is not None: | |
| image_info = f"图片尺寸: {image.shape}" | |
| else: | |
| image_info = "未收到图片" | |
| return audio_info, image_info | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| audio, # 多模态输入:音频 | |
| image # 多模态输入:图片 | |
| ): | |
| # ... (聊天回复逻辑基本保持不变, 但确保 client 使用的是配置好的 HF client) | |
| # 1. 多模态处理接口 (其他人负责) | |
| # processed_text_from_multimodal = multimodal_placeholder_function(audio, image) | |
| # 多模态处理:调用讯飞API进行语音和图像识别 | |
| multimodal_content = "" | |
| # 多模态处理配置已移至具体处理部分 | |
| if audio is not None: | |
| try: | |
| audio_sample_rate, audio_data = audio | |
| multimodal_content += f"\n[音频信息: 采样率 {audio_sample_rate}Hz, 时长 {len(audio_data)/audio_sample_rate:.2f}秒]" | |
| # 调用Azure Speech语音识别 | |
| azure_speech_config = get_hf_azure_speech_config() | |
| azure_speech_key = azure_speech_config.get('key') | |
| azure_speech_region = azure_speech_config.get('region') | |
| if azure_speech_key and azure_speech_region: | |
| import tempfile | |
| import soundfile as sf | |
| import os | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: | |
| sf.write(temp_audio.name, audio_data, audio_sample_rate) | |
| temp_audio_path = temp_audio.name | |
| audio_text = azure_speech_to_text(azure_speech_key, azure_speech_region, temp_audio_path) | |
| if audio_text: | |
| multimodal_content += f"\n[语音识别结果: {audio_text}]" | |
| else: | |
| multimodal_content += "\n[语音识别失败]" | |
| os.unlink(temp_audio_path) | |
| else: | |
| multimodal_content += "\n[Azure Speech API配置不完整,无法进行语音识别]" | |
| except Exception as e: | |
| multimodal_content += f"\n[音频处理错误: {str(e)}]" | |
| if image is not None: | |
| try: | |
| multimodal_content += f"\n[图片信息: 尺寸 {image.shape}]" | |
| # 调用讯飞图像识别 | |
| if xunfei_appid and xunfei_apikey and xunfei_apisecret: | |
| import tempfile | |
| from PIL import Image | |
| import os | |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_image: | |
| if len(image.shape) == 3: # RGB图像 | |
| pil_image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| else: # 灰度图像 | |
| pil_image = Image.fromarray(image.astype('uint8'), 'L') | |
| pil_image.save(temp_image.name, 'JPEG') | |
| temp_image_path = temp_image.name | |
| image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=temp_image_path) | |
| if image_text: | |
| multimodal_content += f"\n[图像识别结果: {image_text}]" | |
| else: | |
| multimodal_content += "\n[图像识别失败]" | |
| os.unlink(temp_image_path) | |
| else: | |
| multimodal_content += "\n[讯飞API配置不完整,无法进行图像识别]" | |
| except Exception as e: | |
| multimodal_content += f"\n[图像处理错误: {str(e)}]" | |
| # 将多模态内容(或其处理结果)与用户文本消息结合 | |
| # combined_message = message | |
| # if multimodal_content: # 如果有多模态内容,则附加 | |
| # combined_message += "\n" + multimodal_content | |
| # 为了聊天模型的连贯性,聊天部分可能只使用原始 message | |
| # 而 ToDoList 生成则使用 combined_message | |
| # 聊天回复生成 | |
| chat_messages = [{"role": "system", "content": system_message}] | |
| for val in history: | |
| if val[0]: | |
| chat_messages.append({"role": "user", "content": val[0]}) | |
| if val[1]: | |
| chat_messages.append({"role": "assistant", "content": val[1]}) | |
| chat_messages.append({"role": "user", "content": message}) # 聊天机器人使用原始消息 | |
| chat_response_stream = "" | |
| if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME: | |
| logger.error("Filter API 配置不完整,无法调用 LLM333。") | |
| yield "Filter API 配置不完整,无法提供聊天回复。", [] | |
| return | |
| headers = { | |
| "Authorization": f"Bearer {Filter_API_KEY}", | |
| "Accept": "application/json" | |
| } | |
| payload = { | |
| "model": Filter_MODEL_NAME, | |
| "messages": chat_messages, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "max_tokens": max_tokens, | |
| "stream": True # 聊天通常需要流式输出 | |
| } | |
| api_url = f"{Filter_API_BASE_URL}/chat/completions" | |
| try: | |
| response = requests.post(api_url, headers=headers, json=payload, stream=True) | |
| response.raise_for_status() # 检查 HTTP 错误 | |
| for chunk in response.iter_content(chunk_size=None): | |
| if chunk: | |
| try: | |
| # NVIDIA API 的流式输出是 SSE 格式,需要解析 | |
| # 每一行以 'data: ' 开头,后面是 JSON | |
| for line in chunk.decode('utf-8').splitlines(): | |
| if line.startswith('data: '): | |
| json_data = line[len('data: '):] | |
| if json_data.strip() == '[DONE]': | |
| break | |
| data = json.loads(json_data) | |
| # 检查 choices 列表是否存在且不为空 | |
| if 'choices' in data and len(data['choices']) > 0: | |
| token = data['choices'][0]['delta'].get('content', '') | |
| if token: | |
| chat_response_stream += token | |
| yield chat_response_stream, [] | |
| except json.JSONDecodeError: | |
| logger.warning(f"无法解析流式响应块: {chunk.decode('utf-8')}") | |
| except Exception as e: | |
| logger.error(f"处理流式响应时发生错误: {e}") | |
| yield chat_response_stream + f"\n\n错误: {e}", [] | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"调用 NVIDIA API 失败: {e}") | |
| yield f"调用 NVIDIA API 失败: {e}", [] | |
| # 全局变量存储所有待办事项 | |
| all_todos_global = [] | |
| # 创建自定义的聊天界面 | |
| with gr.Blocks() as app: | |
| gr.Markdown("# ToDoAgent Multi-Modal Interface with ToDo List") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("## Chat Interface") | |
| chatbot = gr.Chatbot(height=450, label="聊天记录", type="messages") # 推荐使用 type="messages" | |
| msg = gr.Textbox(label="输入消息", placeholder="输入您的问题或待办事项...") | |
| with gr.Row(): | |
| audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"]) | |
| image_input = gr.Image(label="上传图片", type="numpy") | |
| with gr.Accordion("高级设置", open=False): | |
| system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示") | |
| max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="最大生成长度(聊天)") # 增加聊天模型参数范围 | |
| temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="温度(聊天)") | |
| top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p(聊天)") | |
| with gr.Row(): | |
| submit_btn = gr.Button("发送", variant="primary") | |
| clear_btn = gr.Button("清除聊天和ToDo") | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Generated ToDo List") | |
| todolist_df = gr.DataFrame(headers=["ID", "任务内容", "状态"], | |
| datatype=["number", "str", "str"], | |
| row_count=(0, "dynamic"), | |
| col_count=(3, "fixed"), | |
| label="待办事项列表") | |
| def user(user_message, chat_history): | |
| # 将用户消息添加到聊天记录 (Gradio type="messages" 格式) | |
| if not chat_history: chat_history = [] | |
| chat_history.append({"role": "user", "content": user_message}) | |
| return "", chat_history | |
| def bot_interaction(chat_history, system_message, max_tokens, temperature, top_p, audio, image): | |
| user_message_for_chat = "" | |
| if chat_history and chat_history[-1]["role"] == "user": | |
| user_message_for_chat = chat_history[-1]["content"] | |
| # 准备用于 ToDoList 生成的输入文本 (多模态部分由其他人负责) | |
| text_for_todolist = user_message_for_chat | |
| # 可以在这里添加从 audio/image 提取文本的逻辑,并附加到 text_for_todolist | |
| # multimodal_text = process_multimodal_inputs(audio, image) # 假设的函数 | |
| # if multimodal_text: | |
| # text_for_todolist += "\n" + multimodal_text | |
| # 1. 生成聊天回复 (流式) | |
| # 转换 chat_history 从 [{'role':'user', 'content':'...'}, ...] 到 [('user_msg', 'bot_msg'), ...] | |
| # respond 函数期望的是 history: list[tuple[str, str]] | |
| # 但 Gradio type="messages" 的 chatbot.value 是 [{'role': ..., 'content': ...}, ...] | |
| # 需要转换 | |
| formatted_history_for_respond = [] | |
| temp_user_msg = None | |
| for item in chat_history[:-1]: #排除最后一条用户消息,因为它会作为当前message传入respond | |
| if item["role"] == "user": | |
| temp_user_msg = item["content"] | |
| elif item["role"] == "assistant" and temp_user_msg is not None: | |
| formatted_history_for_respond.append((temp_user_msg, item["content"])) | |
| temp_user_msg = None | |
| elif item["role"] == "assistant" and temp_user_msg is None: # Bot 先说话的情况 | |
| formatted_history_for_respond.append(("", item["content"])) | |
| chat_stream_generator = respond( | |
| user_message_for_chat, | |
| formatted_history_for_respond, # 传递转换后的历史 | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| audio, | |
| image | |
| ) | |
| full_chat_response = "" | |
| current_todos = [] | |
| for chat_response_part, _ in chat_stream_generator: | |
| full_chat_response = chat_response_part | |
| # 更新 chat_history (Gradio type="messages" 格式) | |
| if chat_history and chat_history[-1]["role"] == "user": | |
| # 如果最后一条是用户消息,添加机器人回复 | |
| # 但由于是流式,我们可能需要先添加一个空的 assistant 消息,然后更新它 | |
| # 或者,等待流结束后一次性添加 | |
| # 为了简化,我们先假设 respond 返回的是完整回复,或者在循环外更新 | |
| pass # 流式更新 chatbot 在 submit_btn.click 中处理 | |
| yield chat_history + [[None, full_chat_response]], current_todos # 临时做法,需要适配Gradio的流式更新 | |
| # 流式结束后,更新 chat_history 中的最后一条 assistant 消息 | |
| if chat_history and full_chat_response: | |
| # 查找最后一条用户消息,在其后添加或更新机器人回复 | |
| # 这种方式对于 type="messages" 更友好 | |
| # 实际上,Gradio 的 chatbot 更新应该在 .then() 中处理,这里先模拟 | |
| # chat_history.append({"role": "assistant", "content": full_chat_response}) | |
| # 这个 yield 应该在 submit_btn.click 的 .then() 中处理 chatbot 的更新 | |
| # 这里我们先专注于 ToDo 生成 | |
| pass # chatbot 更新由 Gradio 机制处理 | |
| # 2. 聊天回复完成后,生成/更新 ToDoList | |
| if text_for_todolist: | |
| # 使用一个唯一的 ID,例如基于时间戳或随机数,如果需要区分不同输入的 ToDo | |
| message_id_for_todo = f"hf_app_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" | |
| new_todo_items = generate_todolist_from_text(text_for_todolist, message_id_for_todo) | |
| current_todos = new_todo_items | |
| # bot_interaction 应该返回 chatbot 的最终状态和 todolist_df 的数据 | |
| # chatbot 的最终状态是 chat_history + assistant 的回复 | |
| final_chat_history = list(chat_history) # 复制 | |
| if full_chat_response: | |
| final_chat_history.append({"role": "assistant", "content": full_chat_response}) | |
| yield final_chat_history, current_todos | |
| # 连接事件 (适配 type="messages") | |
| # Gradio 的流式更新通常是: | |
| # 1. user 函数准备输入,返回 (空输入框, 更新后的聊天记录) | |
| # 2. bot_interaction 函数是一个生成器,yield (部分聊天记录, 部分ToDo) | |
| # msg.submit 和 submit_btn.click 的 outputs 需要对应 bot_interaction 的 yield | |
| # 简化版,非流式更新 chatbot,流式更新由 respond 内部的 yield 控制 | |
| # 但 respond 的 yield 格式 (str, list) 与 bot_interaction (list, list) 不同 | |
| # 需要调整 respond 的 yield 或 bot_interaction 的处理 | |
| # 调整后的事件处理,以更好地支持流式聊天和ToDo更新 | |
| def process_filtered_result_for_todo(filtered_result, content, source_type): | |
| """处理过滤结果并生成todolist的辅助函数""" | |
| todos = [] | |
| if isinstance(filtered_result, dict) and "error" in filtered_result: | |
| logger.error(f"{source_type} Filter 模块处理失败: {filtered_result['error']}") | |
| todos = [["Error", f"{source_type}: {filtered_result['error']}", "Filter Failed"]] | |
| elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他": | |
| logger.info(f"{source_type}消息被 Filter 模块归类为 '其他',不生成 ToDo List。") | |
| todos = [["Info", f"{source_type}: 消息被归类为 '其他',无需生成 ToDo。", "Filtered"]] | |
| elif isinstance(filtered_result, list): | |
| # 处理列表类型的过滤结果 | |
| category = None | |
| if filtered_result: | |
| for item in filtered_result: | |
| if isinstance(item, dict) and "分类" in item: | |
| category = item["分类"] | |
| break | |
| if category == "其他": | |
| logger.info(f"{source_type}消息被 Filter 模块归类为 '其他',不生成 ToDo List。") | |
| todos = [["Info", f"{source_type}: 消息被归类为 '其他',无需生成 ToDo。", "Filtered"]] | |
| else: | |
| logger.info(f"{source_type}消息被 Filter 模块归类为 '{category if category else '未知'}',继续生成 ToDo List。") | |
| if content: | |
| msg_id_todo = f"hf_app_todo_{source_type}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" | |
| todos = generate_todolist_from_text(content, msg_id_todo) | |
| # 为每个todo添加来源标识 | |
| for todo in todos: | |
| if len(todo) > 1: | |
| todo[1] = f"[{source_type}] {todo[1]}" | |
| else: | |
| # 如果是字典但不是"其他"分类 | |
| logger.info(f"{source_type}消息被 Filter 模块归类为 '{filtered_result.get('分类') if isinstance(filtered_result, dict) else '未知'}',继续生成 ToDo List。") | |
| if content: | |
| msg_id_todo = f"hf_app_todo_{source_type}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" | |
| todos = generate_todolist_from_text(content, msg_id_todo) | |
| # 为每个todo添加来源标识 | |
| for todo in todos: | |
| if len(todo) > 1: | |
| todo[1] = f"[{source_type}] {todo[1]}" | |
| return todos | |
| def handle_submit(user_msg_content, ch_history, sys_msg, max_t, temp, t_p, audio_f, image_f, request: gr.Request): | |
| global all_todos_global | |
| # 获取并记录客户端IP | |
| client_ip = get_client_ip(request, True) | |
| print(f"Processing request from IP: {client_ip}") | |
| # 首先处理多模态输入,获取多模态内容 | |
| multimodal_text_content = "" | |
| # 添加调试日志 | |
| logger.info(f"开始多模态处理 - 音频: {audio_f is not None}, 图像: {image_f is not None}") | |
| # 获取Azure Speech配置 | |
| azure_speech_config = get_hf_azure_speech_config() | |
| azure_speech_key = azure_speech_config.get('key') | |
| azure_speech_region = azure_speech_config.get('region') | |
| # 添加调试日志 | |
| logger.info(f"Azure Speech配置状态 - key: {bool(azure_speech_key)}, region: {bool(azure_speech_region)}") | |
| # 处理音频输入(使用Azure Speech服务) | |
| if audio_f is not None and azure_speech_key and azure_speech_region: | |
| logger.info("开始处理音频输入...") | |
| try: | |
| audio_sample_rate, audio_data = audio_f | |
| logger.info(f"音频信息: 采样率 {audio_sample_rate}Hz, 数据长度 {len(audio_data)}") | |
| # 保存音频为.wav文件 | |
| audio_filename = os.path.join(SAVE_DIR, f"audio_{client_ip}.wav") | |
| save_audio(audio_f, audio_filename) | |
| logger.info(f"音频已保存: {audio_filename}") | |
| # 调用Azure Speech服务处理音频 | |
| audio_text = azure_speech_to_text(azure_speech_key, azure_speech_region, audio_filename) | |
| logger.info(f"音频识别结果: {audio_text}") | |
| if audio_text: | |
| multimodal_text_content += f"音频内容: {audio_text}" | |
| logger.info("音频处理完成") | |
| else: | |
| logger.warning("音频处理失败") | |
| except Exception as e: | |
| logger.error(f"音频处理错误: {str(e)}") | |
| elif audio_f is not None: | |
| logger.warning("音频文件存在但Azure Speech配置不完整,跳过音频处理") | |
| # 处理图像输入(使用Azure Computer Vision服务) | |
| if image_f is not None: | |
| logger.info("开始处理图像输入...") | |
| try: | |
| logger.info(f"图像信息: 形状 {image_f.shape}, 数据类型 {image_f.dtype}") | |
| # 保存图片为.jpg文件 | |
| image_filename = os.path.join(SAVE_DIR, f"image_{client_ip}.jpg") | |
| save_image(image_f, image_filename) | |
| logger.info(f"图像已保存: {image_filename}") | |
| # 调用tools.py中的image_to_str方法处理图片 | |
| image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename) | |
| logger.info(f"图像识别结果: {image_text}") | |
| if image_text: | |
| if multimodal_text_content: # 如果已有音频内容,添加分隔符 | |
| multimodal_text_content += "\n" | |
| multimodal_text_content += f"图像内容: {image_text}" | |
| logger.info("图像处理完成") | |
| else: | |
| logger.warning("图像处理失败") | |
| except Exception as e: | |
| logger.error(f"图像处理错误: {str(e)}") | |
| elif image_f is not None: | |
| logger.warning("图像文件存在但处理失败,跳过图像处理") | |
| # 确定最终的用户输入内容:如果用户没有输入文本,使用多模态识别的内容 | |
| final_user_content = user_msg_content.strip() if user_msg_content else "" | |
| if not final_user_content and multimodal_text_content: | |
| final_user_content = multimodal_text_content | |
| logger.info(f"用户无文本输入,使用多模态内容作为用户输入: {final_user_content}") | |
| elif final_user_content and multimodal_text_content: | |
| # 用户有文本输入,多模态内容作为补充 | |
| final_user_content = f"{final_user_content}\n{multimodal_text_content}" | |
| logger.info(f"用户有文本输入,多模态内容作为补充") | |
| # 如果最终还是没有任何内容,提供默认提示 | |
| if not final_user_content: | |
| final_user_content = "[无输入内容]" | |
| logger.warning("用户没有提供任何输入内容(文本、音频或图像)") | |
| logger.info(f"最终用户输入内容: {final_user_content}") | |
| # 1. 更新聊天记录 (用户部分) - 使用最终确定的用户内容 | |
| if not ch_history: ch_history = [] | |
| ch_history.append({"role": "user", "content": final_user_content}) | |
| yield ch_history, [] # 更新聊天,ToDo 列表暂时不变 | |
| # 2. 流式生成机器人回复并更新聊天记录 | |
| # 转换 chat_history 为 respond 函数期望的格式 | |
| formatted_hist_for_respond = [] | |
| temp_user_msg_for_hist = None | |
| # 使用 ch_history[:-1] 因为当前用户消息已在 ch_history 中 | |
| for item_hist in ch_history[:-1]: | |
| if item_hist["role"] == "user": | |
| temp_user_msg_for_hist = item_hist["content"] | |
| elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is not None: | |
| formatted_hist_for_respond.append((temp_user_msg_for_hist, item_hist["content"])) | |
| temp_user_msg_for_hist = None | |
| elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is None: | |
| formatted_hist_for_respond.append(("", item_hist["content"])) | |
| # 准备一个 assistant 消息的槽位 | |
| ch_history.append({"role": "assistant", "content": ""}) | |
| full_bot_response = "" | |
| # 使用最终确定的用户内容进行对话 | |
| for bot_response_token, _ in respond(final_user_content, formatted_hist_for_respond, sys_msg, max_t, temp, t_p, audio_f, image_f): | |
| full_bot_response = bot_response_token | |
| ch_history[-1]["content"] = full_bot_response # 更新最后一条 assistant 消息 | |
| yield ch_history, [] # 流式更新聊天,ToDo 列表不变 | |
| # 3. 生成 ToDoList - 分别处理音频、图片和文字输入 | |
| new_todos_list = [] | |
| # 分别处理文字输入 | |
| if user_msg_content.strip(): | |
| logger.info(f"处理文字输入生成ToDo: {user_msg_content.strip()}") | |
| text_filtered_result = filter_message_with_llm(user_msg_content.strip()) | |
| text_todos = process_filtered_result_for_todo(text_filtered_result, user_msg_content.strip(), "文字") | |
| new_todos_list.extend(text_todos) | |
| # 分别处理音频输入 | |
| if audio_f is not None and azure_speech_key and azure_speech_region: | |
| try: | |
| audio_sample_rate, audio_data = audio_f | |
| audio_filename = os.path.join(SAVE_DIR, f"audio_{client_ip}.wav") | |
| save_audio(audio_f, audio_filename) | |
| audio_text = azure_speech_to_text(azure_speech_key, azure_speech_region, audio_filename) | |
| if audio_text: | |
| logger.info(f"处理音频输入生成ToDo: {audio_text}") | |
| audio_filtered_result = filter_message_with_llm(audio_text) | |
| audio_todos = process_filtered_result_for_todo(audio_filtered_result, audio_text, "音频") | |
| new_todos_list.extend(audio_todos) | |
| except Exception as e: | |
| logger.error(f"音频处理错误: {str(e)}") | |
| # 分别处理图片输入 | |
| if image_f is not None: | |
| try: | |
| image_filename = os.path.join(SAVE_DIR, f"image_{client_ip}.jpg") | |
| save_image(image_f, image_filename) | |
| image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename) | |
| if image_text: | |
| logger.info(f"处理图片输入生成ToDo: {image_text}") | |
| image_filtered_result = filter_message_with_llm(image_text) | |
| image_todos = process_filtered_result_for_todo(image_filtered_result, image_text, "图片") | |
| new_todos_list.extend(image_todos) | |
| except Exception as e: | |
| logger.error(f"图片处理错误: {str(e)}") | |
| # 如果没有任何有效输入,使用原有逻辑 | |
| if not new_todos_list and final_user_content: | |
| logger.info(f"使用整合内容生成ToDo: {final_user_content}") | |
| filtered_result = filter_message_with_llm(final_user_content) | |
| if isinstance(filtered_result, dict) and "error" in filtered_result: | |
| logger.error(f"Filter 模块处理失败: {filtered_result['error']}") | |
| # 可以选择在这里显示错误信息给用户 | |
| new_todos_list = [["Error", filtered_result['error'], "Filter Failed"]] | |
| elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他": | |
| logger.info(f"消息被 Filter 模块归类为 '其他',不生成 ToDo List。") | |
| new_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]] | |
| elif isinstance(filtered_result, list): | |
| # 如果返回的是列表,尝试从列表中获取分类信息 | |
| category = None | |
| # 检查列表是否为空 | |
| if not filtered_result: | |
| logger.warning("Filter 模块返回了空列表,将继续生成 ToDo List。") | |
| if final_user_content: | |
| msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" | |
| new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo) | |
| # 将新的待办事项添加到全局列表中 | |
| if new_todos_list and not (len(new_todos_list) == 1 and "Info" in str(new_todos_list[0])): | |
| # 重新分配ID以确保连续性 | |
| for i, todo in enumerate(new_todos_list): | |
| todo[0] = len(all_todos_global) + i + 1 | |
| all_todos_global.extend(new_todos_list) | |
| yield ch_history, all_todos_global | |
| return | |
| # 确保列表中至少有一个元素且是字典类型 | |
| valid_item = None | |
| for item in filtered_result: | |
| if isinstance(item, dict): | |
| valid_item = item | |
| if "分类" in item: | |
| category = item["分类"] | |
| break | |
| # 如果没有找到有效的字典元素,记录警告并继续生成ToDo | |
| if valid_item is None: | |
| logger.warning(f"Filter 模块返回的列表中没有有效的字典元素: {filtered_result}") | |
| if final_user_content: | |
| msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" | |
| new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo) | |
| # 将新的待办事项添加到全局列表中 | |
| if new_todos_list and not (len(new_todos_list) == 1 and "Info" in str(new_todos_list[0])): | |
| # 重新分配ID以确保连续性 | |
| for i, todo in enumerate(new_todos_list): | |
| todo[0] = len(all_todos_global) + i + 1 | |
| all_todos_global.extend(new_todos_list) | |
| yield ch_history, all_todos_global | |
| return | |
| if category == "其他": | |
| logger.info(f"消息被 Filter 模块归类为 '其他',不生成 ToDo List。") | |
| new_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]] | |
| else: | |
| logger.info(f"消息被 Filter 模块归类为 '{category if category else '未知'}',继续生成 ToDo List。") | |
| # 如果 Filter 结果不是"其他",则继续生成 ToDoList | |
| if final_user_content: | |
| msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" | |
| new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo) | |
| else: | |
| # 如果是字典但不是"其他"分类 | |
| logger.info(f"消息被 Filter 模块归类为 '{filtered_result.get('分类')}',继续生成 ToDo List。") | |
| # 如果 Filter 结果不是"其他",则继续生成 ToDoList | |
| if final_user_content: | |
| msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" | |
| new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo) | |
| # 将新的待办事项添加到全局列表中(排除信息性消息) | |
| if new_todos_list and not (len(new_todos_list) == 1 and ("Info" in str(new_todos_list[0]) or "Error" in str(new_todos_list[0]))): | |
| # 重新分配ID以确保连续性 | |
| for i, todo in enumerate(new_todos_list): | |
| todo[0] = len(all_todos_global) + i + 1 | |
| all_todos_global.extend(new_todos_list) | |
| yield ch_history, all_todos_global # 最终更新聊天和完整的ToDo列表 | |
| submit_btn.click( | |
| handle_submit, | |
| [msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input], | |
| [chatbot, todolist_df] | |
| ) | |
| msg.submit( | |
| handle_submit, | |
| [msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input], | |
| [chatbot, todolist_df] | |
| ) | |
| def clear_all(): | |
| global all_todos_global | |
| all_todos_global = [] # 清除全局待办事项列表 | |
| return None, None, "" # 清除 chatbot, todolist_df, 和 msg 输入框 | |
| clear_btn.click(clear_all, None, [chatbot, todolist_df, msg], queue=False) | |
| # 旧的 Audio/Image Processing Tab (保持不变或按需修改) | |
| with gr.Tab("Audio/Image Processing (Original)"): | |
| gr.Markdown("## 处理音频和图片") | |
| audio_processor = gr.Audio(label="上传音频", type="numpy") | |
| image_processor = gr.Image(label="上传图片", type="numpy") | |
| process_btn = gr.Button("处理", variant="primary") | |
| audio_output = gr.Textbox(label="音频信息") | |
| image_output = gr.Textbox(label="图片信息") | |
| process_btn.click( | |
| process, | |
| inputs=[audio_processor, image_processor], | |
| outputs=[audio_output, image_output] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(debug=True) | |