Spaces:
Paused
Paused
| import aiohttp | |
| import json | |
| import urllib | |
| from networks import ( | |
| ChathubRequestPayloadConstructor, | |
| ConversationRequestHeadersConstructor, | |
| MessageParser, | |
| OpenaiStreamOutputer, | |
| ) | |
| from conversations import ConversationStyle | |
| from utils.logger import logger | |
| from utils.enver import enver | |
| class ConversationConnector: | |
| """ | |
| Input params: | |
| - `sec_access_token`, `client_id`, `conversation_id` | |
| - Generated by `ConversationCreator` | |
| - `invocation_id` (int): | |
| - For 1st request, this value must be `0`. | |
| - For all requests after, any integer is valid. | |
| - To make it simple, use `1` for all requests after the 1st one. | |
| """ | |
| def __init__( | |
| self, | |
| conversation_style: str = ConversationStyle.PRECISE.value, | |
| sec_access_token: str = "", | |
| client_id: str = "", | |
| conversation_id: str = "", | |
| invocation_id: int = 0, | |
| cookies={}, | |
| ): | |
| conversation_style_enum_values = [ | |
| style.value for style in ConversationStyle.__members__.values() | |
| ] | |
| if conversation_style.lower() not in conversation_style_enum_values: | |
| self.conversation_style = ConversationStyle.PRECISE.value | |
| else: | |
| self.conversation_style = conversation_style.lower() | |
| print(f"Model: [{self.conversation_style}]") | |
| self.sec_access_token = sec_access_token | |
| self.client_id = client_id | |
| self.conversation_id = conversation_id | |
| self.invocation_id = invocation_id | |
| self.cookies = cookies | |
| async def wss_send(self, message): | |
| serialized_websocket_message = json.dumps(message, ensure_ascii=False) + "\x1e" | |
| await self.wss.send_str(serialized_websocket_message) | |
| async def init_handshake(self): | |
| await self.wss_send({"protocol": "json", "version": 1}) | |
| await self.wss.receive_str() | |
| await self.wss_send({"type": 6}) | |
| async def connect(self): | |
| self.quotelized_sec_access_token = urllib.parse.quote(self.sec_access_token) | |
| self.ws_url = ( | |
| f"wss://sydney.bing.com/sydney/ChatHub" | |
| f"?sec_access_token={self.quotelized_sec_access_token}" | |
| ) | |
| self.aiohttp_session = aiohttp.ClientSession(cookies=self.cookies) | |
| headers_constructor = ConversationRequestHeadersConstructor() | |
| enver.set_envs(proxies=True) | |
| self.wss = await self.aiohttp_session.ws_connect( | |
| self.ws_url, | |
| headers=headers_constructor.request_headers, | |
| proxy=enver.proxy, | |
| ) | |
| await self.init_handshake() | |
| async def send_chathub_request(self, prompt: str, system_prompt: str = None): | |
| payload_constructor = ChathubRequestPayloadConstructor( | |
| prompt=prompt, | |
| conversation_style=self.conversation_style, | |
| client_id=self.client_id, | |
| conversation_id=self.conversation_id, | |
| invocation_id=self.invocation_id, | |
| system_prompt=system_prompt, | |
| ) | |
| self.connect_request_payload = payload_constructor.request_payload | |
| await self.wss_send(self.connect_request_payload) | |
| async def stream_chat( | |
| self, prompt: str = "", system_prompt: str = None, yield_output=False | |
| ): | |
| await self.connect() | |
| await self.send_chathub_request(prompt=prompt, system_prompt=system_prompt) | |
| message_parser = MessageParser(outputer=OpenaiStreamOutputer()) | |
| has_output_role_message = False | |
| if yield_output and not has_output_role_message: | |
| has_output_role_message = True | |
| yield message_parser.outputer.output(content="", content_type="Role") | |
| while not self.wss.closed: | |
| response_lines_str = await self.wss.receive_str() | |
| if isinstance(response_lines_str, str): | |
| response_lines = response_lines_str.split("\x1e") | |
| else: | |
| continue | |
| for line in response_lines: | |
| if not line: | |
| continue | |
| data = json.loads(line) | |
| # Stream: Meaningful Messages | |
| if data.get("type") == 1: | |
| if yield_output: | |
| output = message_parser.parse(data, return_output=True) | |
| if isinstance(output, list): | |
| for item in output: | |
| yield item | |
| else: | |
| if output: | |
| yield output | |
| else: | |
| message_parser.parse(data) | |
| # Stream: List of all messages in the whole conversation | |
| elif data.get("type") == 2: | |
| if data.get("item"): | |
| # item = data.get("item") | |
| # logger.note("\n[Saving chat messages ...]") | |
| pass | |
| # Stream: End of Conversation | |
| elif data.get("type") == 3: | |
| finished_str = "\n[Finished]" | |
| logger.success(finished_str) | |
| self.invocation_id += 1 | |
| await self.wss.close() | |
| await self.aiohttp_session.close() | |
| if yield_output: | |
| yield message_parser.outputer.output( | |
| content=finished_str, content_type="Finished" | |
| ) | |
| break | |
| # Stream: Heartbeat Signal | |
| elif data.get("type") == 6: | |
| continue | |
| # Stream: Not Implemented | |
| else: | |
| continue | |