from sqlalchemy import insert, select, update, delete from app.dao.base import BaseDAO from app.database import engine # noqa from app.exceptions import UserAlreadyInChatException, UserAlreadyPinnedChatException from app.chat.shemas import SMessage from app.models.users import Users from app.models.message_answer import MessageAnswer from app.models.chat import Chats from app.models.message import Message from app.models.pinned_chat import PinnedChat from app.models.pinned_message import PinnedMessage from app.models.user_chat import UserChat class ChatDAO(BaseDAO): model = Chats async def find_one_or_none(self, **filter_by): query = select(Chats.__table__.columns).filter_by(**filter_by) result = await self.session.execute(query) result = result.scalar_one_or_none() async def create(self, user_id: int, chat_name: str, created_by: int) -> int: stmt = insert(Chats).values(chat_for=user_id, chat_name=chat_name, created_by=created_by).returning(Chats.id) result = await self.session.execute(stmt) await self.session.commit() result = result.scalar() return result async def add_user_to_chat(self, user_id: int, chat_id: int) -> bool: query = select(UserChat.user_id).where(UserChat.chat_id == chat_id) result = await self.session.execute(query) result = result.scalars().all() if user_id in result: raise UserAlreadyInChatException stmt = insert(UserChat).values(user_id=user_id, chat_id=chat_id) await self.session.execute(stmt) await self.session.commit() return True async def send_message(self, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> SMessage: inserted_image = ( insert(Message) .values(chat_id=chat_id, user_id=user_id, message=message, image_url=image_url) .returning( Message.id, Message.message, Message.image_url, Message.chat_id, Message.user_id, Message.created_at ) .cte("inserted_image") ) query = ( select( inserted_image.c.id, inserted_image.c.message, inserted_image.c.image_url, inserted_image.c.chat_id, inserted_image.c.user_id, inserted_image.c.created_at, Users.avatar_image, Users.username, Users.avatar_hex, MessageAnswer.self_id, MessageAnswer.answer_id, ) .select_from(inserted_image) .join(Users, Users.id == inserted_image.c.user_id) .join(MessageAnswer, MessageAnswer.self_id == inserted_image.c.id, isouter=True) ) result = await self.session.execute(query) await self.session.commit() result = result.mappings().one() return SMessage.model_validate(result, from_attributes=True) async def get_message_by_id(self, message_id: int): query = ( select( Message.id, Message.message, Message.image_url, Message.chat_id, Message.user_id, Message.created_at, Users.avatar_image, Users.username, Users.avatar_hex, MessageAnswer.self_id, MessageAnswer.answer_id, ) .select_from(Message) .join(Users, Users.id == Message.user_id) .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True) .where(Message.id == message_id, Message.visibility == True) # noqa: E712 ) result = await self.session.execute(query) result = result.mappings().all() if result: return result[0] async def delete_message(self, message_id: int) -> bool: query = update(Message).where(Message.id == message_id).values(visibility=False) await self.session.execute(query) await self.session.commit() return True async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> list[dict]: """ WITH messages_with_users AS ( SELECT * FROM messages LEFT JOIN users ON messages.user_id = users.id ) SELECT message, image_url, chat_id, user_id, created_at, avatar_image FROM messages_with_users WHERE visibility = true AND chat_id = 2 ORDER BY created_at DESC LIMIT 15 OFFSET 0; """ messages_with_users = ( select(Message.__table__.columns, Users.__table__.columns, MessageAnswer.__table__.columns) .select_from(Message) .join(Users, Message.user_id == Users.id) .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True) .cte("messages_with_users") ) messages = ( select( messages_with_users.c.id, messages_with_users.c.message, messages_with_users.c.image_url, messages_with_users.c.chat_id, messages_with_users.c.user_id, messages_with_users.c.created_at, messages_with_users.c.avatar_image, messages_with_users.c.username, messages_with_users.c.avatar_hex, messages_with_users.c.self_id, messages_with_users.c.answer_id, ) .where(messages_with_users.c.chat_id == chat_id, messages_with_users.c.visibility == True) # noqa: E712 .order_by(messages_with_users.c.created_at.desc()) .limit(messages_to_get) .offset(message_number_from) ) result = await self.session.execute(messages) result = result.mappings().all() if result: result = [dict(res) for res in result] return result async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> bool: query = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url) await self.session.execute(query) await self.session.commit() return True async def add_answer(self, self_id: int, answer_id: int) -> SMessage: answer = ( insert(MessageAnswer) .values(self_id=self_id, answer_id=answer_id) .returning(MessageAnswer.self_id, MessageAnswer.answer_id) .cte("answer") ) query = ( select( Message.id, Message.message, Message.image_url, Message.chat_id, Message.user_id, Message.created_at, Users.avatar_image, Users.username, Users.avatar_hex, answer.c.self_id, answer.c.answer_id, ) .select_from(Message) .join(Users, Users.id == Message.user_id) .join(answer, answer.c.self_id == Message.id, isouter=True) .where(Message.id == self_id) ) result = await self.session.execute(query) await self.session.commit() result = result.mappings().one() return SMessage.model_validate(result) async def delete_chat(self, chat_id: int) -> bool: query = update(Chats).where(Chats.id == chat_id).values(visibility=False) await self.session.execute(query) await self.session.commit() return True async def delete_user(self, chat_id: int, user_id: int) -> bool: query = delete(UserChat).where(UserChat.chat_id == chat_id, UserChat.user_id == user_id) await self.session.execute(query) await self.session.commit() return True async def pin_chat(self, chat_id: int, user_id: int) -> bool: query = select(PinnedChat.chat_id).where(PinnedChat.user_id == user_id) result = await self.session.execute(query) result = result.scalars().all() if chat_id in result: raise UserAlreadyPinnedChatException stmt = insert(PinnedChat).values(chat_id=chat_id, user_id=user_id) await self.session.execute(stmt) await self.session.commit() return True async def unpin_chat(self, chat_id: int, user_id: int) -> bool: query = delete(PinnedChat).where(PinnedChat.chat_id == chat_id, PinnedChat.user_id == user_id) await self.session.execute(query) await self.session.commit() return True async def get_pinned_chats(self, user_id: int): chats_with_descriptions = ( select(UserChat.__table__.columns, Chats.__table__.columns) .select_from(UserChat) .join(Chats, UserChat.chat_id == Chats.id) .cte("chats_with_descriptions") ) chats_with_avatars = ( select( chats_with_descriptions.c.chat_id, chats_with_descriptions.c.chat_for, chats_with_descriptions.c.chat_name, chats_with_descriptions.c.visibility, Users.id, Users.avatar_image, Users.avatar_hex, ) .select_from(chats_with_descriptions) .join(Users, Users.id == chats_with_descriptions.c.user_id) .cte("chats_with_avatars") ) query = ( select( chats_with_avatars.c.chat_id, chats_with_avatars.c.chat_for, chats_with_avatars.c.chat_name, chats_with_avatars.c.avatar_image, chats_with_avatars.c.avatar_hex, ) .distinct() .select_from(PinnedChat) .join(chats_with_avatars, PinnedChat.chat_id == chats_with_avatars.c.chat_id) .where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True) # noqa: E712 ) # print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса result = await self.session.execute(query) result = result.mappings().all() return result async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> bool: query = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id) await self.session.execute(query) return True async def unpin_message(self, chat_id: int, message_id: int) -> bool: query = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id) await self.session.execute(query) return True async def get_pinned_messages(self, chat_id: int) -> list[dict]: query = ( select( Message.id, Message.message, Message.image_url, Message.chat_id, Message.user_id, Message.created_at, Users.avatar_image, Users.username, Users.avatar_hex, MessageAnswer.self_id, MessageAnswer.answer_id, ) .select_from(PinnedMessage) .join(Message, PinnedMessage.message_id == Message.id, isouter=True) .join(Users, PinnedMessage.user_id == Users.id, isouter=True) .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True) .where(PinnedMessage.chat_id == chat_id, Message.visibility == True) # noqa: E712 .order_by(Message.created_at.desc()) ) result = await self.session.execute(query) result = result.mappings().all() if result: result = [dict(res) for res in result] return result