81 lines
3.5 KiB
Python
81 lines
3.5 KiB
Python
from fastapi import WebSocket, WebSocketDisconnect
|
|
|
|
from app.exceptions import IncorrectDataException, UserDontHavePermissionException
|
|
from app.users.chat.dao import ChatDAO
|
|
from app.users.auth import validate_user_access_to_chat, check_verificated_user_with_exc
|
|
from app.users.chat.router import router
|
|
|
|
|
|
class ConnectionManager(WebSocket):
|
|
def __init__(self):
|
|
self.active_connections: dict[int, list[WebSocket]] = {}
|
|
|
|
async def connect(self, chat_id: int, websocket: WebSocket):
|
|
await websocket.accept()
|
|
if chat_id not in self.active_connections:
|
|
self.active_connections[chat_id] = []
|
|
self.active_connections[chat_id].append(websocket)
|
|
|
|
def disconnect(self, chat_id: int, websocket: WebSocket):
|
|
self.active_connections[chat_id].remove(websocket)
|
|
|
|
async def broadcast(self, user_id: int, chat_id: int, message: dict):
|
|
if "flag" not in message:
|
|
raise IncorrectDataException
|
|
|
|
if message['flag'] == 'send':
|
|
new_message = await self.add_message_to_database(
|
|
user_id=user_id, chat_id=chat_id, message=message['message'], image_url=message['image_url']
|
|
)
|
|
new_message = dict(new_message)
|
|
new_message['created_at'] = new_message['created_at'].isoformat()
|
|
new_message['flag'] = 'send'
|
|
elif message['flag'] == 'delete':
|
|
if message["user_id"] != user_id:
|
|
raise UserDontHavePermissionException
|
|
deleted_message = await self.delete_message(message["id"])
|
|
new_message = {'deleted_message': deleted_message, 'id': message['id'], "flag": "delete"}
|
|
elif message['flag'] == 'edit':
|
|
if message["user_id"] != user_id:
|
|
raise UserDontHavePermissionException
|
|
edited_message = await self.edit_message(message['id'], message['new_message'], message['new_image_url'])
|
|
new_message = {'edited_message': edited_message, "flag": "edit", "message_id": message["id"],
|
|
"new_message": message["new_message"], "new_image_url": message["new_image_url"]}
|
|
else:
|
|
raise IncorrectDataException
|
|
|
|
for websocket in self.active_connections[chat_id]:
|
|
await websocket.send_json(new_message)
|
|
|
|
@staticmethod
|
|
async def add_message_to_database(user_id: int, chat_id: int, message: str, image_url: str) -> dict:
|
|
new_message = await ChatDAO.send_message(user_id=user_id, chat_id=chat_id, message=message, image_url=image_url)
|
|
return new_message[0]
|
|
|
|
@staticmethod
|
|
async def delete_message(message_id: int) -> bool:
|
|
new_message = await ChatDAO.delete_message(message_id)
|
|
return new_message
|
|
|
|
@staticmethod
|
|
async def edit_message(message_id: int, new_message: str, image_url: str) -> bool:
|
|
new_message = await ChatDAO.edit_message(message_id=message_id, new_message=new_message, new_image_url=image_url)
|
|
return new_message
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
@router.websocket("/ws/{chat_id}")
|
|
async def websocket_endpoint(chat_id: int, user_id: int, websocket: WebSocket):
|
|
await check_verificated_user_with_exc(user_id=user_id)
|
|
await validate_user_access_to_chat(user_id=user_id, chat_id=chat_id)
|
|
await manager.connect(chat_id, websocket)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
|
|
await manager.broadcast(user_id=user_id, chat_id=chat_id, message=data)
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(chat_id, websocket)
|
|
|