import asyncio import logging from collections import defaultdict import websockets from fastapi import Depends, HTTPException, WebSocket, WebSocketDisconnect, status from app.chat.exceptions import ( MessageAlreadyPinnedException, MessageNotFoundException, UserDontHavePermissionException, UseWSException, ) from app.chat.router import router from app.chat.shemas import ( Responses, SDeleteMessage, SEditMessage, SPinMessage, SSendMessage, SUnpinMessage, ) from app.dependencies import ( get_current_user_ws, get_subprotocol_ws, get_token, get_verificated_user, ) from app.exceptions import IncorrectDataException from app.services import auth_service, message_service from app.users.schemas import SUser from app.utils.unit_of_work import UnitOfWork class ConnectionManager: def __init__(self): self.active_connections: dict[int, list[WebSocket]] = defaultdict(list) self.message_methods = { "send": self._send, "delete": self._delete, "edit": self._edit, "pin": self._pin, "unpin": self._unpin, } async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None) -> None: await websocket.accept(subprotocol=subprotocol) self.active_connections[chat_id].append(websocket) async def disconnect( self, chat_id: int, websocket: WebSocket, code_and_reason: tuple[int, str] | None = None, ) -> None: self.active_connections[chat_id].remove(websocket) if code_and_reason: await websocket.close(code=code_and_reason[0], reason=code_and_reason[1]) async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> None: try: new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message) for websocket in self.active_connections[chat_id]: await websocket.send_json(new_message) await polling_manager.send(chat_id, new_message) except (MessageNotFoundException, MessageAlreadyPinnedException): pass except KeyError: raise IncorrectDataException @staticmethod async def _send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict: message = SSendMessage.model_validate(message) new_message = await message_service.send_message( uow=uow, user_id=user_id, chat_id=chat_id, message=message, ) new_message = new_message.model_dump() new_message["id"] = str(new_message["id"]) new_message["answer_id"] = str(new_message["answer_id"]) if new_message["answer_id"] else None new_message["created_at"] = new_message["created_at"].isoformat() new_message["flag"] = "send" return new_message @staticmethod async def _delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: message = SDeleteMessage.model_validate(message) if message.user_id != user_id: raise UserDontHavePermissionException await message_service.delete_message(uow=uow, message_id=message.id) new_message = {"id": str(message.id), "flag": "delete"} return new_message @staticmethod async def _edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: message = SEditMessage.model_validate(message) if message.user_id != user_id: raise UserDontHavePermissionException await message_service.edit_message( uow=uow, message_id=message.id, new_message=message.new_message, new_image_url=message.new_image_url, ) new_message = { "flag": "edit", "id": str(message.id), "new_message": message.new_message, "new_image_url": message.new_image_url, } return new_message @staticmethod async def _pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: message = SPinMessage.model_validate(message) pinned_message = await message_service.pin_message( uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id, ) new_message = pinned_message.model_dump() new_message["id"] = str(new_message["id"]) new_message["answer_id"] = str(new_message["answer_id"]) if new_message["answer_id"] else None new_message["created_at"] = new_message["created_at"].isoformat() new_message["flag"] = "pin" return new_message @staticmethod async def _unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: message = SUnpinMessage.model_validate(message) await message_service.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id) new_message = {"flag": "unpin", "id": str(message.id)} return new_message manager = ConnectionManager() @router.websocket( "/ws/{chat_id}", ) async def websocket_endpoint( chat_id: int, websocket: WebSocket, user: SUser = Depends(get_current_user_ws), subprotocol: str = Depends(get_subprotocol_ws), uow=Depends(UnitOfWork), ): try: await auth_service.check_verificated_user(uow=uow, user_id=user.id) await auth_service.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id) await manager.connect(chat_id, websocket, subprotocol) while True: data = await websocket.receive_json() await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data) except WebSocketDisconnect: await manager.disconnect(chat_id, websocket) except HTTPException as e: code = status.WS_1011_INTERNAL_ERROR if e.status_code == 500 else status.WS_1003_UNSUPPORTED_DATA reason = e.detail code_and_reason = (code, reason) await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason) except Exception as e: logging.warning(e) code_and_reason = (status.WS_1011_INTERNAL_ERROR, "Internal Server Error") await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason) @router.post( "/ws/{chat_id}", responses={ status.HTTP_401_UNAUTHORIZED: { "model": Responses.STokenMissingException, "description": "Токен отсутствует" }, status.HTTP_403_FORBIDDEN: { "model": Responses.SNotAuthenticated, "description": "Not authenticated" }, status.HTTP_404_NOT_FOUND: { "model": Responses.SUserNotFoundException, "description": "Юзер не найден" }, status.HTTP_409_CONFLICT: { "model": Responses.SUserMustConfirmEmailException, "description": "Сначала подтвердите почту" }, status.HTTP_500_INTERNAL_SERVER_ERROR: { "model": Responses.SBlackPhoenixException, "description": "Внутренняя ошибка сервера" }, }, ) async def chat_ws( chat_id: int, token: str = Depends(get_token), ): raise UseWSException # noqa url = f"ws://localhost:8000/api/chat/ws/{chat_id}" async with websockets.connect(url, extra_headers={"Authorization": f"Bearer {token}"}) as websocket: print(await websocket.recv()) class PollingManager: def __init__(self): self.waiters: dict[int, list[asyncio.Future]] = defaultdict(list) self.messages: dict[int, list[dict]] = defaultdict(list) async def poll(self, chat_id: int) -> dict: future = asyncio.Future() self.waiters[chat_id].append(future) try: await future message = self.messages[chat_id][0] return message except asyncio.CancelledError: self.waiters[chat_id].remove(future) raise HTTPException(status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Client disconnected") async def prepare( self, uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage, ) -> None: message = message.model_dump() await manager.broadcast(uow=uow, user_id=user_id, chat_id=chat_id, message=message) await self.send(chat_id=chat_id, message=message) async def send( self, chat_id: int, message: dict, ) -> None: self.messages[chat_id].append(message) while self.waiters[chat_id]: waiter = self.waiters[chat_id].pop(0) if not waiter.done(): waiter.set_result(None) await asyncio.sleep(0) self.messages[chat_id].pop(0) polling_manager = PollingManager() @router.get( "/poll/{chat_id}", status_code=status.HTTP_200_OK, response_model=None, responses={ status.HTTP_401_UNAUTHORIZED: { "model": Responses.STokenMissingException, "description": "Токен отсутствует" }, status.HTTP_403_FORBIDDEN: { "model": Responses.SNotAuthenticated, "description": "Not authenticated" }, status.HTTP_404_NOT_FOUND: { "model": Responses.SUserNotFoundException, "description": "Юзер не найден" }, status.HTTP_409_CONFLICT: { "model": Responses.SUserMustConfirmEmailException, "description": "Сначала подтвердите почту" }, status.HTTP_500_INTERNAL_SERVER_ERROR: { "model": Responses.SBlackPhoenixException, "description": "Внутренняя ошибка сервера" }, }, ) async def poll(chat_id: int, user: SUser = Depends(get_verificated_user), uow=Depends(UnitOfWork)): await auth_service.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id) return await polling_manager.poll(chat_id) @router.post( "/send/{chat_id}", status_code=status.HTTP_201_CREATED, response_model=None, responses={ status.HTTP_401_UNAUTHORIZED: { "model": Responses.STokenMissingException, "description": "Токен отсутствует" }, status.HTTP_403_FORBIDDEN: { "model": Responses.SNotAuthenticated, "description": "Not authenticated" }, status.HTTP_404_NOT_FOUND: { "model": Responses.SUserNotFoundException, "description": "Юзер не найден" }, status.HTTP_409_CONFLICT: { "model": Responses.SUserMustConfirmEmailException, "description": "Сначала подтвердите почту" }, status.HTTP_500_INTERNAL_SERVER_ERROR: { "model": Responses.SBlackPhoenixException, "description": "Внутренняя ошибка сервера" }, }, ) async def send( chat_id: int, message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage, user: SUser = Depends(get_verificated_user), uow=Depends(UnitOfWork), ): await auth_service.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id) await polling_manager.prepare(uow=uow, user_id=user.id, chat_id=chat_id, message=message)