from pydantic import HttpUrl from sqlalchemy import insert, select, update, delete, func from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.orm import aliased from app.dao.base import BaseDAO from app.chat.exceptions import ( UserAlreadyInChatException, UserAlreadyPinnedChatException, MessageNotFoundException, MessageAlreadyPinnedException, ChatNotFoundException, ) from app.chat.shemas import SMessage, SMessageList, SPinnedMessages, SPinnedChats, SChat from app.models.users import Users from app.models.message_answer import MessageAnswer from app.models.chat import Chat from app.models.message import Message from app.models.pinned_chat import PinnedChat from app.models.pinned_message import PinnedMessage from app.models.user_chat import UserChat class ChatDAO(BaseDAO): model = Chat async def find_one(self, chat_id: int, user_id: int) -> SChat: try: query = ( select( Chat.id.label("chat_id"), Chat.chat_for, Chat.chat_name, Chat.created_by, Chat.avatar_image ) .select_from(Chat) .join(UserChat, Chat.id == UserChat.chat_id) .where(UserChat.chat_id == chat_id, UserChat.user_id == user_id) ) result = await self.session.execute(query) result = result.mappings().one() return SChat.model_validate(result) except NoResultFound: raise ChatNotFoundException async def create_chat(self, user_id: int, chat_name: str, created_by: int, avatar_image: HttpUrl) -> int: stmt = ( insert(Chat) .values(chat_for=user_id, chat_name=chat_name, created_by=created_by, avatar_image=str(avatar_image)) .returning(Chat.id) ) result = await self.session.execute(stmt) chat_id = result.scalar() return chat_id async def add_user_to_chat(self, user_id: int, chat_id: int) -> None: try: stmt = insert(UserChat).values(user_id=user_id, chat_id=chat_id) await self.session.execute(stmt) except IntegrityError: raise UserAlreadyInChatException async def send_message(self, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> int: stmt = ( insert(Message) .values(chat_id=chat_id, user_id=user_id, message=message, image_url=image_url) .returning(Message.id) ) result = await self.session.execute(stmt) return result.scalar() async def get_message_by_id(self, message_id: int) -> SMessage: try: msg = aliased(Message, name="msg") query = ( select( func.json_build_object( "id", Message.id, "message", Message.message, "image_url", Message.image_url, "chat_id", Message.chat_id, "user_id", Message.user_id, "created_at", Message.created_at, "avatar_image", Users.avatar_image, "username", Users.username, "answer_id", MessageAnswer.answer_id, "answer_message", select( msg.message ) .select_from(MessageAnswer) .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 .scalar_subquery(), "answer_image_url", select( msg.image_url ) .select_from(MessageAnswer) .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 .scalar_subquery(), ) ) .select_from(Message) .join(Users, Users.id == Message.user_id) .join(MessageAnswer, Message.id == MessageAnswer.self_id, isouter=True) .where(Message.id == message_id, Message.visibility == True) # noqa: E712 ) result = await self.session.execute(query) result = result.scalar_one() return SMessage.model_validate(result) except NoResultFound: raise MessageNotFoundException async def delete_message(self, message_id: int) -> None: stmt = update(Message).values(visibility=False).where(Message.id == message_id) await self.session.execute(stmt) async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageList: msg = aliased(Message, name="msg") messages_with_users = ( select( Message.id, Message.message, Message.image_url, Message.chat_id, Message.user_id, Message.created_at, Users.username, Users.avatar_image, MessageAnswer.self_id, MessageAnswer.answer_id, ( select(msg.message) .select_from(MessageAnswer) .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 .scalar_subquery() ).label("answer_message"), ( select(msg.image_url) .select_from(MessageAnswer) .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 .scalar_subquery() ).label("answer_image_url"), ) .select_from(Message) .join(Users, Message.user_id == Users.id) .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True) .where(Message.chat_id == chat_id, Message.visibility == True) # noqa: E712 .order_by(Message.created_at.desc()) .limit(messages_to_get) .offset(message_number_from) .cte("messages_with_users") ) query = ( select( func.json_build_object( "messages", func.json_agg( func.json_build_object( "id", messages_with_users.c.id, "message", messages_with_users.c.message, "image_url", messages_with_users.c.image_url, "chat_id", messages_with_users.c.chat_id, "user_id", messages_with_users.c.user_id, "created_at", messages_with_users.c.created_at, "avatar_image", messages_with_users.c.avatar_image, "username", messages_with_users.c.username, "answer_id", messages_with_users.c.answer_id, "answer_message", messages_with_users.c.answer_message, "answer_image_url", messages_with_users.c.answer_image_url, ) ) ) ).select_from(messages_with_users) ) result = await self.session.execute(query) result = result.scalar() return SMessageList.model_validate(result) async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> None: stmt = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url) await self.session.execute(stmt) async def add_answer(self, self_id: int, answer_id: int) -> None: stmt = ( insert(MessageAnswer) .values(self_id=self_id, answer_id=answer_id) ) await self.session.execute(stmt) async def change_data(self, chat_id: int, chat_name: str, avatar_image: HttpUrl) -> None: stmt = ( update(Chat) .values(chat_name=chat_name, avatar_image=avatar_image) .where(Chat.id == chat_id) ) await self.session.execute(stmt) async def delete_chat(self, chat_id: int) -> None: stmt = update(Chat).where(Chat.id == chat_id).values(visibility=False) await self.session.execute(stmt) async def delete_user_from_chat(self, chat_id: int, user_id: int) -> None: stmt = delete(UserChat).where(UserChat.chat_id == chat_id, UserChat.user_id == user_id) await self.session.execute(stmt) async def pin_chat(self, chat_id: int, user_id: int) -> None: try: stmt = insert(PinnedChat).values(chat_id=chat_id, user_id=user_id) await self.session.execute(stmt) except IntegrityError: raise UserAlreadyPinnedChatException async def unpin_chat(self, chat_id: int, user_id: int) -> None: stmt = delete(PinnedChat).where(PinnedChat.chat_id == chat_id, PinnedChat.user_id == user_id) await self.session.execute(stmt) async def get_pinned_chats(self, user_id: int) -> SPinnedChats: query = ( select( func.json_build_object( "pinned_chats", func.json_agg( func.json_build_object( "chat_id", Chat.id, "chat_for", Chat.chat_for, "chat_name", Chat.chat_name, "created_by", Chat.created_by, "avatar_image", Chat.avatar_image, ) ) ) ) .select_from(PinnedChat) .join(Chat, PinnedChat.chat_id == Chat.id) .where(PinnedChat.user_id == user_id, Chat.visibility == True) # noqa: E712 ) result = await self.session.execute(query) result = result.scalar_one() return SPinnedChats.model_validate(result) async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> None: try: stmt = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id) await self.session.execute(stmt) except IntegrityError: raise MessageAlreadyPinnedException async def unpin_message(self, chat_id: int, message_id: int) -> None: stmt = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id) await self.session.execute(stmt) async def get_pinned_messages(self, chat_id: int) -> SPinnedMessages: msg = aliased(Message, name="msg") messages_with_users = ( select( Message.id, Message.message, Message.image_url, Message.chat_id, Message.user_id, Message.created_at, Message.visibility, Users.username, Users.avatar_image, MessageAnswer.self_id, MessageAnswer.answer_id, ( select(msg.message) .select_from(MessageAnswer) .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) # noqa .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 .scalar_subquery() ).label("answer_message"), ( select(msg.image_url) .select_from(MessageAnswer) .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 .scalar_subquery() ).label("answer_image_url"), ) .select_from(PinnedMessage) .join(Message, PinnedMessage.message_id == Message.id) .join(Users, Message.user_id == Users.id) .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True) .where(PinnedMessage.chat_id == chat_id, Message.visibility == True) # noqa: E712 .order_by(Message.created_at.desc()) .cte("messages_with_users") ) query = ( select( func.json_build_object( "pinned_messages", func.json_agg( func.json_build_object( "id", messages_with_users.c.id, "message", messages_with_users.c.message, "image_url", messages_with_users.c.image_url, "chat_id", messages_with_users.c.chat_id, "user_id", messages_with_users.c.user_id, "created_at", messages_with_users.c.created_at, "avatar_image", messages_with_users.c.avatar_image, "username", messages_with_users.c.username, "answer_id", messages_with_users.c.answer_id, "answer_message", messages_with_users.c.answer_message, "answer_image_url", messages_with_users.c.answer_image_url ) ) ) ).select_from(messages_with_users) ) result = await self.session.execute(query) result = result.scalar_one() pinned_messages = SPinnedMessages.model_validate(result) return pinned_messages