chat_back/app/chat/websocket.py

158 lines
5.6 KiB
Python

import websockets
from fastapi import WebSocket, WebSocketDisconnect, Depends
from app.chat.exceptions import UseWSException
from app.exceptions import IncorrectDataException
from app.chat.exceptions import UserDontHavePermissionException
from app.services.message_service import MessageService
from app.unit_of_work import UnitOfWork
from app.utils.auth import AuthService
from app.chat.router import router
from app.chat.shemas import SSendMessage, SMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage
from app.users.dependencies import get_current_user_ws, get_token, get_subprotocol_ws
from app.users.schemas import SUser
class ConnectionManager:
def __init__(self):
self.active_connections: dict[int, list[WebSocket]] = {}
self.message_methods = {
"send": self.send,
"delete": self.delete,
"edit": self.edit,
"pin": self.pin,
"unpin": self.unpin,
}
async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None):
await websocket.accept(subprotocol=subprotocol)
if chat_id not in self.active_connections:
self.active_connections[chat_id] = []
self.active_connections[chat_id].append(websocket)
def disconnect(self, chat_id: int, websocket: WebSocket):
self.active_connections[chat_id].remove(websocket)
async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict):
try:
new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message)
for websocket in self.active_connections[chat_id]:
await websocket.send_json(new_message)
except KeyError:
raise IncorrectDataException
async def send(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict:
message = SSendMessage.model_validate(message)
new_message = await self.add_message_to_database(uow=uow, user_id=user_id, chat_id=chat_id, message=message)
new_message = new_message.model_dump()
new_message["created_at"] = new_message["created_at"].isoformat()
new_message["flag"] = "send"
return new_message
async def delete(self, uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
message = SDeleteMessage.model_validate(message)
if message.user_id != user_id:
raise UserDontHavePermissionException
await self.delete_message(uow=uow, message_id=message.id)
new_message = {"id": message.id, "flag": "delete"}
return new_message
async def edit(self, uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
message = SEditMessage.model_validate(message)
if message.user_id != user_id:
raise UserDontHavePermissionException
await self.edit_message(
uow=uow, message_id=message.id, new_message=message.new_message, image_url=message.new_image_url
)
new_message = {
"flag": "edit",
"id": message.id,
"new_message": message.new_message,
"new_image_url": message.new_image_url,
}
return new_message
async def pin(self, uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
message = SPinMessage.model_validate(message)
pinned_message = await self.pin_message(
uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id
)
new_message = pinned_message.model_dump()
new_message["created_at"] = new_message["created_at"].isoformat()
new_message["flag"] = "pin"
return new_message
async def unpin(self, uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
message = SUnpinMessage.model_validate(message)
await self.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id)
new_message = {"flag": "unpin", "id": message.id}
return new_message
@staticmethod
async def add_message_to_database(uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage) -> SMessage:
new_message = await MessageService.send_message(
uow=uow, user_id=user_id, chat_id=chat_id, message=message, image_url=message.image_url
)
return new_message
@staticmethod
async def delete_message(uow: UnitOfWork, message_id: int) -> None:
await MessageService.delete_message(uow=uow, message_id=message_id)
@staticmethod
async def edit_message(uow: UnitOfWork, message_id: int, new_message: str, image_url: str) -> None:
await MessageService.edit_message(
uow=uow, message_id=message_id, new_message=new_message, new_image_url=image_url
)
@staticmethod
async def pin_message(uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage:
pinned_message = await MessageService.pin_message(uow=uow, chat_id=chat_id, user_id=user_id, message_id=message_id)
return pinned_message
@staticmethod
async def unpin_message(uow: UnitOfWork, chat_id: int, message_id: int) -> None:
await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message_id)
manager = ConnectionManager()
@router.websocket(
"/ws/{chat_id}",
)
async def websocket_endpoint(
chat_id: int,
websocket: WebSocket,
user: SUser = Depends(get_current_user_ws),
subprotocol: str = Depends(get_subprotocol_ws),
uow=Depends(UnitOfWork),
):
await AuthService.check_verificated_user_with_exc(uow=uow, user_id=user.id)
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
await manager.connect(chat_id, websocket, subprotocol)
try:
while True:
data = await websocket.receive_json()
await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data)
except WebSocketDisconnect:
manager.disconnect(chat_id, websocket)
@router.post(
"/ws/{chat_id}",
)
async def chat_ws(
chat_id: int,
token: str = Depends(get_token),
):
raise UseWSException # noqa
url = f"ws://localhost:8000/api/chat/ws/{chat_id}"
async with websockets.connect(url, extra_headers={"Authorization": f"Bearer {token}"}) as websocket:
print(await websocket.recv())