Изменения DAO

This commit is contained in:
urec56 2024-06-08 18:36:54 +05:00
parent 9aa9da1af7
commit 8c7a502dc5
5 changed files with 109 additions and 47 deletions

View file

@ -31,3 +31,8 @@ class UserAlreadyInChatException(BlackPhoenixException):
class UserAlreadyPinnedChatException(BlackPhoenixException): class UserAlreadyPinnedChatException(BlackPhoenixException):
status_code = status.HTTP_409_CONFLICT status_code = status.HTTP_409_CONFLICT
detail = "Юзер уже закрепил чат" detail = "Юзер уже закрепил чат"
class MessageAlreadyPinnedException(BlackPhoenixException):
status_code = status.HTTP_409_CONFLICT
detail = "Сообщение уже закрепили"

View file

@ -47,7 +47,7 @@ async def create_chat(
if user.id == user_to_exclude: if user.id == user_to_exclude:
raise UserCanNotReadThisChatException raise UserCanNotReadThisChatException
async with uow: async with uow:
chat_id = await uow.chat.create(user_id=user_to_exclude, chat_name=chat_name, created_by=user.id) chat_id = await uow.chat.create_chat(user_id=user_to_exclude, chat_name=chat_name, created_by=user.id)
await uow.chat.add_user_to_chat(user.id, chat_id) await uow.chat.add_user_to_chat(user.id, chat_id)
await uow.chat.add_user_to_chat(settings.ADMIN_USER_ID, chat_id) await uow.chat.add_user_to_chat(settings.ADMIN_USER_ID, chat_id)
await uow.commit() await uow.commit()
@ -217,8 +217,5 @@ async def get_pinned_chats(user: SUser = Depends(check_verificated_user_with_exc
async def pinned_messages(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): async def pinned_messages(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id) await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id)
async with uow: async with uow:
messages = await uow.chat.get_pinned_messages(chat_id=chat_id) pinned_messages = await uow.chat.get_pinned_messages(chat_id=chat_id)
if messages: return pinned_messages
for mes in messages:
mes["created_at"] = mes["created_at"].isoformat()
return {"pinned_messages": messages}

View file

@ -1,13 +1,16 @@
import logging
from sqlalchemy import insert, select, update, delete, func from sqlalchemy import insert, select, update, delete, func
from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.orm import aliased from sqlalchemy.orm import aliased
from app.dao.base import BaseDAO from app.dao.base import BaseDAO
from app.database import engine # noqa from app.database import engine # noqa
from app.chat.exceptions import UserAlreadyInChatException, UserAlreadyPinnedChatException, MessageNotFoundException from app.chat.exceptions import (
from app.chat.shemas import SMessage, SMessageList UserAlreadyInChatException,
UserAlreadyPinnedChatException,
MessageNotFoundException,
MessageAlreadyPinnedException,
)
from app.chat.shemas import SMessage, SMessageList, SPinnedMessages
from app.models.users import Users from app.models.users import Users
from app.models.message_answer import MessageAnswer from app.models.message_answer import MessageAnswer
from app.models.chat import Chats from app.models.chat import Chats
@ -26,12 +29,12 @@ class ChatDAO(BaseDAO):
result = result.mappings().one_or_none() result = result.mappings().one_or_none()
return result return result
async def create(self, user_id: int, chat_name: str, created_by: int) -> int: async def create_chat(self, user_id: int, chat_name: str, created_by: int) -> int:
stmt = insert(Chats).values(chat_for=user_id, chat_name=chat_name, created_by=created_by).returning(Chats.id) stmt = insert(Chats).values(chat_for=user_id, chat_name=chat_name, created_by=created_by).returning(Chats.id)
result = await self.session.execute(stmt) result = await self.session.execute(stmt)
result = result.scalar() chat_id = result.scalar()
return result return chat_id
async def add_user_to_chat(self, user_id: int, chat_id: int) -> None: async def add_user_to_chat(self, user_id: int, chat_id: int) -> None:
try: try:
@ -71,14 +74,14 @@ class ChatDAO(BaseDAO):
) )
.select_from(MessageAnswer) .select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True) .join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery() .scalar_subquery()
) )
) )
.select_from(Message) .select_from(Message)
.join(Users, Users.id == Message.user_id) .join(Users, Users.id == Message.user_id)
.join(MessageAnswer, Message.id == MessageAnswer.self_id, isouter=True) .join(MessageAnswer, Message.id == MessageAnswer.self_id, isouter=True)
.where(Message.id == message_id, Message.visibility == True) .where(Message.id == message_id, Message.visibility == True) # noqa: E712
) )
result = await self.session.execute(query) result = await self.session.execute(query)
@ -88,8 +91,8 @@ class ChatDAO(BaseDAO):
raise MessageNotFoundException raise MessageNotFoundException
async def delete_message(self, message_id: int) -> None: async def delete_message(self, message_id: int) -> None:
query = update(Message).where(Message.id == message_id).values(visibility=False) stmt = update(Message).where(Message.id == message_id).values(visibility=False)
await self.session.execute(query) await self.session.execute(stmt)
async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageList: async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageList:
@ -112,21 +115,21 @@ class ChatDAO(BaseDAO):
select(msg.message) select(msg.message)
.select_from(MessageAnswer) .select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True) .join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery() .scalar_subquery()
).label("answer_message"), ).label("answer_message"),
) )
.select_from(Message) .select_from(Message)
.join(Users, Message.user_id == Users.id) .join(Users, Message.user_id == Users.id)
.outerjoin(MessageAnswer, MessageAnswer.self_id == Message.id) .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
.where(Message.chat_id == chat_id, Message.visibility == True) .where(Message.chat_id == chat_id, Message.visibility == True) # noqa: E712
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(messages_to_get) .limit(messages_to_get)
.offset(message_number_from) .offset(message_number_from)
.cte("messages_with_users") .cte("messages_with_users")
) )
message_query = ( query = (
select( select(
func.json_build_object( func.json_build_object(
"messages", "messages",
@ -148,13 +151,13 @@ class ChatDAO(BaseDAO):
).select_from(messages_with_users) ).select_from(messages_with_users)
) )
result = await self.session.execute(message_query) result = await self.session.execute(query)
result = result.scalar() result = result.scalar()
return SMessageList.model_validate(result) return SMessageList.model_validate(result)
async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> None: async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> None:
query = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url) stmt = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url)
await self.session.execute(query) await self.session.execute(stmt)
async def add_answer(self, self_id: int, answer_id: int) -> None: async def add_answer(self, self_id: int, answer_id: int) -> None:
stmt = ( stmt = (
@ -164,12 +167,12 @@ class ChatDAO(BaseDAO):
await self.session.execute(stmt) await self.session.execute(stmt)
async def delete_chat(self, chat_id: int) -> None: async def delete_chat(self, chat_id: int) -> None:
query = update(Chats).where(Chats.id == chat_id).values(visibility=False) stmt = update(Chats).where(Chats.id == chat_id).values(visibility=False)
await self.session.execute(query) await self.session.execute(stmt)
async def delete_user_from_chat(self, chat_id: int, user_id: int) -> None: async def delete_user_from_chat(self, chat_id: int, user_id: int) -> None:
query = delete(UserChat).where(UserChat.chat_id == chat_id, UserChat.user_id == user_id) stmt = delete(UserChat).where(UserChat.chat_id == chat_id, UserChat.user_id == user_id)
await self.session.execute(query) await self.session.execute(stmt)
async def pin_chat(self, chat_id: int, user_id: int) -> None: async def pin_chat(self, chat_id: int, user_id: int) -> None:
try: try:
@ -179,8 +182,8 @@ class ChatDAO(BaseDAO):
raise UserAlreadyPinnedChatException raise UserAlreadyPinnedChatException
async def unpin_chat(self, chat_id: int, user_id: int) -> None: async def unpin_chat(self, chat_id: int, user_id: int) -> None:
query = delete(PinnedChat).where(PinnedChat.chat_id == chat_id, PinnedChat.user_id == user_id) stmt = delete(PinnedChat).where(PinnedChat.chat_id == chat_id, PinnedChat.user_id == user_id)
await self.session.execute(query) await self.session.execute(stmt)
async def get_pinned_chats(self, user_id: int): async def get_pinned_chats(self, user_id: int):
chats_with_descriptions = ( chats_with_descriptions = (
@ -213,7 +216,7 @@ class ChatDAO(BaseDAO):
) )
.distinct() .distinct()
.select_from(PinnedChat) .select_from(PinnedChat)
.join(chats_with_avatars, PinnedChat.chat_id == chats_with_avatars.c.chat_id) .join(chats_with_avatars, PinnedChat.chat_id == chats_with_avatars.c.chat_id) # noqa
.where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True) # noqa: E712 .where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True) # noqa: E712
) )
# print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса # print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса
@ -222,14 +225,74 @@ class ChatDAO(BaseDAO):
return result return result
async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> None: async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> None:
stmt = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id) try:
await self.session.execute(stmt) stmt = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id)
await self.session.execute(stmt)
except IntegrityError:
raise MessageAlreadyPinnedException
async def unpin_message(self, chat_id: int, message_id: int) -> None: async def unpin_message(self, chat_id: int, message_id: int) -> None:
stmt = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id) stmt = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id)
await self.session.execute(stmt) await self.session.execute(stmt)
async def get_pinned_messages(self, chat_id: int) -> list[dict]: async def get_pinned_messages(self, chat_id: int) -> SPinnedMessages:
msg = aliased(Message, name="msg")
messages_with_users = (
select(
Message.id,
Message.message,
Message.image_url,
Message.chat_id,
Message.user_id,
Message.created_at,
Message.visibility,
Users.username,
Users.avatar_image,
MessageAnswer.self_id,
MessageAnswer.answer_id,
(
select(msg.message)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery()
).label("answer_message"),
)
.select_from(PinnedMessage)
.join(Message, PinnedMessage.message_id == Message.id)
.join(Users, Message.user_id == Users.id)
.join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
.where(PinnedMessage.chat_id == chat_id, Message.visibility == True) # noqa: E712
.order_by(Message.created_at.desc())
.cte("messages_with_users")
)
query = (
select(
func.json_build_object(
"pinned_messages",
func.json_agg(
func.json_build_object(
"id", messages_with_users.c.id,
"message", messages_with_users.c.message,
"image_url", messages_with_users.c.image_url,
"chat_id", messages_with_users.c.chat_id,
"user_id", messages_with_users.c.user_id,
"created_at", messages_with_users.c.created_at,
"avatar_image", messages_with_users.c.avatar_image,
"username", messages_with_users.c.username,
"answer_id", messages_with_users.c.answer_id,
"answer_message", messages_with_users.c.answer_message,
)
)
)
).select_from(messages_with_users)
)
print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса
query = ( query = (
select( select(
Message.id, Message.id,
@ -251,7 +314,6 @@ class ChatDAO(BaseDAO):
) )
result = await self.session.execute(query) result = await self.session.execute(query)
result = result.mappings().all() result = result.scalar_one()
if result:
result = [dict(res) for res in result] return SPinnedMessages.model_validate(result)
return result

View file

@ -1,8 +1,8 @@
"""Database Creation """Database Creation
Revision ID: 1cc709a9c827 Revision ID: 4668313943c0
Revises: Revises:
Create Date: 2024-06-08 17:51:30.648467 Create Date: 2024-06-08 18:36:31.974804
""" """
from typing import Sequence, Union from typing import Sequence, Union
@ -12,7 +12,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '1cc709a9c827' revision: str = '4668313943c0'
down_revision: Union[str, None] = None down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@ -85,14 +85,13 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint('self_id') sa.PrimaryKeyConstraint('self_id')
) )
op.create_table('pinned_message', op.create_table('pinned_message',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('chat_id', sa.Integer(), nullable=True), sa.Column('chat_id', sa.Integer(), nullable=True),
sa.Column('message_id', sa.Integer(), nullable=True), sa.Column('message_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True), sa.Column('user_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ), sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ),
sa.ForeignKeyConstraint(['message_id'], ['message.id'], ), sa.ForeignKeyConstraint(['message_id'], ['message.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('message_id')
) )
# ### end Alembic commands ### # ### end Alembic commands ###

View file

@ -1,5 +1,5 @@
from sqlalchemy import ForeignKey from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import mapped_column
from app.database import Base from app.database import Base
@ -7,7 +7,6 @@ from app.database import Base
class PinnedMessage(Base): class PinnedMessage(Base):
__tablename__ = "pinned_message" __tablename__ = "pinned_message"
id: Mapped[int] = mapped_column(primary_key=True)
chat_id = mapped_column(ForeignKey("chat.id")) chat_id = mapped_column(ForeignKey("chat.id"))
message_id = mapped_column(ForeignKey("message.id")) message_id = mapped_column(ForeignKey("message.id"), primary_key=True)
user_id = mapped_column(ForeignKey("users.id")) user_id = mapped_column(ForeignKey("users.id"))