File size: 13,462 Bytes
4343907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
"""
SAAP WebSocket Manager Service - Real-time Communication
Production-ready WebSocket connection management for live agent updates
"""

import asyncio
import json
import logging
from typing import Set, Dict, Any, Optional
from datetime import datetime

from fastapi import WebSocket, WebSocketDisconnect
from models.agent import SaapAgent

logger = logging.getLogger(__name__)

class WebSocketConnection:
    """Individual WebSocket connection wrapper"""
    
    def __init__(self, websocket: WebSocket, client_id: Optional[str] = None):
        self.websocket = websocket
        self.client_id = client_id or f"client_{id(websocket)}"
        self.connected_at = datetime.utcnow()
        self.is_alive = True
        
    async def send_message(self, message: Dict[str, Any]):
        """Send message to this WebSocket connection"""
        try:
            if self.is_alive:
                await self.websocket.send_text(json.dumps(message))
        except Exception as e:
            logger.warning(f"⚠️ Failed to send to {self.client_id}: {e}")
            self.is_alive = False
    
    async def send_ping(self):
        """Send ping to keep connection alive"""
        try:
            if self.is_alive:
                await self.websocket.send_text(json.dumps({
                    "type": "ping",
                    "timestamp": datetime.utcnow().isoformat()
                }))
        except Exception:
            self.is_alive = False

class WebSocketManager:
    """
    Production-ready WebSocket Manager
    Features:
    - Multi-client connection management
    - Real-time agent status broadcasts
    - Message history and statistics
    - Connection health monitoring
    - Automatic cleanup of dead connections
    """
    
    def __init__(self):
        self.active_connections: Set[WebSocketConnection] = set()
        self.client_connections: Dict[str, WebSocketConnection] = {}
        self.message_history: List[Dict[str, Any]] = []
        self.max_history_size = 100
        self.stats = {
            "total_connections": 0,
            "current_connections": 0,
            "messages_sent": 0,
            "broadcasts_sent": 0
        }
        
        # Start periodic cleanup task
        self._cleanup_task = None
        
    async def start_cleanup_task(self):
        """Start periodic cleanup of dead connections"""
        if not self._cleanup_task:
            self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
    
    async def _periodic_cleanup(self):
        """Periodically clean up dead connections"""
        try:
            while True:
                await asyncio.sleep(30)  # Check every 30 seconds
                await self._cleanup_dead_connections()
        except asyncio.CancelledError:
            logger.info("πŸ”§ WebSocket cleanup task cancelled")
    
    async def _cleanup_dead_connections(self):
        """Remove dead/closed WebSocket connections"""
        dead_connections = set()
        
        for connection in self.active_connections:
            if not connection.is_alive:
                dead_connections.add(connection)
                continue
                
            # Try to ping the connection
            try:
                await connection.send_ping()
            except Exception:
                connection.is_alive = False
                dead_connections.add(connection)
        
        # Remove dead connections
        for connection in dead_connections:
            self.disconnect(connection.websocket, log=False)
        
        if dead_connections:
            logger.info(f"🧹 Cleaned up {len(dead_connections)} dead WebSocket connections")
    
    async def connect(self, websocket: WebSocket, client_id: Optional[str] = None):
        """Accept new WebSocket connection"""
        try:
            await websocket.accept()
            
            connection = WebSocketConnection(websocket, client_id)
            self.active_connections.add(connection)
            self.client_connections[connection.client_id] = connection
            
            # Update statistics
            self.stats["total_connections"] += 1
            self.stats["current_connections"] = len(self.active_connections)
            
            logger.info(f"βœ… WebSocket connected: {connection.client_id} (Total: {len(self.active_connections)})\"")
            
            # Send welcome message with connection info
            await connection.send_message({
                "type": "connection_established",
                "client_id": connection.client_id,
                "server_time": datetime.utcnow().isoformat(),
                "message": "Connected to SAAP WebSocket server"
            })
            
            # Send current stats
            await connection.send_message({
                "type": "stats_update",
                "data": await self.get_connection_stats()
            })
            
            # Start cleanup task if not already running
            await self.start_cleanup_task()
            
        except Exception as e:
            logger.error(f"❌ WebSocket connection failed: {e}")
    
    def disconnect(self, websocket: WebSocket, log: bool = True):
        """Remove WebSocket connection"""
        connection_to_remove = None
        
        for connection in self.active_connections:
            if connection.websocket == websocket:
                connection_to_remove = connection
                break
        
        if connection_to_remove:
            self.active_connections.discard(connection_to_remove)
            self.client_connections.pop(connection_to_remove.client_id, None)
            connection_to_remove.is_alive = False
            
            self.stats["current_connections"] = len(self.active_connections)
            
            if log:
                logger.info(f"πŸ”Œ WebSocket disconnected: {connection_to_remove.client_id} (Remaining: {len(self.active_connections)})")
    
    async def broadcast_message(self, message: Dict[str, Any]):
        """Broadcast message to all connected clients"""
        if not self.active_connections:
            return
        
        try:
            # Add timestamp and message type
            broadcast_message = {
                **message,
                "timestamp": datetime.utcnow().isoformat(),
                "broadcast": True
            }
            
            # Send to all connections
            dead_connections = set()
            successful_sends = 0
            
            for connection in self.active_connections:
                try:
                    await connection.send_message(broadcast_message)
                    successful_sends += 1
                except Exception as e:
                    logger.warning(f"⚠️ Broadcast failed for {connection.client_id}: {e}")
                    dead_connections.add(connection)
            
            # Remove dead connections
            for connection in dead_connections:
                self.disconnect(connection.websocket, log=False)
            
            # Update statistics
            self.stats["messages_sent"] += successful_sends
            self.stats["broadcasts_sent"] += 1
            
            # Add to message history
            self._add_to_history(broadcast_message)
            
            logger.debug(f"πŸ“’ Broadcast sent to {successful_sends} clients: {message.get('type', 'unknown')}")
            
        except Exception as e:
            logger.error(f"❌ Broadcast failed: {e}")
    
    async def send_to_client(self, client_id: str, message: Dict[str, Any]):
        """Send message to specific client"""
        connection = self.client_connections.get(client_id)
        if not connection:
            logger.warning(f"⚠️ Client not found: {client_id}")
            return False
        
        try:
            message_with_timestamp = {
                **message,
                "timestamp": datetime.utcnow().isoformat(),
                "targeted": True,
                "client_id": client_id
            }
            
            await connection.send_message(message_with_timestamp)
            self.stats["messages_sent"] += 1
            
            logger.debug(f"πŸ“€ Message sent to {client_id}: {message.get('type', 'unknown')}")
            return True
            
        except Exception as e:
            logger.error(f"❌ Failed to send to {client_id}: {e}")
            self.disconnect(connection.websocket)
            return False
    
    def _add_to_history(self, message: Dict[str, Any]):
        """Add message to broadcast history"""
        self.message_history.append(message)
        
        # Trim history if too long
        if len(self.message_history) > self.max_history_size:
            self.message_history = self.message_history[-self.max_history_size:]
    
    async def get_connection_stats(self) -> Dict[str, Any]:
        """Get WebSocket connection statistics"""
        return {
            **self.stats,
            "current_connections": len(self.active_connections),
            "client_ids": [conn.client_id for conn in self.active_connections],
            "message_history_size": len(self.message_history),
            "server_time": datetime.utcnow().isoformat()
        }
    
    # SAAP-specific broadcast methods
    
    async def broadcast_agent_update(self, agent: SaapAgent):
        """Broadcast agent status update to all clients"""
        # Handle both Enum and string status/type values
        status_value = agent.status.value if hasattr(agent.status, 'value') else str(agent.status)
        type_value = agent.type.value if hasattr(agent.type, 'value') else str(agent.type)
        
        # Handle optional metrics attribute
        last_active = None
        if hasattr(agent, 'metrics') and agent.metrics:
            if hasattr(agent.metrics, 'last_active') and agent.metrics.last_active:
                last_active = agent.metrics.last_active.isoformat()
        
        await self.broadcast_message({
            "type": "agent_update",
            "data": {
                "agent_id": agent.id,
                "name": agent.name,
                "status": status_value,
                "type": type_value,
                "last_active": last_active,
                "capabilities": agent.capabilities
            }
        })
    
    async def broadcast_agent_deleted(self, agent_id: str):
        """Broadcast agent deletion to all clients"""
        await self.broadcast_message({
            "type": "agent_deleted",
            "data": {
                "agent_id": agent_id
            }
        })
    
    async def broadcast_message_update(self, message_data: Dict[str, Any]):
        """Broadcast new agent message/response to all clients"""
        await self.broadcast_message({
            "type": "agent_message",
            "data": message_data
        })
    
    async def broadcast_system_status(self, status_data: Dict[str, Any]):
        """Broadcast system status update"""
        await self.broadcast_message({
            "type": "system_status",
            "data": status_data
        })
    
    async def broadcast_error(self, error_message: str, error_code: Optional[str] = None):
        """Broadcast system error to all clients"""
        await self.broadcast_message({
            "type": "system_error",
            "data": {
                "message": error_message,
                "code": error_code,
                "severity": "error"
            }
        })
    
    async def shutdown(self):
        """Gracefully shutdown WebSocket manager"""
        try:
            logger.info("πŸ”§ Shutting down WebSocket manager...")
            
            # Cancel cleanup task
            if self._cleanup_task:
                self._cleanup_task.cancel()
                try:
                    await self._cleanup_task
                except asyncio.CancelledError:
                    pass
            
            # Send shutdown notification to all clients
            await self.broadcast_message({
                "type": "server_shutdown",
                "data": {
                    "message": "SAAP server is shutting down",
                    "timestamp": datetime.utcnow().isoformat()
                }
            })
            
            # Close all connections
            for connection in list(self.active_connections):
                try:
                    await connection.websocket.close()
                except Exception:
                    pass
                self.disconnect(connection.websocket, log=False)
            
            logger.info("βœ… WebSocket manager shutdown complete")
            
        except Exception as e:
            logger.error(f"❌ WebSocket shutdown error: {e}")

if __name__ == "__main__":
    async def test_websocket_manager():
        """Test WebSocket manager functionality"""
        manager = WebSocketManager()
        
        # Simulate stats
        print("πŸ“Š WebSocket Manager Stats:")
        stats = await manager.get_connection_stats()
        print(json.dumps(stats, indent=2))
        
        # Test broadcast (no clients, but should not error)
        await manager.broadcast_message({
            "type": "test",
            "message": "Test broadcast"
        })
        
        print("βœ… WebSocket manager test completed")
    
    asyncio.run(test_websocket_manager())