chat_back/app/chat/websocket.py
2024-07-13 14:41:33 +04:00

227 lines
7.4 KiB
Python

import asyncio
from collections import defaultdict
import websockets
from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException, status
from app.chat.exceptions import UseWSException
from app.exceptions import IncorrectDataException
from app.chat.exceptions import UserDontHavePermissionException
from app.services.message_service import MessageService
from app.unit_of_work import UnitOfWork
from app.utils.auth import AuthService
from app.chat.router import router
from app.chat.shemas import SSendMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage
from app.users.dependencies import get_current_user_ws, get_token, get_subprotocol_ws, get_verificated_user
from app.users.schemas import SUser
class ConnectionManager:
def __init__(self):
self.active_connections: dict[int, list[WebSocket]] = {}
self.message_methods = {
"send": self.send,
"delete": self.delete,
"edit": self.edit,
"pin": self.pin,
"unpin": self.unpin,
}
async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None):
await websocket.accept(subprotocol=subprotocol)
if chat_id not in self.active_connections:
self.active_connections[chat_id] = []
self.active_connections[chat_id].append(websocket)
async def disconnect(self, chat_id: int, websocket: WebSocket, code_and_reason: tuple[int, str] | None = None):
self.active_connections[chat_id].remove(websocket)
if code_and_reason:
await websocket.close(code=code_and_reason[0], reason=code_and_reason[1])
async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict):
try:
new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message)
for websocket in self.active_connections[chat_id]:
await websocket.send_json(new_message)
await polling_manager.send(chat_id, new_message)
except KeyError:
raise IncorrectDataException
@staticmethod
async def send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict:
message = SSendMessage.model_validate(message)
new_message = await MessageService.send_message(
uow=uow, user_id=user_id, chat_id=chat_id, message=message, image_url=message.image_url
)
new_message = new_message.model_dump()
new_message["created_at"] = new_message["created_at"].isoformat()
new_message["flag"] = "send"
return new_message
@staticmethod
async def delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
message = SDeleteMessage.model_validate(message)
if message.user_id != user_id:
raise UserDontHavePermissionException
await MessageService.delete_message(uow=uow, message_id=message.id)
new_message = {"id": message.id, "flag": "delete"}
return new_message
@staticmethod
async def edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
message = SEditMessage.model_validate(message)
if message.user_id != user_id:
raise UserDontHavePermissionException
await MessageService.edit_message(
uow=uow, message_id=message.id, new_message=message.new_message, new_image_url=message.new_image_url
)
new_message = {
"flag": "edit",
"id": message.id,
"new_message": message.new_message,
"new_image_url": message.new_image_url,
}
return new_message
@staticmethod
async def pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
message = SPinMessage.model_validate(message)
pinned_message = await MessageService.pin_message(
uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id
)
new_message = pinned_message.model_dump()
new_message["created_at"] = new_message["created_at"].isoformat()
new_message["flag"] = "pin"
return new_message
@staticmethod
async def unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
message = SUnpinMessage.model_validate(message)
await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id)
new_message = {"flag": "unpin", "id": message.id}
return new_message
manager = ConnectionManager()
@router.websocket(
"/ws/{chat_id}",
)
async def websocket_endpoint(
chat_id: int,
websocket: WebSocket,
user: SUser = Depends(get_current_user_ws),
subprotocol: str = Depends(get_subprotocol_ws),
uow=Depends(UnitOfWork),
):
await AuthService.check_verificated_user(uow=uow, user_id=user.id)
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
await manager.connect(chat_id, websocket, subprotocol)
try:
while True:
data = await websocket.receive_json()
await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data)
except WebSocketDisconnect:
await manager.disconnect(chat_id, websocket)
except HTTPException as e:
code = status.WS_1003_UNSUPPORTED_DATA
if e.status_code == 500:
code = status.WS_1011_INTERNAL_ERROR
reason = e.detail
code_and_reason = (code, reason)
await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason)
except Exception:
code_and_reason = (status.WS_1011_INTERNAL_ERROR, "Internal Server Error")
await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason)
@router.post(
"/ws/{chat_id}",
)
async def chat_ws(
chat_id: int,
token: str = Depends(get_token),
):
raise UseWSException # noqa
url = f"ws://localhost:8000/api/chat/ws/{chat_id}"
async with websockets.connect(url, extra_headers={"Authorization": f"Bearer {token}"}) as websocket:
print(await websocket.recv())
class PollingManager:
def __init__(self):
self.waiters: dict[int, list[asyncio.Future]] = defaultdict(list)
self.messages: dict[int, list[dict]] = defaultdict(list)
async def poll(self, chat_id: int):
future = asyncio.Future()
self.waiters[chat_id].append(future)
try:
await future
message = self.messages[chat_id][0]
return message
except asyncio.CancelledError:
self.waiters[chat_id].remove(future)
raise HTTPException(status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Client disconnected")
async def prepare(
self,
uow: UnitOfWork,
user_id: int,
chat_id: int,
message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage,
):
message = message.model_dump()
await manager.broadcast(uow=uow, user_id=user_id, chat_id=chat_id, message=message)
await self.send(chat_id=chat_id, message=message)
async def send(
self,
chat_id: int,
message: dict,
):
self.messages[chat_id].append(message)
while self.waiters[chat_id]:
waiter = self.waiters[chat_id].pop(0)
if not waiter.done():
waiter.set_result(None)
await asyncio.sleep(0)
self.messages[chat_id].pop(0)
polling_manager = PollingManager()
@router.get(
"/poll/{chat_id}",
status_code=status.HTTP_200_OK,
response_model=SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage,
)
async def poll(chat_id: int, user: SUser = Depends(get_verificated_user), uow=Depends(UnitOfWork)):
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
return await polling_manager.poll(chat_id)
@router.post(
"/send/{chat_id}",
status_code=status.HTTP_201_CREATED,
response_model=None,
)
async def send(
chat_id: int,
message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage,
user: SUser = Depends(get_verificated_user),
uow=Depends(UnitOfWork),
):
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
await polling_manager.prepare(uow=uow, user_id=user.id, chat_id=chat_id, message=message)