chat_back/app/dao/chat.py
2025-03-16 12:04:09 +03:00

155 lines
4.5 KiB
Python

from uuid import UUID
from pydantic import HttpUrl
from sqlalchemy import delete, func, insert, select, update
from sqlalchemy.exc import IntegrityError, NoResultFound
from app.chat.exceptions import (
ChatNotFoundException,
MessageAlreadyPinnedException,
UserAlreadyPinnedChatException,
)
from app.chat.shemas import SChat, SPinnedChats
from app.dao.base import BaseDAO
from app.models.chat import Chat
from app.models.pinned_chat import PinnedChat
from app.models.pinned_message import PinnedMessage
from app.models.user_chat import UserChat
class ChatDAO(BaseDAO):
model = Chat
async def find_chat(self, chat_id: int, user_id: int) -> SChat:
try:
query = (
select(
Chat.id.label("chat_id"),
Chat.chat_for,
Chat.chat_name,
Chat.created_by,
Chat.avatar_image
)
.select_from(Chat)
.join(UserChat, Chat.id == UserChat.chat_id)
.where(UserChat.chat_id == chat_id, UserChat.user_id == user_id)
)
result = await self.session.execute(query)
result = result.mappings().one()
return SChat.model_validate(result)
except NoResultFound:
raise ChatNotFoundException
async def find_one(self, chat_id: int) -> SChat:
try:
query = (
select(
Chat.id.label("chat_id"),
Chat.chat_for,
Chat.chat_name,
Chat.created_by,
Chat.avatar_image
)
.select_from(Chat)
.where(Chat.id == chat_id)
)
result = await self.session.execute(query)
result = result.mappings().one()
return SChat.model_validate(result)
except NoResultFound:
raise ChatNotFoundException
async def create_chat(self, user_id: int, chat_name: str, created_by: int, avatar_image: HttpUrl) -> int:
stmt = (
insert(Chat)
.values(chat_for=user_id, chat_name=chat_name, created_by=created_by, avatar_image=str(avatar_image))
.returning(Chat.id)
)
result = await self.session.execute(stmt)
chat_id = result.scalar()
return chat_id
async def add_user_to_chat(self, user_id: int, chat_id: int) -> None:
try:
stmt = insert(UserChat).values(user_id=user_id, chat_id=chat_id)
await self.session.execute(stmt)
except IntegrityError:
pass
async def change_data(self, chat_id: int, chat_name: str, avatar_image: HttpUrl) -> None:
stmt = (
update(Chat)
.values(chat_name=chat_name, avatar_image=avatar_image)
.where(Chat.id == chat_id)
)
await self.session.execute(stmt)
async def delete_chat(self, chat_id: int) -> None:
stmt = update(Chat).where(Chat.id == chat_id).values(visibility=False)
await self.session.execute(stmt)
async def delete_user_from_chat(self, chat_id: int, user_id: int) -> None:
stmt = delete(UserChat).where(UserChat.chat_id == chat_id, UserChat.user_id == user_id)
await self.session.execute(stmt)
async def pin_chat(self, chat_id: int, user_id: int) -> None:
try:
stmt = insert(PinnedChat).values(chat_id=chat_id, user_id=user_id)
await self.session.execute(stmt)
except IntegrityError:
raise UserAlreadyPinnedChatException
async def unpin_chat(self, chat_id: int, user_id: int) -> None:
stmt = delete(PinnedChat).where(PinnedChat.chat_id == chat_id, PinnedChat.user_id == user_id)
await self.session.execute(stmt)
async def get_pinned_chats(self, user_id: int) -> SPinnedChats:
query = (
select(
func.json_build_object(
"pinned_chats", func.json_agg(
func.json_build_object(
"chat_id", Chat.id,
"chat_for", Chat.chat_for,
"chat_name", Chat.chat_name,
"created_by", Chat.created_by,
"avatar_image", Chat.avatar_image,
)
)
)
)
.select_from(PinnedChat)
.join(Chat, PinnedChat.chat_id == Chat.id)
.where(PinnedChat.user_id == user_id, Chat.visibility == True) # noqa: E712
)
result = await self.session.execute(query)
result = result.scalar_one()
return SPinnedChats.model_validate(result)
async def pin_message(self, chat_id: int, message_id: UUID, user_id: int) -> None:
try:
stmt = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id)
await self.session.execute(stmt)
except IntegrityError:
raise MessageAlreadyPinnedException
async def unpin_message(self, chat_id: int, message_id: UUID) -> None:
stmt = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id)
await self.session.execute(stmt)
async def get_pinned_messages_ids(self, chat_id: int) -> list[UUID]:
query = (
select(PinnedMessage.message_id)
.where(PinnedMessage.chat_id == chat_id)
)
result = await self.session.execute(query)
result = result.scalars().all()
return result # noqa