chat_back/app/dao/chat.py
2024-06-12 20:02:55 +05:00

303 lines
9.8 KiB
Python

from sqlalchemy import insert, select, update, delete, func
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.orm import aliased
from app.dao.base import BaseDAO
from app.database import engine
from app.chat.exceptions import (
UserAlreadyInChatException,
UserAlreadyPinnedChatException,
MessageNotFoundException,
MessageAlreadyPinnedException,
ChatNotFoundException,
)
from app.chat.shemas import SMessage, SMessageList, SPinnedMessages, SPinnedChats, SChat
from app.models.users import Users
from app.models.message_answer import MessageAnswer
from app.models.chat import Chat
from app.models.message import Message
from app.models.pinned_chat import PinnedChat
from app.models.pinned_message import PinnedMessage
from app.models.user_chat import UserChat
class ChatDAO(BaseDAO):
model = Chat
@staticmethod
def check_query_compile(query):
print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса
async def find_one(self, chat_id: int, user_id: int) -> SChat:
try:
query = (
select(
Chat.id.label("chat_id"),
Chat.chat_for,
Chat.chat_name,
Users.avatar_image
)
.select_from(Chat)
.join(UserChat, Chat.id == UserChat.chat_id)
.join(Users, UserChat.user_id == Users.id)
.where(UserChat.chat_id == chat_id, UserChat.user_id == user_id)
)
result = await self.session.execute(query)
result = result.mappings().one()
return SChat.model_validate(result)
except NoResultFound:
raise ChatNotFoundException
async def create_chat(self, user_id: int, chat_name: str, created_by: int) -> int:
stmt = insert(Chat).values(chat_for=user_id, chat_name=chat_name, created_by=created_by).returning(Chat.id)
result = await self.session.execute(stmt)
chat_id = result.scalar()
return chat_id
async def add_user_to_chat(self, user_id: int, chat_id: int) -> None:
try:
stmt = insert(UserChat).values(user_id=user_id, chat_id=chat_id)
await self.session.execute(stmt)
except IntegrityError:
raise UserAlreadyInChatException
async def send_message(self, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> int:
stmt = (
insert(Message)
.values(chat_id=chat_id, user_id=user_id, message=message, image_url=image_url)
.returning(Message.id)
)
result = await self.session.execute(stmt)
return result.scalar()
async def get_message_by_id(self, message_id: int) -> SMessage:
try:
msg = aliased(Message, name="msg")
query = (
select(
func.json_build_object(
"id", Message.id,
"message", Message.message,
"image_url", Message.image_url,
"chat_id", Message.chat_id,
"user_id", Message.user_id,
"created_at", Message.created_at,
"avatar_image", Users.avatar_image,
"username", Users.username,
"answer_id", MessageAnswer.answer_id,
"answer_message", select(
msg.message
)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery()
)
)
.select_from(Message)
.join(Users, Users.id == Message.user_id)
.join(MessageAnswer, Message.id == MessageAnswer.self_id, isouter=True)
.where(Message.id == message_id, Message.visibility == True) # noqa: E712
)
result = await self.session.execute(query)
result = result.scalar_one()
return SMessage.model_validate(result)
except NoResultFound:
raise MessageNotFoundException
async def delete_message(self, message_id: int) -> None:
stmt = update(Message).values(visibility=False).where(Message.id == message_id)
await self.session.execute(stmt)
async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageList:
msg = aliased(Message, name="msg")
messages_with_users = (
select(
Message.id,
Message.message,
Message.image_url,
Message.chat_id,
Message.user_id,
Message.created_at,
Message.visibility,
Users.username,
Users.avatar_image,
MessageAnswer.self_id,
MessageAnswer.answer_id,
(
select(msg.message)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery()
).label("answer_message"),
)
.select_from(Message)
.join(Users, Message.user_id == Users.id)
.join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
.where(Message.chat_id == chat_id, Message.visibility == True) # noqa: E712
.order_by(Message.created_at.desc())
.limit(messages_to_get)
.offset(message_number_from)
.cte("messages_with_users")
)
query = (
select(
func.json_build_object(
"messages",
func.json_agg(
func.json_build_object(
"id", messages_with_users.c.id,
"message", messages_with_users.c.message,
"image_url", messages_with_users.c.image_url,
"chat_id", messages_with_users.c.chat_id,
"user_id", messages_with_users.c.user_id,
"created_at", messages_with_users.c.created_at,
"avatar_image", messages_with_users.c.avatar_image,
"username", messages_with_users.c.username,
"answer_id", messages_with_users.c.answer_id,
"answer_message", messages_with_users.c.answer_message,
)
)
)
).select_from(messages_with_users)
)
result = await self.session.execute(query)
result = result.scalar()
return SMessageList.model_validate(result)
async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> None:
stmt = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url)
await self.session.execute(stmt)
async def add_answer(self, self_id: int, answer_id: int) -> None:
stmt = (
insert(MessageAnswer)
.values(self_id=self_id, answer_id=answer_id)
)
await self.session.execute(stmt)
async def delete_chat(self, chat_id: int) -> None:
stmt = update(Chat).where(Chat.id == chat_id).values(visibility=False)
await self.session.execute(stmt)
async def delete_user_from_chat(self, chat_id: int, user_id: int) -> None:
stmt = delete(UserChat).where(UserChat.chat_id == chat_id, UserChat.user_id == user_id)
await self.session.execute(stmt)
async def pin_chat(self, chat_id: int, user_id: int) -> None:
try:
stmt = insert(PinnedChat).values(chat_id=chat_id, user_id=user_id)
await self.session.execute(stmt)
except IntegrityError:
raise UserAlreadyPinnedChatException
async def unpin_chat(self, chat_id: int, user_id: int) -> None:
stmt = delete(PinnedChat).where(PinnedChat.chat_id == chat_id, PinnedChat.user_id == user_id)
await self.session.execute(stmt)
async def get_pinned_chats(self, user_id: int) -> SPinnedChats:
query = (
select(
func.json_build_object(
"pinned_chats", func.json_agg(
func.json_build_object(
"chat_id", Chat.id,
"chat_for", Chat.chat_for,
"chat_name", Chat.chat_name,
"avatar_image", Users.avatar_image,
)
)
)
)
.select_from(PinnedChat)
.join(Users, PinnedChat.user_id == Users.id)
.join(Chat, PinnedChat.chat_id == Chat.id)
.where(PinnedChat.user_id == user_id, Chat.visibility == True) # noqa: E712
)
result = await self.session.execute(query)
result = result.scalar_one()
return SPinnedChats.model_validate(result)
async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> None:
try:
stmt = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id)
await self.session.execute(stmt)
except IntegrityError:
raise MessageAlreadyPinnedException
async def unpin_message(self, chat_id: int, message_id: int) -> None:
stmt = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id)
await self.session.execute(stmt)
async def get_pinned_messages(self, chat_id: int) -> SPinnedMessages:
msg = aliased(Message, name="msg")
messages_with_users = (
select(
Message.id,
Message.message,
Message.image_url,
Message.chat_id,
Message.user_id,
Message.created_at,
Message.visibility,
Users.username,
Users.avatar_image,
MessageAnswer.self_id,
MessageAnswer.answer_id,
(
select(msg.message)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True) # noqa
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery()
).label("answer_message"),
)
.select_from(PinnedMessage)
.join(Message, PinnedMessage.message_id == Message.id)
.join(Users, Message.user_id == Users.id)
.join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
.where(PinnedMessage.chat_id == chat_id, Message.visibility == True) # noqa: E712
.order_by(Message.created_at.desc())
.cte("messages_with_users")
)
query = (
select(
func.json_build_object(
"pinned_messages",
func.json_agg(
func.json_build_object(
"id", messages_with_users.c.id,
"message", messages_with_users.c.message,
"image_url", messages_with_users.c.image_url,
"chat_id", messages_with_users.c.chat_id,
"user_id", messages_with_users.c.user_id,
"created_at", messages_with_users.c.created_at,
"avatar_image", messages_with_users.c.avatar_image,
"username", messages_with_users.c.username,
"answer_id", messages_with_users.c.answer_id,
"answer_message", messages_with_users.c.answer_message,
)
)
)
).select_from(messages_with_users)
)
result = await self.session.execute(query)
result = result.scalar_one()
pinned_messages = SPinnedMessages.model_validate(result)
return pinned_messages