158 lines
5.6 KiB
Python
158 lines
5.6 KiB
Python
import websockets
|
|
from fastapi import WebSocket, WebSocketDisconnect, Depends
|
|
|
|
from app.chat.exceptions import UseWSException
|
|
from app.exceptions import IncorrectDataException
|
|
from app.chat.exceptions import UserDontHavePermissionException
|
|
from app.services.message_service import MessageService
|
|
from app.unit_of_work import UnitOfWork
|
|
from app.utils.auth import AuthService
|
|
from app.chat.router import router
|
|
from app.chat.shemas import SSendMessage, SMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage
|
|
from app.users.dependencies import get_current_user_ws, get_token, get_subprotocol_ws
|
|
from app.users.schemas import SUser
|
|
|
|
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self.active_connections: dict[int, list[WebSocket]] = {}
|
|
self.message_methods = {
|
|
"send": self.send,
|
|
"delete": self.delete,
|
|
"edit": self.edit,
|
|
"pin": self.pin,
|
|
"unpin": self.unpin,
|
|
}
|
|
|
|
async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None):
|
|
await websocket.accept(subprotocol=subprotocol)
|
|
if chat_id not in self.active_connections:
|
|
self.active_connections[chat_id] = []
|
|
self.active_connections[chat_id].append(websocket)
|
|
|
|
def disconnect(self, chat_id: int, websocket: WebSocket):
|
|
self.active_connections[chat_id].remove(websocket)
|
|
|
|
async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict):
|
|
try:
|
|
new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message)
|
|
for websocket in self.active_connections[chat_id]:
|
|
await websocket.send_json(new_message)
|
|
except KeyError:
|
|
raise IncorrectDataException
|
|
|
|
async def send(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict:
|
|
message = SSendMessage.model_validate(message)
|
|
|
|
new_message = await self.add_message_to_database(uow=uow, user_id=user_id, chat_id=chat_id, message=message)
|
|
new_message = new_message.model_dump()
|
|
new_message["created_at"] = new_message["created_at"].isoformat()
|
|
new_message["flag"] = "send"
|
|
return new_message
|
|
|
|
async def delete(self, uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
|
|
message = SDeleteMessage.model_validate(message)
|
|
|
|
if message.user_id != user_id:
|
|
raise UserDontHavePermissionException
|
|
|
|
await self.delete_message(uow=uow, message_id=message.id)
|
|
new_message = {"id": message.id, "flag": "delete"}
|
|
return new_message
|
|
|
|
async def edit(self, uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
|
|
message = SEditMessage.model_validate(message)
|
|
|
|
if message.user_id != user_id:
|
|
raise UserDontHavePermissionException
|
|
|
|
await self.edit_message(
|
|
uow=uow, message_id=message.id, new_message=message.new_message, image_url=message.new_image_url
|
|
)
|
|
new_message = {
|
|
"flag": "edit",
|
|
"id": message.id,
|
|
"new_message": message.new_message,
|
|
"new_image_url": message.new_image_url,
|
|
}
|
|
return new_message
|
|
|
|
async def pin(self, uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
|
|
message = SPinMessage.model_validate(message)
|
|
pinned_message = await self.pin_message(
|
|
uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id
|
|
)
|
|
new_message = pinned_message.model_dump()
|
|
new_message["created_at"] = new_message["created_at"].isoformat()
|
|
new_message["flag"] = "pin"
|
|
return new_message
|
|
|
|
async def unpin(self, uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
|
|
message = SUnpinMessage.model_validate(message)
|
|
await self.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id)
|
|
new_message = {"flag": "unpin", "id": message.id}
|
|
return new_message
|
|
|
|
@staticmethod
|
|
async def add_message_to_database(uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage) -> SMessage:
|
|
new_message = await MessageService.send_message(
|
|
uow=uow, user_id=user_id, chat_id=chat_id, message=message, image_url=message.image_url
|
|
)
|
|
return new_message
|
|
|
|
@staticmethod
|
|
async def delete_message(uow: UnitOfWork, message_id: int) -> None:
|
|
await MessageService.delete_message(uow=uow, message_id=message_id)
|
|
|
|
@staticmethod
|
|
async def edit_message(uow: UnitOfWork, message_id: int, new_message: str, image_url: str) -> None:
|
|
await MessageService.edit_message(
|
|
uow=uow, message_id=message_id, new_message=new_message, new_image_url=image_url
|
|
)
|
|
|
|
@staticmethod
|
|
async def pin_message(uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage:
|
|
pinned_message = await MessageService.pin_message(uow=uow, chat_id=chat_id, user_id=user_id, message_id=message_id)
|
|
return pinned_message
|
|
|
|
@staticmethod
|
|
async def unpin_message(uow: UnitOfWork, chat_id: int, message_id: int) -> None:
|
|
await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message_id)
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
@router.websocket(
|
|
"/ws/{chat_id}",
|
|
)
|
|
async def websocket_endpoint(
|
|
chat_id: int,
|
|
websocket: WebSocket,
|
|
user: SUser = Depends(get_current_user_ws),
|
|
subprotocol: str = Depends(get_subprotocol_ws),
|
|
uow=Depends(UnitOfWork),
|
|
):
|
|
await AuthService.check_verificated_user_with_exc(uow=uow, user_id=user.id)
|
|
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
|
|
await manager.connect(chat_id, websocket, subprotocol)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
|
|
await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data)
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(chat_id, websocket)
|
|
|
|
|
|
@router.post(
|
|
"/ws/{chat_id}",
|
|
)
|
|
async def chat_ws(
|
|
chat_id: int,
|
|
token: str = Depends(get_token),
|
|
):
|
|
raise UseWSException # noqa
|
|
url = f"ws://localhost:8000/api/chat/ws/{chat_id}"
|
|
async with websockets.connect(url, extra_headers={"Authorization": f"Bearer {token}"}) as websocket:
|
|
print(await websocket.recv())
|