import json from typing import Dict, List from fastapi import WebSocket, Depends, WebSocketDisconnect from pydantic import parse_obj_as from app.users.chat.dao import ChatDAO from app.users.auth import validate_user_access_to_chat from app.users.chat.shemas import SMessageSchema, SMessage from app.users.models import Users from app.users.chat.router import router class ConnectionManager(WebSocket): def __init__(self): self.active_connections: dict[int, list[WebSocket]] = {} async def connect(self, chat_id: int, websocket: WebSocket): await websocket.accept() if chat_id not in self.active_connections: self.active_connections[chat_id] = [] self.active_connections[chat_id].append(websocket) def disconnect(self, chat_id: int, websocket: WebSocket): self.active_connections[chat_id].remove(websocket) async def broadcast(self, user_id: int, chat_id: int, message: str): new_message = await self.add_message_to_database(user_id=user_id, chat_id=chat_id, message=message) new_message = dict(new_message) new_message['created_at'] = new_message['created_at'].isoformat() for websocket in self.active_connections[chat_id]: await websocket.send_json(new_message) @staticmethod async def add_message_to_database(user_id: int, chat_id: int, message: str): result = await ChatDAO.send_message(user_id=user_id, chat_id=chat_id, message=message) new_message = await ChatDAO.get_message_by_id(message_id=result) return new_message manager = ConnectionManager() @router.websocket("/ws/{chat_id}") async def websocket_endpoint(chat_id: int, user_id: int, websocket: WebSocket): await validate_user_access_to_chat(user_id=user_id, chat_id=chat_id) await manager.connect(chat_id, websocket) try: while True: data = await websocket.receive_text() await manager.broadcast(user_id=user_id, chat_id=chat_id, message=data) except WebSocketDisconnect: manager.disconnect(chat_id, websocket)