import websockets from fastapi import WebSocket, WebSocketDisconnect, Depends from app.exceptions import IncorrectDataException, UserDontHavePermissionException from app.services.message_service import MessageService from app.unit_of_work import UnitOfWork from app.utils.auth import AuthService from app.chat.router import router from app.chat.shemas import SSendMessage, SMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage from app.users.dependencies import get_current_user_ws, get_token from app.users.schemas import SUser class ConnectionManager: def __init__(self): self.active_connections: dict[int, list[WebSocket]] = {} self.message_methods = { "send": self.send, "delete": self.delete, "edit": self.edit, "pin": self.pin, "unpin": self.unpin, } async def connect(self, chat_id: int, websocket: WebSocket): await websocket.accept() if chat_id not in self.active_connections: self.active_connections[chat_id] = [] self.active_connections[chat_id].append(websocket) def disconnect(self, chat_id: int, websocket: WebSocket): self.active_connections[chat_id].remove(websocket) async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict): try: new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message) for websocket in self.active_connections[chat_id]: await websocket.send_json(new_message) except KeyError: raise IncorrectDataException async def send(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict: message = SSendMessage.model_validate(message) new_message = await self.add_message_to_database(uow=uow, user_id=user_id, chat_id=chat_id, message=message) new_message = new_message.model_dump() new_message["created_at"] = new_message["created_at"].isoformat() new_message["flag"] = "send" return new_message async def delete(self, uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: message = SDeleteMessage.model_validate(message) if message.user_id != user_id: raise UserDontHavePermissionException deleted_message = await self.delete_message(uow=uow, message_id=message.id) new_message = {"deleted_message": deleted_message, "id": message.id, "flag": "delete"} return new_message async def edit(self, uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: message = SEditMessage.model_validate(message) if message.user_id != user_id: raise UserDontHavePermissionException edited_message = await self.edit_message( uow=uow, message_id=message.id, new_message=message.new_message, image_url=message.new_image_url ) new_message = { "flag": "edit", "id": message.id, "edited_message": edited_message, "new_message": message.new_message, "new_image_url": message.new_image_url, } return new_message async def pin(self, uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: message = SPinMessage.model_validate(message) pinned_message = await self.pin_message( uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id ) new_message = pinned_message.model_dump() new_message["created_at"] = new_message["created_at"].isoformat() new_message["flag"] = "pin" return new_message async def unpin(self, uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: message = SUnpinMessage.model_validate(message) unpinned_message = await self.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id) new_message = {"flag": "unpin", "id": unpinned_message} return new_message @staticmethod async def add_message_to_database(uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage) -> SMessage: new_message = await MessageService.send_message( uow=uow, user_id=user_id, chat_id=chat_id, message=message.message, image_url=message.image_url ) if message.answer: new_message = await MessageService.add_answer(uow=uow, self_id=new_message.id, answer_id=message.answer) return new_message @staticmethod async def delete_message(uow: UnitOfWork, message_id: int) -> bool: new_message = await MessageService.delete_message(uow=uow, message_id=message_id) return new_message @staticmethod async def edit_message(uow: UnitOfWork, message_id: int, new_message: str, image_url: str) -> bool: new_message = await MessageService.edit_message( uow=uow, message_id=message_id, new_message=new_message, new_image_url=image_url ) return new_message @staticmethod async def pin_message(uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage: pinned_message = await MessageService.pin_message(uow=uow, chat_id=chat_id, user_id=user_id, message_id=message_id) return pinned_message @staticmethod async def unpin_message(uow: UnitOfWork, chat_id: int, message_id: int) -> int: unpinned_message_id = await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message_id) return unpinned_message_id manager = ConnectionManager() @router.websocket( "/ws/{chat_id}", ) async def websocket_endpoint( chat_id: int, websocket: WebSocket, user: SUser = Depends(get_current_user_ws), uow=Depends(UnitOfWork) ): await AuthService.check_verificated_user_with_exc(uow=uow, user_id=user.id) await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id) await manager.connect(chat_id, websocket) await websocket.send_json({"scope": str(websocket.scope), "cookies": str(websocket.cookies), "headers": str(websocket.headers)}) try: while True: data = await websocket.receive_json() await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data) except WebSocketDisconnect: manager.disconnect(chat_id, websocket) # # @router.post( # "/ws/{chat_id}" # ) # async def websocket_endpoint( # chat_id: int, # token: str = Depends(get_token), # ): # url = f"ws://localhost:8000/api/chat/ws/{chat_id}" # async with websockets.connect(url, extra_headers={"Authorization": f"Bearer {token}"}) as websocket: # print(await websocket.recv())