227 lines
7.4 KiB
Python
227 lines
7.4 KiB
Python
import asyncio
|
|
from collections import defaultdict
|
|
|
|
import websockets
|
|
from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException, status
|
|
|
|
from app.chat.exceptions import UseWSException
|
|
from app.exceptions import IncorrectDataException
|
|
from app.chat.exceptions import UserDontHavePermissionException
|
|
from app.services.message_service import MessageService
|
|
from app.unit_of_work import UnitOfWork
|
|
from app.utils.auth import AuthService
|
|
from app.chat.router import router
|
|
from app.chat.shemas import SSendMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage
|
|
from app.users.dependencies import get_current_user_ws, get_token, get_subprotocol_ws, get_verificated_user
|
|
from app.users.schemas import SUser
|
|
|
|
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self.active_connections: dict[int, list[WebSocket]] = {}
|
|
self.message_methods = {
|
|
"send": self.send,
|
|
"delete": self.delete,
|
|
"edit": self.edit,
|
|
"pin": self.pin,
|
|
"unpin": self.unpin,
|
|
}
|
|
|
|
async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None):
|
|
await websocket.accept(subprotocol=subprotocol)
|
|
if chat_id not in self.active_connections:
|
|
self.active_connections[chat_id] = []
|
|
self.active_connections[chat_id].append(websocket)
|
|
|
|
async def disconnect(self, chat_id: int, websocket: WebSocket, code_and_reason: tuple[int, str] | None = None):
|
|
self.active_connections[chat_id].remove(websocket)
|
|
if code_and_reason:
|
|
await websocket.close(code=code_and_reason[0], reason=code_and_reason[1])
|
|
|
|
async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict):
|
|
try:
|
|
new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message)
|
|
for websocket in self.active_connections[chat_id]:
|
|
await websocket.send_json(new_message)
|
|
await polling_manager.send(chat_id, new_message)
|
|
except KeyError:
|
|
raise IncorrectDataException
|
|
|
|
@staticmethod
|
|
async def send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict:
|
|
message = SSendMessage.model_validate(message)
|
|
|
|
new_message = await MessageService.send_message(
|
|
uow=uow, user_id=user_id, chat_id=chat_id, message=message, image_url=message.image_url
|
|
)
|
|
new_message = new_message.model_dump()
|
|
new_message["created_at"] = new_message["created_at"].isoformat()
|
|
new_message["flag"] = "send"
|
|
return new_message
|
|
|
|
@staticmethod
|
|
async def delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
|
|
message = SDeleteMessage.model_validate(message)
|
|
|
|
if message.user_id != user_id:
|
|
raise UserDontHavePermissionException
|
|
|
|
await MessageService.delete_message(uow=uow, message_id=message.id)
|
|
new_message = {"id": message.id, "flag": "delete"}
|
|
return new_message
|
|
|
|
@staticmethod
|
|
async def edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
|
|
message = SEditMessage.model_validate(message)
|
|
|
|
if message.user_id != user_id:
|
|
raise UserDontHavePermissionException
|
|
|
|
await MessageService.edit_message(
|
|
uow=uow, message_id=message.id, new_message=message.new_message, new_image_url=message.new_image_url
|
|
)
|
|
new_message = {
|
|
"flag": "edit",
|
|
"id": message.id,
|
|
"new_message": message.new_message,
|
|
"new_image_url": message.new_image_url,
|
|
}
|
|
return new_message
|
|
|
|
@staticmethod
|
|
async def pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
|
|
message = SPinMessage.model_validate(message)
|
|
pinned_message = await MessageService.pin_message(
|
|
uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id
|
|
)
|
|
new_message = pinned_message.model_dump()
|
|
new_message["created_at"] = new_message["created_at"].isoformat()
|
|
new_message["flag"] = "pin"
|
|
return new_message
|
|
|
|
@staticmethod
|
|
async def unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
|
|
message = SUnpinMessage.model_validate(message)
|
|
await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id)
|
|
new_message = {"flag": "unpin", "id": message.id}
|
|
return new_message
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
@router.websocket(
|
|
"/ws/{chat_id}",
|
|
)
|
|
async def websocket_endpoint(
|
|
chat_id: int,
|
|
websocket: WebSocket,
|
|
user: SUser = Depends(get_current_user_ws),
|
|
subprotocol: str = Depends(get_subprotocol_ws),
|
|
uow=Depends(UnitOfWork),
|
|
):
|
|
await AuthService.check_verificated_user(uow=uow, user_id=user.id)
|
|
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
|
|
await manager.connect(chat_id, websocket, subprotocol)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
|
|
await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data)
|
|
except WebSocketDisconnect:
|
|
await manager.disconnect(chat_id, websocket)
|
|
except HTTPException as e:
|
|
code = status.WS_1003_UNSUPPORTED_DATA
|
|
|
|
if e.status_code == 500:
|
|
code = status.WS_1011_INTERNAL_ERROR
|
|
|
|
reason = e.detail
|
|
code_and_reason = (code, reason)
|
|
await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason)
|
|
except Exception:
|
|
code_and_reason = (status.WS_1011_INTERNAL_ERROR, "Internal Server Error")
|
|
await manager.disconnect(chat_id=chat_id, websocket=websocket, code_and_reason=code_and_reason)
|
|
|
|
|
|
@router.post(
|
|
"/ws/{chat_id}",
|
|
)
|
|
async def chat_ws(
|
|
chat_id: int,
|
|
token: str = Depends(get_token),
|
|
):
|
|
raise UseWSException # noqa
|
|
url = f"ws://localhost:8000/api/chat/ws/{chat_id}"
|
|
async with websockets.connect(url, extra_headers={"Authorization": f"Bearer {token}"}) as websocket:
|
|
print(await websocket.recv())
|
|
|
|
|
|
class PollingManager:
|
|
def __init__(self):
|
|
self.waiters: dict[int, list[asyncio.Future]] = defaultdict(list)
|
|
self.messages: dict[int, list[dict]] = defaultdict(list)
|
|
|
|
async def poll(self, chat_id: int):
|
|
future = asyncio.Future()
|
|
self.waiters[chat_id].append(future)
|
|
try:
|
|
await future
|
|
message = self.messages[chat_id][0]
|
|
return message
|
|
except asyncio.CancelledError:
|
|
self.waiters[chat_id].remove(future)
|
|
raise HTTPException(status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Client disconnected")
|
|
|
|
async def prepare(
|
|
self,
|
|
uow: UnitOfWork,
|
|
user_id: int,
|
|
chat_id: int,
|
|
message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage,
|
|
):
|
|
message = message.model_dump()
|
|
await manager.broadcast(uow=uow, user_id=user_id, chat_id=chat_id, message=message)
|
|
|
|
await self.send(chat_id=chat_id, message=message)
|
|
|
|
async def send(
|
|
self,
|
|
chat_id: int,
|
|
message: dict,
|
|
):
|
|
self.messages[chat_id].append(message)
|
|
while self.waiters[chat_id]:
|
|
waiter = self.waiters[chat_id].pop(0)
|
|
if not waiter.done():
|
|
waiter.set_result(None)
|
|
await asyncio.sleep(0)
|
|
self.messages[chat_id].pop(0)
|
|
|
|
|
|
polling_manager = PollingManager()
|
|
|
|
|
|
@router.get(
|
|
"/poll/{chat_id}",
|
|
status_code=status.HTTP_200_OK,
|
|
response_model=None,
|
|
)
|
|
async def poll(chat_id: int, user: SUser = Depends(get_verificated_user), uow=Depends(UnitOfWork)):
|
|
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
|
|
return await polling_manager.poll(chat_id)
|
|
|
|
|
|
@router.post(
|
|
"/send/{chat_id}",
|
|
status_code=status.HTTP_201_CREATED,
|
|
response_model=None,
|
|
)
|
|
async def send(
|
|
chat_id: int,
|
|
message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage,
|
|
user: SUser = Depends(get_verificated_user),
|
|
uow=Depends(UnitOfWork),
|
|
):
|
|
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
|
|
await polling_manager.prepare(uow=uow, user_id=user.id, chat_id=chat_id, message=message)
|