diff --git a/chat_test/app/users/chat/dao.py b/chat_test/app/users/chat/dao.py index b713dd1..f271f3b 100644 --- a/chat_test/app/users/chat/dao.py +++ b/chat_test/app/users/chat/dao.py @@ -28,21 +28,24 @@ class ChatDAO(BaseDAO): @classmethod async def send_message(cls, user_id, chat_id: int, message: str): - query = insert(Messages).values(chat_id=chat_id, user_id=user_id, message=message, image_url=None) + query = (insert(Messages).values(chat_id=chat_id, user_id=user_id, message=message, image_url=None) + .returning(Messages.id)) async with async_session_maker() as session: - await session.execute(query) + result = await session.execute(query) await session.commit() - return True + return result.scalar() @classmethod async def get_message_by_id(cls, message_id: int): - query = select(Messages.__table__.columns).where( + query = (select(Messages.message, Messages.image_url, Messages.chat_id, Messages.user_id, + Messages.created_at, Users.avatar_image, Users.username).select_from(Messages) + .join(Users, Users.id == Messages.user_id) + .where( and_( Messages.id == message_id, Messages.visibility == True ) - - ) + )) async with async_session_maker() as session: result = await session.execute(query) result = result.mappings().all() @@ -97,4 +100,3 @@ class ChatDAO(BaseDAO): result = result.mappings().all() if result: return result - diff --git a/chat_test/app/users/chat/websocket.py b/chat_test/app/users/chat/websocket.py index 09f5f03..a487176 100644 --- a/chat_test/app/users/chat/websocket.py +++ b/chat_test/app/users/chat/websocket.py @@ -1,3 +1,4 @@ +import json from typing import Dict, List from fastapi import WebSocket, Depends, WebSocketDisconnect @@ -5,7 +6,7 @@ from pydantic import parse_obj_as from app.users.chat.dao import ChatDAO from app.users.auth import validate_user_access_to_chat -from app.users.chat.shemas import SMessageSchema +from app.users.chat.shemas import SMessageSchema, SMessage from app.users.models import Users from app.users.chat.router import router @@ -24,14 +25,17 @@ class ConnectionManager(WebSocket): self.active_connections[chat_id].remove(websocket) async def broadcast(self, user_id: int, chat_id: int, message: str): - await self.add_message_to_database(user_id=user_id, chat_id=chat_id, message=message) + new_message = await self.add_message_to_database(user_id=user_id, chat_id=chat_id, message=message) + new_message = dict(new_message) + new_message['created_at'] = new_message['created_at'].isoformat() for websocket in self.active_connections[chat_id]: - await websocket.send_json({'message': message}) + await websocket.send_json(new_message) @staticmethod async def add_message_to_database(user_id: int, chat_id: int, message: str): result = await ChatDAO.send_message(user_id=user_id, chat_id=chat_id, message=message) - return result + new_message = await ChatDAO.get_message_by_id(message_id=result) + return new_message manager = ConnectionManager()