53 lines
1.8 KiB
Python
53 lines
1.8 KiB
Python
from typing import Dict, List
|
|
|
|
from fastapi import WebSocket, Depends, WebSocketDisconnect
|
|
|
|
from app.users.chat.dao import ChatDAO
|
|
from app.users.dependencies import validate_user_access_to_chat, get_current_user
|
|
from app.users.models import Users
|
|
from app.users.chat.router import router
|
|
|
|
|
|
class ConnectionManager(WebSocket):
|
|
def __init__(self):
|
|
self.active_connections: dict[int, list[WebSocket]] = {}
|
|
|
|
async def connect(self, chat_id: int, websocket: WebSocket):
|
|
await websocket.accept()
|
|
if chat_id not in self.active_connections:
|
|
self.active_connections[chat_id] = []
|
|
self.active_connections[chat_id].append(websocket)
|
|
|
|
def disconnect(self, chat_id: int, websocket: WebSocket):
|
|
self.active_connections[chat_id].remove(websocket)
|
|
|
|
async def broadcast(self, user_id: int, chat_id: int, message: str):
|
|
await self.add_message_to_database(user_id=user_id, chat_id=chat_id, message=message)
|
|
for websocket in self.active_connections[chat_id]:
|
|
await websocket.send_text(message)
|
|
|
|
@staticmethod
|
|
async def add_message_to_database(user_id: int, chat_id: int, message: str):
|
|
result = await ChatDAO.send_message(user_id=user_id, chat_id=chat_id, message=message)
|
|
return result
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
@router.websocket("/ws/{chat_id}")
|
|
async def websocket_endpoint(chat_id: int, user_id: int, websocket: WebSocket):
|
|
await validate_user_access_to_chat(user_id=user_id, chat_id=chat_id)
|
|
await manager.connect(chat_id, websocket)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
|
|
await manager.broadcast(user_id=user_id, chat_id=chat_id, message=data)
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(chat_id, websocket)
|
|
|
|
|
|
@router.post("")
|
|
async def chroot():
|
|
pass
|