chat_back/app/chat/websocket.py
2025-03-16 12:04:09 +03:00

325 lines
9.9 KiB
Python

import asyncio
import logging
from collections import defaultdict
import websockets
from fastapi import Depends, HTTPException, WebSocket, WebSocketDisconnect, status
from app.chat.exceptions import (
MessageAlreadyPinnedException,
MessageNotFoundException,
UserDontHavePermissionException,
UseWSException,
)
from app.chat.router import router
from app.chat.shemas import (
Responses,
SDeleteMessage,
SEditMessage,
SPinMessage,
SSendMessage,
SUnpinMessage,
)
from app.dependencies import (
get_current_user_ws,
get_subprotocol_ws,
get_token,
get_verificated_user,
)
from app.exceptions import IncorrectDataException
from app.services import auth_service, message_service
from app.users.schemas import SUser
from app.utils.unit_of_work import UnitOfWork
class ConnectionManager:
def __init__(self):
self.active_connections: dict[int, list[WebSocket]] = defaultdict(list)
self.message_methods = {
"send": self._send,
"delete": self._delete,
"edit": self._edit,
"pin": self._pin,
"unpin": self._unpin,
}
async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None) -> None:
await websocket.accept(subprotocol=subprotocol)
self.active_connections[chat_id].append(websocket)
async def disconnect(
self,
chat_id: int,
websocket: WebSocket,
code_and_reason: tuple[int, str] | None = None,
) -> None:
self.active_connections[chat_id].remove(websocket)
if code_and_reason:
await websocket.close(code=code_and_reason[0], reason=code_and_reason[1])
async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> None:
try:
new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message)
for websocket in self.active_connections[chat_id]:
await websocket.send_json(new_message)
await polling_manager.send(chat_id, new_message)
except (MessageNotFoundException, MessageAlreadyPinnedException):
pass
except KeyError:
raise IncorrectDataException
@staticmethod
async def _send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict:
message = SSendMessage.model_validate(message)
new_message = await message_service.send_message(
uow=uow,
user_id=user_id,
chat_id=chat_id,
message=message,
)
new_message = new_message.model_dump()
new_message["id"] = str(new_message["id"])
new_message["answer_id"] = str(new_message["answer_id"]) if new_message["answer_id"] else None
new_message["created_at"] = new_message["created_at"].isoformat()
new_message["flag"] = "send"
return new_message
@staticmethod
async def _delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
message = SDeleteMessage.model_validate(message)
if message.user_id != user_id:
raise UserDontHavePermissionException
await message_service.delete_message(uow=uow, message_id=message.id)
new_message = {"id": str(message.id), "flag": "delete"}
return new_message
@staticmethod
async def _edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
message = SEditMessage.model_validate(message)
if message.user_id != user_id:
raise UserDontHavePermissionException
await message_service.edit_message(
uow=uow,
message_id=message.id,
new_message=message.new_message,
new_image_url=message.new_image_url,
)
new_message = {
"flag": "edit",
"id": str(message.id),
"new_message": message.new_message,
"new_image_url": message.new_image_url,
}
return new_message
@staticmethod
async def _pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
message = SPinMessage.model_validate(message)
pinned_message = await message_service.pin_message(
uow=uow,
chat_id=chat_id,
user_id=message.user_id,
message_id=message.id,
)
new_message = pinned_message.model_dump()
new_message["id"] = str(new_message["id"])
new_message["answer_id"] = str(new_message["answer_id"]) if new_message["answer_id"] else None
new_message["created_at"] = new_message["created_at"].isoformat()
new_message["flag"] = "pin"
return new_message
@staticmethod
async def _unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
message = SUnpinMessage.model_validate(message)
await message_service.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id)
new_message = {"flag": "unpin", "id": str(message.id)}
return new_message
manager = ConnectionManager()
@router.websocket(
"/ws/{chat_id}",
)
async def websocket_endpoint(
chat_id: int,
websocket: WebSocket,
user: SUser = Depends(get_current_user_ws),
subprotocol: str = Depends(get_subprotocol_ws),
uow=Depends(UnitOfWork),
):
try:
await auth_service.check_verificated_user(uow=uow, user_id=user.id)
await auth_service.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
await manager.connect(chat_id, websocket, subprotocol)
while True:
data = await websocket.receive_json()
await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data)
except WebSocketDisconnect:
await manager.disconnect(chat_id, websocket)
except HTTPException as e:
code = status.WS_1011_INTERNAL_ERROR if e.status_code == 500 else status.WS_1003_UNSUPPORTED_DATA
reason = e.detail
code_and_reason = (code, reason)
await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason)
except Exception as e:
logging.warning(e)
code_and_reason = (status.WS_1011_INTERNAL_ERROR, "Internal Server Error")
await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason)
@router.post(
"/ws/{chat_id}",
responses={
status.HTTP_401_UNAUTHORIZED: {
"model": Responses.STokenMissingException,
"description": "Токен отсутствует"
},
status.HTTP_403_FORBIDDEN: {
"model": Responses.SNotAuthenticated,
"description": "Not authenticated"
},
status.HTTP_404_NOT_FOUND: {
"model": Responses.SUserNotFoundException,
"description": "Юзер не найден"
},
status.HTTP_409_CONFLICT: {
"model": Responses.SUserMustConfirmEmailException,
"description": "Сначала подтвердите почту"
},
status.HTTP_500_INTERNAL_SERVER_ERROR: {
"model": Responses.SBlackPhoenixException,
"description": "Внутренняя ошибка сервера"
},
},
)
async def chat_ws(
chat_id: int,
token: str = Depends(get_token),
):
raise UseWSException # noqa
url = f"ws://localhost:8000/api/chat/ws/{chat_id}"
async with websockets.connect(url, extra_headers={"Authorization": f"Bearer {token}"}) as websocket:
print(await websocket.recv())
class PollingManager:
def __init__(self):
self.waiters: dict[int, list[asyncio.Future]] = defaultdict(list)
self.messages: dict[int, list[dict]] = defaultdict(list)
async def poll(self, chat_id: int) -> dict:
future = asyncio.Future()
self.waiters[chat_id].append(future)
try:
await future
message = self.messages[chat_id][0]
return message
except asyncio.CancelledError:
self.waiters[chat_id].remove(future)
raise HTTPException(status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Client disconnected")
async def prepare(
self,
uow: UnitOfWork,
user_id: int,
chat_id: int,
message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage,
) -> None:
message = message.model_dump()
await manager.broadcast(uow=uow, user_id=user_id, chat_id=chat_id, message=message)
await self.send(chat_id=chat_id, message=message)
async def send(
self,
chat_id: int,
message: dict,
) -> None:
self.messages[chat_id].append(message)
while self.waiters[chat_id]:
waiter = self.waiters[chat_id].pop(0)
if not waiter.done():
waiter.set_result(None)
await asyncio.sleep(0)
self.messages[chat_id].pop(0)
polling_manager = PollingManager()
@router.get(
"/poll/{chat_id}",
status_code=status.HTTP_200_OK,
response_model=None,
responses={
status.HTTP_401_UNAUTHORIZED: {
"model": Responses.STokenMissingException,
"description": "Токен отсутствует"
},
status.HTTP_403_FORBIDDEN: {
"model": Responses.SNotAuthenticated,
"description": "Not authenticated"
},
status.HTTP_404_NOT_FOUND: {
"model": Responses.SUserNotFoundException,
"description": "Юзер не найден"
},
status.HTTP_409_CONFLICT: {
"model": Responses.SUserMustConfirmEmailException,
"description": "Сначала подтвердите почту"
},
status.HTTP_500_INTERNAL_SERVER_ERROR: {
"model": Responses.SBlackPhoenixException,
"description": "Внутренняя ошибка сервера"
},
},
)
async def poll(chat_id: int, user: SUser = Depends(get_verificated_user), uow=Depends(UnitOfWork)):
await auth_service.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
return await polling_manager.poll(chat_id)
@router.post(
"/send/{chat_id}",
status_code=status.HTTP_201_CREATED,
response_model=None,
responses={
status.HTTP_401_UNAUTHORIZED: {
"model": Responses.STokenMissingException,
"description": "Токен отсутствует"
},
status.HTTP_403_FORBIDDEN: {
"model": Responses.SNotAuthenticated,
"description": "Not authenticated"
},
status.HTTP_404_NOT_FOUND: {
"model": Responses.SUserNotFoundException,
"description": "Юзер не найден"
},
status.HTTP_409_CONFLICT: {
"model": Responses.SUserMustConfirmEmailException,
"description": "Сначала подтвердите почту"
},
status.HTTP_500_INTERNAL_SERVER_ERROR: {
"model": Responses.SBlackPhoenixException,
"description": "Внутренняя ошибка сервера"
},
},
)
async def send(
chat_id: int,
message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage,
user: SUser = Depends(get_verificated_user),
uow=Depends(UnitOfWork),
):
await auth_service.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
await polling_manager.prepare(uow=uow, user_id=user.id, chat_id=chat_id, message=message)