Spaces:
Sleeping
Sleeping
| """ | |
| SAAP WebSocket Manager Service - Real-time Communication | |
| Production-ready WebSocket connection management for live agent updates | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| from typing import Set, Dict, Any, Optional | |
| from datetime import datetime | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from models.agent import SaapAgent | |
| logger = logging.getLogger(__name__) | |
| class WebSocketConnection: | |
| """Individual WebSocket connection wrapper""" | |
| def __init__(self, websocket: WebSocket, client_id: Optional[str] = None): | |
| self.websocket = websocket | |
| self.client_id = client_id or f"client_{id(websocket)}" | |
| self.connected_at = datetime.utcnow() | |
| self.is_alive = True | |
| async def send_message(self, message: Dict[str, Any]): | |
| """Send message to this WebSocket connection""" | |
| try: | |
| if self.is_alive: | |
| await self.websocket.send_text(json.dumps(message)) | |
| except Exception as e: | |
| logger.warning(f"β οΈ Failed to send to {self.client_id}: {e}") | |
| self.is_alive = False | |
| async def send_ping(self): | |
| """Send ping to keep connection alive""" | |
| try: | |
| if self.is_alive: | |
| await self.websocket.send_text(json.dumps({ | |
| "type": "ping", | |
| "timestamp": datetime.utcnow().isoformat() | |
| })) | |
| except Exception: | |
| self.is_alive = False | |
| class WebSocketManager: | |
| """ | |
| Production-ready WebSocket Manager | |
| Features: | |
| - Multi-client connection management | |
| - Real-time agent status broadcasts | |
| - Message history and statistics | |
| - Connection health monitoring | |
| - Automatic cleanup of dead connections | |
| """ | |
| def __init__(self): | |
| self.active_connections: Set[WebSocketConnection] = set() | |
| self.client_connections: Dict[str, WebSocketConnection] = {} | |
| self.message_history: List[Dict[str, Any]] = [] | |
| self.max_history_size = 100 | |
| self.stats = { | |
| "total_connections": 0, | |
| "current_connections": 0, | |
| "messages_sent": 0, | |
| "broadcasts_sent": 0 | |
| } | |
| # Start periodic cleanup task | |
| self._cleanup_task = None | |
| async def start_cleanup_task(self): | |
| """Start periodic cleanup of dead connections""" | |
| if not self._cleanup_task: | |
| self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) | |
| async def _periodic_cleanup(self): | |
| """Periodically clean up dead connections""" | |
| try: | |
| while True: | |
| await asyncio.sleep(30) # Check every 30 seconds | |
| await self._cleanup_dead_connections() | |
| except asyncio.CancelledError: | |
| logger.info("π§ WebSocket cleanup task cancelled") | |
| async def _cleanup_dead_connections(self): | |
| """Remove dead/closed WebSocket connections""" | |
| dead_connections = set() | |
| for connection in self.active_connections: | |
| if not connection.is_alive: | |
| dead_connections.add(connection) | |
| continue | |
| # Try to ping the connection | |
| try: | |
| await connection.send_ping() | |
| except Exception: | |
| connection.is_alive = False | |
| dead_connections.add(connection) | |
| # Remove dead connections | |
| for connection in dead_connections: | |
| self.disconnect(connection.websocket, log=False) | |
| if dead_connections: | |
| logger.info(f"π§Ή Cleaned up {len(dead_connections)} dead WebSocket connections") | |
| async def connect(self, websocket: WebSocket, client_id: Optional[str] = None): | |
| """Accept new WebSocket connection""" | |
| try: | |
| await websocket.accept() | |
| connection = WebSocketConnection(websocket, client_id) | |
| self.active_connections.add(connection) | |
| self.client_connections[connection.client_id] = connection | |
| # Update statistics | |
| self.stats["total_connections"] += 1 | |
| self.stats["current_connections"] = len(self.active_connections) | |
| logger.info(f"β WebSocket connected: {connection.client_id} (Total: {len(self.active_connections)})\"") | |
| # Send welcome message with connection info | |
| await connection.send_message({ | |
| "type": "connection_established", | |
| "client_id": connection.client_id, | |
| "server_time": datetime.utcnow().isoformat(), | |
| "message": "Connected to SAAP WebSocket server" | |
| }) | |
| # Send current stats | |
| await connection.send_message({ | |
| "type": "stats_update", | |
| "data": await self.get_connection_stats() | |
| }) | |
| # Start cleanup task if not already running | |
| await self.start_cleanup_task() | |
| except Exception as e: | |
| logger.error(f"β WebSocket connection failed: {e}") | |
| def disconnect(self, websocket: WebSocket, log: bool = True): | |
| """Remove WebSocket connection""" | |
| connection_to_remove = None | |
| for connection in self.active_connections: | |
| if connection.websocket == websocket: | |
| connection_to_remove = connection | |
| break | |
| if connection_to_remove: | |
| self.active_connections.discard(connection_to_remove) | |
| self.client_connections.pop(connection_to_remove.client_id, None) | |
| connection_to_remove.is_alive = False | |
| self.stats["current_connections"] = len(self.active_connections) | |
| if log: | |
| logger.info(f"π WebSocket disconnected: {connection_to_remove.client_id} (Remaining: {len(self.active_connections)})") | |
| async def broadcast_message(self, message: Dict[str, Any]): | |
| """Broadcast message to all connected clients""" | |
| if not self.active_connections: | |
| return | |
| try: | |
| # Add timestamp and message type | |
| broadcast_message = { | |
| **message, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "broadcast": True | |
| } | |
| # Send to all connections | |
| dead_connections = set() | |
| successful_sends = 0 | |
| for connection in self.active_connections: | |
| try: | |
| await connection.send_message(broadcast_message) | |
| successful_sends += 1 | |
| except Exception as e: | |
| logger.warning(f"β οΈ Broadcast failed for {connection.client_id}: {e}") | |
| dead_connections.add(connection) | |
| # Remove dead connections | |
| for connection in dead_connections: | |
| self.disconnect(connection.websocket, log=False) | |
| # Update statistics | |
| self.stats["messages_sent"] += successful_sends | |
| self.stats["broadcasts_sent"] += 1 | |
| # Add to message history | |
| self._add_to_history(broadcast_message) | |
| logger.debug(f"π’ Broadcast sent to {successful_sends} clients: {message.get('type', 'unknown')}") | |
| except Exception as e: | |
| logger.error(f"β Broadcast failed: {e}") | |
| async def send_to_client(self, client_id: str, message: Dict[str, Any]): | |
| """Send message to specific client""" | |
| connection = self.client_connections.get(client_id) | |
| if not connection: | |
| logger.warning(f"β οΈ Client not found: {client_id}") | |
| return False | |
| try: | |
| message_with_timestamp = { | |
| **message, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "targeted": True, | |
| "client_id": client_id | |
| } | |
| await connection.send_message(message_with_timestamp) | |
| self.stats["messages_sent"] += 1 | |
| logger.debug(f"π€ Message sent to {client_id}: {message.get('type', 'unknown')}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Failed to send to {client_id}: {e}") | |
| self.disconnect(connection.websocket) | |
| return False | |
| def _add_to_history(self, message: Dict[str, Any]): | |
| """Add message to broadcast history""" | |
| self.message_history.append(message) | |
| # Trim history if too long | |
| if len(self.message_history) > self.max_history_size: | |
| self.message_history = self.message_history[-self.max_history_size:] | |
| async def get_connection_stats(self) -> Dict[str, Any]: | |
| """Get WebSocket connection statistics""" | |
| return { | |
| **self.stats, | |
| "current_connections": len(self.active_connections), | |
| "client_ids": [conn.client_id for conn in self.active_connections], | |
| "message_history_size": len(self.message_history), | |
| "server_time": datetime.utcnow().isoformat() | |
| } | |
| # SAAP-specific broadcast methods | |
| async def broadcast_agent_update(self, agent: SaapAgent): | |
| """Broadcast agent status update to all clients""" | |
| # Handle both Enum and string status/type values | |
| status_value = agent.status.value if hasattr(agent.status, 'value') else str(agent.status) | |
| type_value = agent.type.value if hasattr(agent.type, 'value') else str(agent.type) | |
| # Handle optional metrics attribute | |
| last_active = None | |
| if hasattr(agent, 'metrics') and agent.metrics: | |
| if hasattr(agent.metrics, 'last_active') and agent.metrics.last_active: | |
| last_active = agent.metrics.last_active.isoformat() | |
| await self.broadcast_message({ | |
| "type": "agent_update", | |
| "data": { | |
| "agent_id": agent.id, | |
| "name": agent.name, | |
| "status": status_value, | |
| "type": type_value, | |
| "last_active": last_active, | |
| "capabilities": agent.capabilities | |
| } | |
| }) | |
| async def broadcast_agent_deleted(self, agent_id: str): | |
| """Broadcast agent deletion to all clients""" | |
| await self.broadcast_message({ | |
| "type": "agent_deleted", | |
| "data": { | |
| "agent_id": agent_id | |
| } | |
| }) | |
| async def broadcast_message_update(self, message_data: Dict[str, Any]): | |
| """Broadcast new agent message/response to all clients""" | |
| await self.broadcast_message({ | |
| "type": "agent_message", | |
| "data": message_data | |
| }) | |
| async def broadcast_system_status(self, status_data: Dict[str, Any]): | |
| """Broadcast system status update""" | |
| await self.broadcast_message({ | |
| "type": "system_status", | |
| "data": status_data | |
| }) | |
| async def broadcast_error(self, error_message: str, error_code: Optional[str] = None): | |
| """Broadcast system error to all clients""" | |
| await self.broadcast_message({ | |
| "type": "system_error", | |
| "data": { | |
| "message": error_message, | |
| "code": error_code, | |
| "severity": "error" | |
| } | |
| }) | |
| async def shutdown(self): | |
| """Gracefully shutdown WebSocket manager""" | |
| try: | |
| logger.info("π§ Shutting down WebSocket manager...") | |
| # Cancel cleanup task | |
| if self._cleanup_task: | |
| self._cleanup_task.cancel() | |
| try: | |
| await self._cleanup_task | |
| except asyncio.CancelledError: | |
| pass | |
| # Send shutdown notification to all clients | |
| await self.broadcast_message({ | |
| "type": "server_shutdown", | |
| "data": { | |
| "message": "SAAP server is shutting down", | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| }) | |
| # Close all connections | |
| for connection in list(self.active_connections): | |
| try: | |
| await connection.websocket.close() | |
| except Exception: | |
| pass | |
| self.disconnect(connection.websocket, log=False) | |
| logger.info("β WebSocket manager shutdown complete") | |
| except Exception as e: | |
| logger.error(f"β WebSocket shutdown error: {e}") | |
| if __name__ == "__main__": | |
| async def test_websocket_manager(): | |
| """Test WebSocket manager functionality""" | |
| manager = WebSocketManager() | |
| # Simulate stats | |
| print("π WebSocket Manager Stats:") | |
| stats = await manager.get_connection_stats() | |
| print(json.dumps(stats, indent=2)) | |
| # Test broadcast (no clients, but should not error) | |
| await manager.broadcast_message({ | |
| "type": "test", | |
| "message": "Test broadcast" | |
| }) | |
| print("β WebSocket manager test completed") | |
| asyncio.run(test_websocket_manager()) | |