chat_back/app/chat/websocket.py
2024-08-20 14:54:31 +04:00

294 lines
9.6 KiB
Python

import asyncio
import logging
from collections import defaultdict
import websockets
from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException, status
from app.chat.exceptions import UseWSException, MessageNotFoundException, MessageAlreadyPinnedException
from app.exceptions import IncorrectDataException
from app.chat.exceptions import UserDontHavePermissionException
from app.services.message_service import MessageService
from app.utils.unit_of_work import UnitOfWork
from app.services.auth_service import AuthService
from app.chat.router import router
from app.chat.shemas import SSendMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage, Responses
from app.dependencies import get_current_user_ws, get_token, get_subprotocol_ws, get_verificated_user
from app.users.schemas import SUser
class ConnectionManager:
def __init__(self):
self.active_connections: dict[int, list[WebSocket]] = defaultdict(list)
self.message_methods = {
"send": self._send,
"delete": self._delete,
"edit": self._edit,
"pin": self._pin,
"unpin": self._unpin,
}
async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None) -> None:
await websocket.accept(subprotocol=subprotocol)
self.active_connections[chat_id].append(websocket)
async def disconnect(
self, chat_id: int, websocket: WebSocket, code_and_reason: tuple[int, str] | None = None
) -> None:
self.active_connections[chat_id].remove(websocket)
if code_and_reason:
await websocket.close(code=code_and_reason[0], reason=code_and_reason[1])
async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> None:
try:
new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message)
for websocket in self.active_connections[chat_id]:
await websocket.send_json(new_message)
await polling_manager.send(chat_id, new_message)
except (MessageNotFoundException, MessageAlreadyPinnedException):
pass
except KeyError:
raise IncorrectDataException
@staticmethod
async def _send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict:
message = SSendMessage.model_validate(message)
new_message = await MessageService.send_message(
uow=uow, user_id=user_id, chat_id=chat_id, message=message, image_url=message.image_url
)
new_message = new_message.model_dump()
new_message["created_at"] = new_message["created_at"].isoformat()
new_message["flag"] = "send"
return new_message
@staticmethod
async def _delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
message = SDeleteMessage.model_validate(message)
if message.user_id != user_id:
raise UserDontHavePermissionException
await MessageService.delete_message(uow=uow, message_id=message.id)
new_message = {"id": message.id, "flag": "delete"}
return new_message
@staticmethod
async def _edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
message = SEditMessage.model_validate(message)
if message.user_id != user_id:
raise UserDontHavePermissionException
await MessageService.edit_message(
uow=uow, message_id=message.id, new_message=message.new_message, new_image_url=message.new_image_url
)
new_message = {
"flag": "edit",
"id": message.id,
"new_message": message.new_message,
"new_image_url": message.new_image_url,
}
return new_message
@staticmethod
async def _pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
message = SPinMessage.model_validate(message)
pinned_message = await MessageService.pin_message(
uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id
)
new_message = pinned_message.model_dump()
new_message["created_at"] = new_message["created_at"].isoformat()
new_message["flag"] = "pin"
return new_message
@staticmethod
async def _unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
message = SUnpinMessage.model_validate(message)
await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id)
new_message = {"flag": "unpin", "id": message.id}
return new_message
manager = ConnectionManager()
@router.websocket(
"/ws/{chat_id}",
)
async def websocket_endpoint(
chat_id: int,
websocket: WebSocket,
user: SUser = Depends(get_current_user_ws),
subprotocol: str = Depends(get_subprotocol_ws),
uow=Depends(UnitOfWork),
):
try:
await AuthService.check_verificated_user(uow=uow, user_id=user.id)
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
await manager.connect(chat_id, websocket, subprotocol)
while True:
data = await websocket.receive_json()
await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data)
except WebSocketDisconnect:
await manager.disconnect(chat_id, websocket)
except HTTPException as e:
code = status.WS_1011_INTERNAL_ERROR if e.status_code == 500 else status.WS_1003_UNSUPPORTED_DATA
reason = e.detail
code_and_reason = (code, reason)
await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason)
except Exception as e:
logging.warning(e)
code_and_reason = (status.WS_1011_INTERNAL_ERROR, "Internal Server Error")
await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason)
@router.post(
"/ws/{chat_id}",
responses={
status.HTTP_401_UNAUTHORIZED: {
"model": Responses.STokenMissingException,
"description": "Токен отсутствует"
},
status.HTTP_403_FORBIDDEN: {
"model": Responses.SNotAuthenticated,
"description": "Not authenticated"
},
status.HTTP_404_NOT_FOUND: {
"model": Responses.SUserNotFoundException,
"description": "Юзер не найден"
},
status.HTTP_409_CONFLICT: {
"model": Responses.SUserMustConfirmEmailException,
"description": "Сначала подтвердите почту"
},
status.HTTP_500_INTERNAL_SERVER_ERROR: {
"model": Responses.SBlackPhoenixException,
"description": "Внутренняя ошибка сервера"
},
},
)
async def chat_ws(
chat_id: int,
token: str = Depends(get_token),
):
raise UseWSException # noqa
url = f"ws://localhost:8000/api/chat/ws/{chat_id}"
async with websockets.connect(url, extra_headers={"Authorization": f"Bearer {token}"}) as websocket:
print(await websocket.recv())
class PollingManager:
def __init__(self):
self.waiters: dict[int, list[asyncio.Future]] = defaultdict(list)
self.messages: dict[int, list[dict]] = defaultdict(list)
async def poll(self, chat_id: int) -> dict:
future = asyncio.Future()
self.waiters[chat_id].append(future)
try:
await future
message = self.messages[chat_id][0]
return message
except asyncio.CancelledError:
self.waiters[chat_id].remove(future)
raise HTTPException(status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Client disconnected")
async def prepare(
self,
uow: UnitOfWork,
user_id: int,
chat_id: int,
message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage,
) -> None:
message = message.model_dump()
await manager.broadcast(uow=uow, user_id=user_id, chat_id=chat_id, message=message)
await self.send(chat_id=chat_id, message=message)
async def send(
self,
chat_id: int,
message: dict,
) -> None:
self.messages[chat_id].append(message)
while self.waiters[chat_id]:
waiter = self.waiters[chat_id].pop(0)
if not waiter.done():
waiter.set_result(None)
await asyncio.sleep(0)
self.messages[chat_id].pop(0)
polling_manager = PollingManager()
@router.get(
"/poll/{chat_id}",
status_code=status.HTTP_200_OK,
response_model=None,
responses={
status.HTTP_401_UNAUTHORIZED: {
"model": Responses.STokenMissingException,
"description": "Токен отсутствует"
},
status.HTTP_403_FORBIDDEN: {
"model": Responses.SNotAuthenticated,
"description": "Not authenticated"
},
status.HTTP_404_NOT_FOUND: {
"model": Responses.SUserNotFoundException,
"description": "Юзер не найден"
},
status.HTTP_409_CONFLICT: {
"model": Responses.SUserMustConfirmEmailException,
"description": "Сначала подтвердите почту"
},
status.HTTP_500_INTERNAL_SERVER_ERROR: {
"model": Responses.SBlackPhoenixException,
"description": "Внутренняя ошибка сервера"
},
},
)
async def poll(chat_id: int, user: SUser = Depends(get_verificated_user), uow=Depends(UnitOfWork)):
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
return await polling_manager.poll(chat_id)
@router.post(
"/send/{chat_id}",
status_code=status.HTTP_201_CREATED,
response_model=None,
responses={
status.HTTP_401_UNAUTHORIZED: {
"model": Responses.STokenMissingException,
"description": "Токен отсутствует"
},
status.HTTP_403_FORBIDDEN: {
"model": Responses.SNotAuthenticated,
"description": "Not authenticated"
},
status.HTTP_404_NOT_FOUND: {
"model": Responses.SUserNotFoundException,
"description": "Юзер не найден"
},
status.HTTP_409_CONFLICT: {
"model": Responses.SUserMustConfirmEmailException,
"description": "Сначала подтвердите почту"
},
status.HTTP_500_INTERNAL_SERVER_ERROR: {
"model": Responses.SBlackPhoenixException,
"description": "Внутренняя ошибка сервера"
},
},
)
async def send(
chat_id: int,
message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage,
user: SUser = Depends(get_verificated_user),
uow=Depends(UnitOfWork),
):
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
await polling_manager.prepare(uow=uow, user_id=user.id, chat_id=chat_id, message=message)