import asyncio import logging from collections import defaultdict import websockets from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException, status from app.chat.exceptions import UseWSException, MessageNotFoundException, MessageAlreadyPinnedException from app.exceptions import IncorrectDataException from app.chat.exceptions import UserDontHavePermissionException from app.services import message_service, auth_service from app.utils.unit_of_work import UnitOfWork from app.chat.router import router from app.chat.shemas import SSendMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage, Responses from app.dependencies import get_current_user_ws, get_token, get_subprotocol_ws, get_verificated_user from app.users.schemas import SUser class ConnectionManager: def __init__(self): self.active_connections: dict[int, list[WebSocket]] = defaultdict(list) self.message_methods = { "send": self._send, "delete": self._delete, "edit": self._edit, "pin": self._pin, "unpin": self._unpin, } async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None) -> None: await websocket.accept(subprotocol=subprotocol) self.active_connections[chat_id].append(websocket) async def disconnect( self, chat_id: int, websocket: WebSocket, code_and_reason: tuple[int, str] | None = None, ) -> None: self.active_connections[chat_id].remove(websocket) if code_and_reason: await websocket.close(code=code_and_reason[0], reason=code_and_reason[1]) async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> None: try: new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message) for websocket in self.active_connections[chat_id]: await websocket.send_json(new_message) await polling_manager.send(chat_id, new_message) except (MessageNotFoundException, MessageAlreadyPinnedException): pass except KeyError: raise IncorrectDataException @staticmethod async def _send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict: message = SSendMessage.model_validate(message) new_message = await message_service.send_message( uow=uow, user_id=user_id, chat_id=chat_id, message=message, ) new_message = new_message.model_dump() new_message["id"] = str(new_message["id"]) new_message["answer_id"] = str(new_message["answer_id"]) if new_message["answer_id"] else None new_message["created_at"] = new_message["created_at"].isoformat() new_message["flag"] = "send" return new_message @staticmethod async def _delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: message = SDeleteMessage.model_validate(message) if message.user_id != user_id: raise UserDontHavePermissionException await message_service.delete_message(uow=uow, message_id=message.id) new_message = {"id": str(message.id), "flag": "delete"} return new_message @staticmethod async def _edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: message = SEditMessage.model_validate(message) if message.user_id != user_id: raise UserDontHavePermissionException await message_service.edit_message( uow=uow, message_id=message.id, new_message=message.new_message, new_image_url=message.new_image_url, ) new_message = { "flag": "edit", "id": str(message.id), "new_message": message.new_message, "new_image_url": message.new_image_url, } return new_message @staticmethod async def _pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: message = SPinMessage.model_validate(message) pinned_message = await message_service.pin_message( uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id, ) new_message = pinned_message.model_dump() new_message["id"] = str(new_message["id"]) new_message["answer_id"] = str(new_message["answer_id"]) if new_message["answer_id"] else None new_message["created_at"] = new_message["created_at"].isoformat() new_message["flag"] = "pin" return new_message @staticmethod async def _unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: message = SUnpinMessage.model_validate(message) await message_service.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id) new_message = {"flag": "unpin", "id": str(message.id)} return new_message manager = ConnectionManager() @router.websocket( "/ws/{chat_id}", ) async def websocket_endpoint( chat_id: int, websocket: WebSocket, user: SUser = Depends(get_current_user_ws), subprotocol: str = Depends(get_subprotocol_ws), uow=Depends(UnitOfWork), ): try: await auth_service.check_verificated_user(uow=uow, user_id=user.id) await auth_service.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id) await manager.connect(chat_id, websocket, subprotocol) while True: data = await websocket.receive_json() await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data) except WebSocketDisconnect: await manager.disconnect(chat_id, websocket) except HTTPException as e: code = status.WS_1011_INTERNAL_ERROR if e.status_code == 500 else status.WS_1003_UNSUPPORTED_DATA reason = e.detail code_and_reason = (code, reason) await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason) except Exception as e: logging.warning(e) code_and_reason = (status.WS_1011_INTERNAL_ERROR, "Internal Server Error") await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason) @router.post( "/ws/{chat_id}", responses={ status.HTTP_401_UNAUTHORIZED: { "model": Responses.STokenMissingException, "description": "Токен отсутствует" }, status.HTTP_403_FORBIDDEN: { "model": Responses.SNotAuthenticated, "description": "Not authenticated" }, status.HTTP_404_NOT_FOUND: { "model": Responses.SUserNotFoundException, "description": "Юзер не найден" }, status.HTTP_409_CONFLICT: { "model": Responses.SUserMustConfirmEmailException, "description": "Сначала подтвердите почту" }, status.HTTP_500_INTERNAL_SERVER_ERROR: { "model": Responses.SBlackPhoenixException, "description": "Внутренняя ошибка сервера" }, }, ) async def chat_ws( chat_id: int, token: str = Depends(get_token), ): raise UseWSException # noqa url = f"ws://localhost:8000/api/chat/ws/{chat_id}" async with websockets.connect(url, extra_headers={"Authorization": f"Bearer {token}"}) as websocket: print(await websocket.recv()) class PollingManager: def __init__(self): self.waiters: dict[int, list[asyncio.Future]] = defaultdict(list) self.messages: dict[int, list[dict]] = defaultdict(list) async def poll(self, chat_id: int) -> dict: future = asyncio.Future() self.waiters[chat_id].append(future) try: await future message = self.messages[chat_id][0] return message except asyncio.CancelledError: self.waiters[chat_id].remove(future) raise HTTPException(status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Client disconnected") async def prepare( self, uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage, ) -> None: message = message.model_dump() await manager.broadcast(uow=uow, user_id=user_id, chat_id=chat_id, message=message) await self.send(chat_id=chat_id, message=message) async def send( self, chat_id: int, message: dict, ) -> None: self.messages[chat_id].append(message) while self.waiters[chat_id]: waiter = self.waiters[chat_id].pop(0) if not waiter.done(): waiter.set_result(None) await asyncio.sleep(0) self.messages[chat_id].pop(0) polling_manager = PollingManager() @router.get( "/poll/{chat_id}", status_code=status.HTTP_200_OK, response_model=None, responses={ status.HTTP_401_UNAUTHORIZED: { "model": Responses.STokenMissingException, "description": "Токен отсутствует" }, status.HTTP_403_FORBIDDEN: { "model": Responses.SNotAuthenticated, "description": "Not authenticated" }, status.HTTP_404_NOT_FOUND: { "model": Responses.SUserNotFoundException, "description": "Юзер не найден" }, status.HTTP_409_CONFLICT: { "model": Responses.SUserMustConfirmEmailException, "description": "Сначала подтвердите почту" }, status.HTTP_500_INTERNAL_SERVER_ERROR: { "model": Responses.SBlackPhoenixException, "description": "Внутренняя ошибка сервера" }, }, ) async def poll(chat_id: int, user: SUser = Depends(get_verificated_user), uow=Depends(UnitOfWork)): await auth_service.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id) return await polling_manager.poll(chat_id) @router.post( "/send/{chat_id}", status_code=status.HTTP_201_CREATED, response_model=None, responses={ status.HTTP_401_UNAUTHORIZED: { "model": Responses.STokenMissingException, "description": "Токен отсутствует" }, status.HTTP_403_FORBIDDEN: { "model": Responses.SNotAuthenticated, "description": "Not authenticated" }, status.HTTP_404_NOT_FOUND: { "model": Responses.SUserNotFoundException, "description": "Юзер не найден" }, status.HTTP_409_CONFLICT: { "model": Responses.SUserMustConfirmEmailException, "description": "Сначала подтвердите почту" }, status.HTTP_500_INTERNAL_SERVER_ERROR: { "model": Responses.SBlackPhoenixException, "description": "Внутренняя ошибка сервера" }, }, ) async def send( chat_id: int, message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage, user: SUser = Depends(get_verificated_user), uow=Depends(UnitOfWork), ): await auth_service.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id) await polling_manager.prepare(uow=uow, user_id=user.id, chat_id=chat_id, message=message)