import asyncio import logging from collections import defaultdict import websockets from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException, status from app.chat.exceptions import UseWSException, MessageNotFoundException, MessageAlreadyPinnedException from app.exceptions import IncorrectDataException from app.chat.exceptions import UserDontHavePermissionException from app.services.message_service import MessageService from app.utils.unit_of_work import UnitOfWork from app.services.auth_service import AuthService from app.chat.router import router from app.chat.shemas import SSendMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage, Responses from app.dependencies import get_current_user_ws, get_token, get_subprotocol_ws, get_verificated_user from app.users.schemas import SUser class ConnectionManager: def __init__(self): self.active_connections: dict[int, list[WebSocket]] = defaultdict(list) self.message_methods = { "send": self._send, "delete": self._delete, "edit": self._edit, "pin": self._pin, "unpin": self._unpin, } async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None) -> None: await websocket.accept(subprotocol=subprotocol) self.active_connections[chat_id].append(websocket) async def disconnect( self, chat_id: int, websocket: WebSocket, code_and_reason: tuple[int, str] | None = None ) -> None: self.active_connections[chat_id].remove(websocket) if code_and_reason: await websocket.close(code=code_and_reason[0], reason=code_and_reason[1]) async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> None: try: new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message) for websocket in self.active_connections[chat_id]: await websocket.send_json(new_message) await polling_manager.send(chat_id, new_message) except (MessageNotFoundException, MessageAlreadyPinnedException): pass except KeyError: raise IncorrectDataException @staticmethod async def _send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict: message = SSendMessage.model_validate(message) new_message = await MessageService.send_message( uow=uow, user_id=user_id, chat_id=chat_id, message=message, image_url=message.image_url ) new_message = new_message.model_dump() new_message["id"] = str(new_message["id"]) new_message["answer_id"] = str(new_message["answer_id"]) if new_message["answer_id"] else None new_message["created_at"] = new_message["created_at"].isoformat() new_message["flag"] = "send" return new_message @staticmethod async def _delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: message = SDeleteMessage.model_validate(message) if message.user_id != user_id: raise UserDontHavePermissionException await MessageService.delete_message(uow=uow, message_id=message.id) new_message = {"id": str(message.id), "flag": "delete"} return new_message @staticmethod async def _edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: message = SEditMessage.model_validate(message) if message.user_id != user_id: raise UserDontHavePermissionException await MessageService.edit_message( uow=uow, message_id=message.id, new_message=message.new_message, new_image_url=message.new_image_url ) new_message = { "flag": "edit", "id": str(message.id), "new_message": message.new_message, "new_image_url": message.new_image_url, } return new_message @staticmethod async def _pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: message = SPinMessage.model_validate(message) pinned_message = await MessageService.pin_message( uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id ) new_message = pinned_message.model_dump() new_message["id"] = str(new_message["id"]) new_message["answer_id"] = str(new_message["answer_id"]) if new_message["answer_id"] else None new_message["created_at"] = new_message["created_at"].isoformat() new_message["flag"] = "pin" logging.warning(new_message) return new_message @staticmethod async def _unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: message = SUnpinMessage.model_validate(message) await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id) new_message = {"flag": "unpin", "id": str(message.id)} return new_message manager = ConnectionManager() @router.websocket( "/ws/{chat_id}", ) async def websocket_endpoint( chat_id: int, websocket: WebSocket, user: SUser = Depends(get_current_user_ws), subprotocol: str = Depends(get_subprotocol_ws), uow=Depends(UnitOfWork), ): try: await AuthService.check_verificated_user(uow=uow, user_id=user.id) await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id) await manager.connect(chat_id, websocket, subprotocol) while True: data = await websocket.receive_json() await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data) except WebSocketDisconnect: await manager.disconnect(chat_id, websocket) except HTTPException as e: code = status.WS_1011_INTERNAL_ERROR if e.status_code == 500 else status.WS_1003_UNSUPPORTED_DATA reason = e.detail code_and_reason = (code, reason) await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason) except Exception as e: logging.warning(e) code_and_reason = (status.WS_1011_INTERNAL_ERROR, "Internal Server Error") await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason) @router.post( "/ws/{chat_id}", responses={ status.HTTP_401_UNAUTHORIZED: { "model": Responses.STokenMissingException, "description": "Токен отсутствует" }, status.HTTP_403_FORBIDDEN: { "model": Responses.SNotAuthenticated, "description": "Not authenticated" }, status.HTTP_404_NOT_FOUND: { "model": Responses.SUserNotFoundException, "description": "Юзер не найден" }, status.HTTP_409_CONFLICT: { "model": Responses.SUserMustConfirmEmailException, "description": "Сначала подтвердите почту" }, status.HTTP_500_INTERNAL_SERVER_ERROR: { "model": Responses.SBlackPhoenixException, "description": "Внутренняя ошибка сервера" }, }, ) async def chat_ws( chat_id: int, token: str = Depends(get_token), ): raise UseWSException # noqa url = f"ws://localhost:8000/api/chat/ws/{chat_id}" async with websockets.connect(url, extra_headers={"Authorization": f"Bearer {token}"}) as websocket: print(await websocket.recv()) class PollingManager: def __init__(self): self.waiters: dict[int, list[asyncio.Future]] = defaultdict(list) self.messages: dict[int, list[dict]] = defaultdict(list) async def poll(self, chat_id: int) -> dict: future = asyncio.Future() self.waiters[chat_id].append(future) try: await future message = self.messages[chat_id][0] return message except asyncio.CancelledError: self.waiters[chat_id].remove(future) raise HTTPException(status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Client disconnected") async def prepare( self, uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage, ) -> None: message = message.model_dump() await manager.broadcast(uow=uow, user_id=user_id, chat_id=chat_id, message=message) await self.send(chat_id=chat_id, message=message) async def send( self, chat_id: int, message: dict, ) -> None: self.messages[chat_id].append(message) while self.waiters[chat_id]: waiter = self.waiters[chat_id].pop(0) if not waiter.done(): waiter.set_result(None) await asyncio.sleep(0) self.messages[chat_id].pop(0) polling_manager = PollingManager() @router.get( "/poll/{chat_id}", status_code=status.HTTP_200_OK, response_model=None, responses={ status.HTTP_401_UNAUTHORIZED: { "model": Responses.STokenMissingException, "description": "Токен отсутствует" }, status.HTTP_403_FORBIDDEN: { "model": Responses.SNotAuthenticated, "description": "Not authenticated" }, status.HTTP_404_NOT_FOUND: { "model": Responses.SUserNotFoundException, "description": "Юзер не найден" }, status.HTTP_409_CONFLICT: { "model": Responses.SUserMustConfirmEmailException, "description": "Сначала подтвердите почту" }, status.HTTP_500_INTERNAL_SERVER_ERROR: { "model": Responses.SBlackPhoenixException, "description": "Внутренняя ошибка сервера" }, }, ) async def poll(chat_id: int, user: SUser = Depends(get_verificated_user), uow=Depends(UnitOfWork)): await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id) return await polling_manager.poll(chat_id) @router.post( "/send/{chat_id}", status_code=status.HTTP_201_CREATED, response_model=None, responses={ status.HTTP_401_UNAUTHORIZED: { "model": Responses.STokenMissingException, "description": "Токен отсутствует" }, status.HTTP_403_FORBIDDEN: { "model": Responses.SNotAuthenticated, "description": "Not authenticated" }, status.HTTP_404_NOT_FOUND: { "model": Responses.SUserNotFoundException, "description": "Юзер не найден" }, status.HTTP_409_CONFLICT: { "model": Responses.SUserMustConfirmEmailException, "description": "Сначала подтвердите почту" }, status.HTTP_500_INTERNAL_SERVER_ERROR: { "model": Responses.SBlackPhoenixException, "description": "Внутренняя ошибка сервера" }, }, ) async def send( chat_id: int, message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage, user: SUser = Depends(get_verificated_user), uow=Depends(UnitOfWork), ): await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id) await polling_manager.prepare(uow=uow, user_id=user.id, chat_id=chat_id, message=message)