chat_back/chat_test/app/users/chat/websocket.py

56 lines
2 KiB
Python

import json
from typing import Dict, List
from fastapi import WebSocket, Depends, WebSocketDisconnect
from pydantic import parse_obj_as
from app.users.chat.dao import ChatDAO
from app.users.auth import validate_user_access_to_chat
from app.users.chat.shemas import SMessageSchema, SMessage
from app.users.models import Users
from app.users.chat.router import router
class ConnectionManager(WebSocket):
def __init__(self):
self.active_connections: dict[int, list[WebSocket]] = {}
async def connect(self, chat_id: int, websocket: WebSocket):
await websocket.accept()
if chat_id not in self.active_connections:
self.active_connections[chat_id] = []
self.active_connections[chat_id].append(websocket)
def disconnect(self, chat_id: int, websocket: WebSocket):
self.active_connections[chat_id].remove(websocket)
async def broadcast(self, user_id: int, chat_id: int, message: str):
new_message = await self.add_message_to_database(user_id=user_id, chat_id=chat_id, message=message)
new_message = dict(new_message)
new_message['created_at'] = new_message['created_at'].isoformat()
for websocket in self.active_connections[chat_id]:
await websocket.send_json(new_message)
@staticmethod
async def add_message_to_database(user_id: int, chat_id: int, message: str):
result = await ChatDAO.send_message(user_id=user_id, chat_id=chat_id, message=message)
new_message = await ChatDAO.get_message_by_id(message_id=result)
return new_message
manager = ConnectionManager()
@router.websocket("/ws/{chat_id}")
async def websocket_endpoint(chat_id: int, user_id: int, websocket: WebSocket):
await validate_user_access_to_chat(user_id=user_id, chat_id=chat_id)
await manager.connect(chat_id, websocket)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast(user_id=user_id, chat_id=chat_id, message=data)
except WebSocketDisconnect:
manager.disconnect(chat_id, websocket)