130 lines
4.9 KiB
Python
130 lines
4.9 KiB
Python
from fastapi import WebSocket, WebSocketDisconnect, Depends
|
|
|
|
from app.exceptions import IncorrectDataException, UserDontHavePermissionException
|
|
from app.services.message_service import MessageService
|
|
from app.unit_of_work import UnitOfWork
|
|
from app.users.auth import AuthService
|
|
from app.chat.router import router
|
|
from app.chat.shemas import SSendMessage, SMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage
|
|
from app.users.dependencies import get_current_user_ws
|
|
from app.users.schemas import SUser
|
|
|
|
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self.active_connections: dict[int, list[WebSocket]] = {}
|
|
|
|
async def connect(self, chat_id: int, websocket: WebSocket):
|
|
await websocket.accept()
|
|
if chat_id not in self.active_connections:
|
|
self.active_connections[chat_id] = []
|
|
self.active_connections[chat_id].append(websocket)
|
|
|
|
def disconnect(self, chat_id: int, websocket: WebSocket):
|
|
self.active_connections[chat_id].remove(websocket)
|
|
|
|
async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict):
|
|
if "flag" not in message:
|
|
raise IncorrectDataException
|
|
|
|
if message["flag"] == "send":
|
|
message = SSendMessage.model_validate(message)
|
|
|
|
new_message = await self.add_message_to_database(uow=uow, user_id=user_id, chat_id=chat_id, message=message)
|
|
new_message = new_message.model_dump()
|
|
new_message["created_at"] = new_message["created_at"].isoformat()
|
|
new_message["flag"] = "send"
|
|
|
|
elif message["flag"] == "delete":
|
|
message = SDeleteMessage.model_validate(message)
|
|
|
|
if message.user_id != user_id:
|
|
raise UserDontHavePermissionException
|
|
|
|
deleted_message = await self.delete_message(uow=uow, message_id=message.id)
|
|
new_message = {"deleted_message": deleted_message, "id": message.id, "flag": "delete"}
|
|
|
|
elif message["flag"] == "edit":
|
|
message = SEditMessage.model_validate(message)
|
|
|
|
if message.user_id != user_id:
|
|
raise UserDontHavePermissionException
|
|
|
|
edited_message = await self.edit_message(
|
|
uow=uow, message_id=message.id, new_message=message.new_message, image_url=message.new_image_url
|
|
)
|
|
new_message = {
|
|
"flag": "edit",
|
|
"id": message.id,
|
|
"edited_message": edited_message,
|
|
"new_message": message.new_message,
|
|
"new_image_url": message.new_image_url,
|
|
}
|
|
|
|
elif message["flag"] == "pin":
|
|
message = SPinMessage.model_validate(message)
|
|
pinned_message = await self.pin_message(uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id)
|
|
new_message = pinned_message.model_dump()
|
|
new_message["created_at"] = new_message["created_at"].isoformat()
|
|
new_message["flag"] = "pin"
|
|
|
|
elif message["flag"] == "unpin":
|
|
message = SUnpinMessage.model_validate(message)
|
|
unpinned_message = await self.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id)
|
|
new_message = {"flag": "pin", "id": unpinned_message}
|
|
|
|
else:
|
|
raise IncorrectDataException
|
|
|
|
for websocket in self.active_connections[chat_id]:
|
|
await websocket.send_json(new_message)
|
|
|
|
@staticmethod
|
|
async def add_message_to_database(uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage) -> SMessage:
|
|
new_message = await MessageService.send_message(
|
|
uow=uow, user_id=user_id, chat_id=chat_id, message=message.message, image_url=message.image_url
|
|
)
|
|
if message.answer:
|
|
new_message = await MessageService.add_answer(uow=uow, self_id=new_message.id, answer_id=message.answer)
|
|
return new_message
|
|
|
|
@staticmethod
|
|
async def delete_message(uow: UnitOfWork, message_id: int) -> bool:
|
|
new_message = await MessageService.delete_message(uow=uow, message_id=message_id)
|
|
return new_message
|
|
|
|
@staticmethod
|
|
async def edit_message(uow: UnitOfWork, message_id: int, new_message: str, image_url: str) -> bool:
|
|
new_message = await MessageService.edit_message(
|
|
uow=uow, message_id=message_id, new_message=new_message, new_image_url=image_url
|
|
)
|
|
return new_message
|
|
|
|
@staticmethod
|
|
async def pin_message(uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage:
|
|
pinned_message = await MessageService.pin_message(uow=uow, chat_id=chat_id, user_id=user_id, message_id=message_id)
|
|
return pinned_message
|
|
|
|
@staticmethod
|
|
async def unpin_message(uow: UnitOfWork, chat_id: int, message_id: int) -> int:
|
|
unpinned_message_id = await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message_id)
|
|
return unpinned_message_id
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
@router.websocket(
|
|
"/ws/{chat_id}",
|
|
)
|
|
async def websocket_endpoint(chat_id: int, websocket: WebSocket, user: SUser = Depends(get_current_user_ws), uow=Depends(UnitOfWork)):
|
|
await AuthService.check_verificated_user_with_exc(uow=uow, user_id=user.id)
|
|
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
|
|
await manager.connect(chat_id, websocket)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
|
|
await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data)
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(chat_id, websocket)
|