From 8c7a502dc5081c1863e098bd1f2acdb5e08f8a53 Mon Sep 17 00:00:00 2001 From: urec56 Date: Sat, 8 Jun 2024 18:36:54 +0500 Subject: [PATCH] =?UTF-8?q?=D0=98=D0=B7=D0=BC=D0=B5=D0=BD=D0=B5=D0=BD?= =?UTF-8?q?=D0=B8=D1=8F=20DAO?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/chat/exceptions.py | 5 + app/chat/router.py | 9 +- app/dao/chat.py | 126 +++++++++++++----- .../versions/2024-06-08_database_creation.py | 11 +- app/models/pinned_message.py | 5 +- 5 files changed, 109 insertions(+), 47 deletions(-) diff --git a/app/chat/exceptions.py b/app/chat/exceptions.py index 94589c6..e4f615f 100644 --- a/app/chat/exceptions.py +++ b/app/chat/exceptions.py @@ -31,3 +31,8 @@ class UserAlreadyInChatException(BlackPhoenixException): class UserAlreadyPinnedChatException(BlackPhoenixException): status_code = status.HTTP_409_CONFLICT detail = "Юзер уже закрепил чат" + + +class MessageAlreadyPinnedException(BlackPhoenixException): + status_code = status.HTTP_409_CONFLICT + detail = "Сообщение уже закрепили" diff --git a/app/chat/router.py b/app/chat/router.py index fada1d3..ac3fd80 100644 --- a/app/chat/router.py +++ b/app/chat/router.py @@ -47,7 +47,7 @@ async def create_chat( if user.id == user_to_exclude: raise UserCanNotReadThisChatException async with uow: - chat_id = await uow.chat.create(user_id=user_to_exclude, chat_name=chat_name, created_by=user.id) + chat_id = await uow.chat.create_chat(user_id=user_to_exclude, chat_name=chat_name, created_by=user.id) await uow.chat.add_user_to_chat(user.id, chat_id) await uow.chat.add_user_to_chat(settings.ADMIN_USER_ID, chat_id) await uow.commit() @@ -217,8 +217,5 @@ async def get_pinned_chats(user: SUser = Depends(check_verificated_user_with_exc async def pinned_messages(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id) async with uow: - messages = await uow.chat.get_pinned_messages(chat_id=chat_id) - if messages: - for mes in messages: - mes["created_at"] = mes["created_at"].isoformat() - return {"pinned_messages": messages} + pinned_messages = await uow.chat.get_pinned_messages(chat_id=chat_id) + return pinned_messages diff --git a/app/dao/chat.py b/app/dao/chat.py index ce02365..dd5fcc6 100644 --- a/app/dao/chat.py +++ b/app/dao/chat.py @@ -1,13 +1,16 @@ -import logging - from sqlalchemy import insert, select, update, delete, func from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.orm import aliased from app.dao.base import BaseDAO from app.database import engine # noqa -from app.chat.exceptions import UserAlreadyInChatException, UserAlreadyPinnedChatException, MessageNotFoundException -from app.chat.shemas import SMessage, SMessageList +from app.chat.exceptions import ( + UserAlreadyInChatException, + UserAlreadyPinnedChatException, + MessageNotFoundException, + MessageAlreadyPinnedException, +) +from app.chat.shemas import SMessage, SMessageList, SPinnedMessages from app.models.users import Users from app.models.message_answer import MessageAnswer from app.models.chat import Chats @@ -26,12 +29,12 @@ class ChatDAO(BaseDAO): result = result.mappings().one_or_none() return result - async def create(self, user_id: int, chat_name: str, created_by: int) -> int: + async def create_chat(self, user_id: int, chat_name: str, created_by: int) -> int: stmt = insert(Chats).values(chat_for=user_id, chat_name=chat_name, created_by=created_by).returning(Chats.id) result = await self.session.execute(stmt) - result = result.scalar() - return result + chat_id = result.scalar() + return chat_id async def add_user_to_chat(self, user_id: int, chat_id: int) -> None: try: @@ -71,14 +74,14 @@ class ChatDAO(BaseDAO): ) .select_from(MessageAnswer) .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) - .where(MessageAnswer.self_id == Message.id, msg.visibility == True) + .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 .scalar_subquery() ) ) .select_from(Message) .join(Users, Users.id == Message.user_id) .join(MessageAnswer, Message.id == MessageAnswer.self_id, isouter=True) - .where(Message.id == message_id, Message.visibility == True) + .where(Message.id == message_id, Message.visibility == True) # noqa: E712 ) result = await self.session.execute(query) @@ -88,8 +91,8 @@ class ChatDAO(BaseDAO): raise MessageNotFoundException async def delete_message(self, message_id: int) -> None: - query = update(Message).where(Message.id == message_id).values(visibility=False) - await self.session.execute(query) + stmt = update(Message).where(Message.id == message_id).values(visibility=False) + await self.session.execute(stmt) async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageList: @@ -112,21 +115,21 @@ class ChatDAO(BaseDAO): select(msg.message) .select_from(MessageAnswer) .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) - .where(MessageAnswer.self_id == Message.id, msg.visibility == True) + .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 .scalar_subquery() ).label("answer_message"), ) .select_from(Message) .join(Users, Message.user_id == Users.id) - .outerjoin(MessageAnswer, MessageAnswer.self_id == Message.id) - .where(Message.chat_id == chat_id, Message.visibility == True) + .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True) + .where(Message.chat_id == chat_id, Message.visibility == True) # noqa: E712 .order_by(Message.created_at.desc()) .limit(messages_to_get) .offset(message_number_from) .cte("messages_with_users") ) - message_query = ( + query = ( select( func.json_build_object( "messages", @@ -148,13 +151,13 @@ class ChatDAO(BaseDAO): ).select_from(messages_with_users) ) - result = await self.session.execute(message_query) + result = await self.session.execute(query) result = result.scalar() return SMessageList.model_validate(result) async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> None: - query = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url) - await self.session.execute(query) + stmt = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url) + await self.session.execute(stmt) async def add_answer(self, self_id: int, answer_id: int) -> None: stmt = ( @@ -164,12 +167,12 @@ class ChatDAO(BaseDAO): await self.session.execute(stmt) async def delete_chat(self, chat_id: int) -> None: - query = update(Chats).where(Chats.id == chat_id).values(visibility=False) - await self.session.execute(query) + stmt = update(Chats).where(Chats.id == chat_id).values(visibility=False) + await self.session.execute(stmt) async def delete_user_from_chat(self, chat_id: int, user_id: int) -> None: - query = delete(UserChat).where(UserChat.chat_id == chat_id, UserChat.user_id == user_id) - await self.session.execute(query) + stmt = delete(UserChat).where(UserChat.chat_id == chat_id, UserChat.user_id == user_id) + await self.session.execute(stmt) async def pin_chat(self, chat_id: int, user_id: int) -> None: try: @@ -179,8 +182,8 @@ class ChatDAO(BaseDAO): raise UserAlreadyPinnedChatException async def unpin_chat(self, chat_id: int, user_id: int) -> None: - query = delete(PinnedChat).where(PinnedChat.chat_id == chat_id, PinnedChat.user_id == user_id) - await self.session.execute(query) + stmt = delete(PinnedChat).where(PinnedChat.chat_id == chat_id, PinnedChat.user_id == user_id) + await self.session.execute(stmt) async def get_pinned_chats(self, user_id: int): chats_with_descriptions = ( @@ -213,7 +216,7 @@ class ChatDAO(BaseDAO): ) .distinct() .select_from(PinnedChat) - .join(chats_with_avatars, PinnedChat.chat_id == chats_with_avatars.c.chat_id) + .join(chats_with_avatars, PinnedChat.chat_id == chats_with_avatars.c.chat_id) # noqa .where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True) # noqa: E712 ) # print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса @@ -222,14 +225,74 @@ class ChatDAO(BaseDAO): return result async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> None: - stmt = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id) - await self.session.execute(stmt) + try: + stmt = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id) + await self.session.execute(stmt) + except IntegrityError: + raise MessageAlreadyPinnedException async def unpin_message(self, chat_id: int, message_id: int) -> None: stmt = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id) await self.session.execute(stmt) - async def get_pinned_messages(self, chat_id: int) -> list[dict]: + async def get_pinned_messages(self, chat_id: int) -> SPinnedMessages: + + msg = aliased(Message, name="msg") + + messages_with_users = ( + select( + Message.id, + Message.message, + Message.image_url, + Message.chat_id, + Message.user_id, + Message.created_at, + Message.visibility, + Users.username, + Users.avatar_image, + MessageAnswer.self_id, + MessageAnswer.answer_id, + ( + select(msg.message) + .select_from(MessageAnswer) + .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) + .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 + .scalar_subquery() + ).label("answer_message"), + ) + .select_from(PinnedMessage) + .join(Message, PinnedMessage.message_id == Message.id) + .join(Users, Message.user_id == Users.id) + .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True) + .where(PinnedMessage.chat_id == chat_id, Message.visibility == True) # noqa: E712 + .order_by(Message.created_at.desc()) + .cte("messages_with_users") + ) + + query = ( + select( + func.json_build_object( + "pinned_messages", + func.json_agg( + func.json_build_object( + "id", messages_with_users.c.id, + "message", messages_with_users.c.message, + "image_url", messages_with_users.c.image_url, + "chat_id", messages_with_users.c.chat_id, + "user_id", messages_with_users.c.user_id, + "created_at", messages_with_users.c.created_at, + "avatar_image", messages_with_users.c.avatar_image, + "username", messages_with_users.c.username, + "answer_id", messages_with_users.c.answer_id, + "answer_message", messages_with_users.c.answer_message, + ) + ) + ) + ).select_from(messages_with_users) + ) + + print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса + query = ( select( Message.id, @@ -251,7 +314,6 @@ class ChatDAO(BaseDAO): ) result = await self.session.execute(query) - result = result.mappings().all() - if result: - result = [dict(res) for res in result] - return result + result = result.scalar_one() + + return SPinnedMessages.model_validate(result) diff --git a/app/migrations/versions/2024-06-08_database_creation.py b/app/migrations/versions/2024-06-08_database_creation.py index bb45b7b..d93391c 100644 --- a/app/migrations/versions/2024-06-08_database_creation.py +++ b/app/migrations/versions/2024-06-08_database_creation.py @@ -1,8 +1,8 @@ """Database Creation -Revision ID: 1cc709a9c827 +Revision ID: 4668313943c0 Revises: -Create Date: 2024-06-08 17:51:30.648467 +Create Date: 2024-06-08 18:36:31.974804 """ from typing import Sequence, Union @@ -12,7 +12,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = '1cc709a9c827' +revision: str = '4668313943c0' down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -85,14 +85,13 @@ def upgrade() -> None: sa.PrimaryKeyConstraint('self_id') ) op.create_table('pinned_message', - sa.Column('id', sa.Integer(), nullable=False), sa.Column('chat_id', sa.Integer(), nullable=True), - sa.Column('message_id', sa.Integer(), nullable=True), + sa.Column('message_id', sa.Integer(), nullable=False), sa.Column('user_id', sa.Integer(), nullable=True), sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ), sa.ForeignKeyConstraint(['message_id'], ['message.id'], ), sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + sa.PrimaryKeyConstraint('message_id') ) # ### end Alembic commands ### diff --git a/app/models/pinned_message.py b/app/models/pinned_message.py index ae1d55b..d6889da 100644 --- a/app/models/pinned_message.py +++ b/app/models/pinned_message.py @@ -1,5 +1,5 @@ from sqlalchemy import ForeignKey -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import mapped_column from app.database import Base @@ -7,7 +7,6 @@ from app.database import Base class PinnedMessage(Base): __tablename__ = "pinned_message" - id: Mapped[int] = mapped_column(primary_key=True) chat_id = mapped_column(ForeignKey("chat.id")) - message_id = mapped_column(ForeignKey("message.id")) + message_id = mapped_column(ForeignKey("message.id"), primary_key=True) user_id = mapped_column(ForeignKey("users.id"))