333 lines
11 KiB
Python
333 lines
11 KiB
Python
from pydantic import HttpUrl
|
|
from sqlalchemy import insert, select, update, delete, func
|
|
from sqlalchemy.exc import IntegrityError, NoResultFound
|
|
from sqlalchemy.orm import aliased
|
|
|
|
from app.dao.base import BaseDAO
|
|
from app.chat.exceptions import (
|
|
UserAlreadyInChatException,
|
|
UserAlreadyPinnedChatException,
|
|
MessageNotFoundException,
|
|
MessageAlreadyPinnedException,
|
|
ChatNotFoundException,
|
|
)
|
|
from app.chat.shemas import SMessage, SMessageList, SPinnedMessages, SPinnedChats, SChat
|
|
from app.models.users import Users
|
|
from app.models.message_answer import MessageAnswer
|
|
from app.models.chat import Chat
|
|
from app.models.message import Message
|
|
from app.models.pinned_chat import PinnedChat
|
|
from app.models.pinned_message import PinnedMessage
|
|
from app.models.user_chat import UserChat
|
|
|
|
|
|
class ChatDAO(BaseDAO):
|
|
model = Chat
|
|
|
|
async def find_one(self, chat_id: int, user_id: int) -> SChat:
|
|
try:
|
|
query = (
|
|
select(
|
|
Chat.id.label("chat_id"),
|
|
Chat.chat_for,
|
|
Chat.chat_name,
|
|
Chat.created_by,
|
|
Chat.avatar_image
|
|
)
|
|
.select_from(Chat)
|
|
.join(UserChat, Chat.id == UserChat.chat_id)
|
|
.where(UserChat.chat_id == chat_id, UserChat.user_id == user_id)
|
|
)
|
|
|
|
result = await self.session.execute(query)
|
|
result = result.mappings().one()
|
|
return SChat.model_validate(result)
|
|
except NoResultFound:
|
|
raise ChatNotFoundException
|
|
|
|
async def create_chat(self, user_id: int, chat_name: str, created_by: int, avatar_image: HttpUrl) -> int:
|
|
stmt = (
|
|
insert(Chat)
|
|
.values(chat_for=user_id, chat_name=chat_name, created_by=created_by, avatar_image=str(avatar_image))
|
|
.returning(Chat.id)
|
|
)
|
|
|
|
result = await self.session.execute(stmt)
|
|
chat_id = result.scalar()
|
|
return chat_id
|
|
|
|
async def add_user_to_chat(self, user_id: int, chat_id: int) -> None:
|
|
try:
|
|
stmt = insert(UserChat).values(user_id=user_id, chat_id=chat_id)
|
|
await self.session.execute(stmt)
|
|
except IntegrityError:
|
|
raise UserAlreadyInChatException
|
|
|
|
async def send_message(self, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> int:
|
|
stmt = (
|
|
insert(Message)
|
|
.values(chat_id=chat_id, user_id=user_id, message=message, image_url=image_url)
|
|
.returning(Message.id)
|
|
)
|
|
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar()
|
|
|
|
async def get_message_by_id(self, message_id: int) -> SMessage:
|
|
try:
|
|
msg = aliased(Message, name="msg")
|
|
|
|
query = (
|
|
select(
|
|
func.json_build_object(
|
|
"id", Message.id,
|
|
"message", Message.message,
|
|
"image_url", Message.image_url,
|
|
"chat_id", Message.chat_id,
|
|
"user_id", Message.user_id,
|
|
"created_at", Message.created_at,
|
|
"avatar_image", Users.avatar_image,
|
|
"username", Users.username,
|
|
"answer_id", MessageAnswer.answer_id,
|
|
"answer_message", select(
|
|
msg.message
|
|
)
|
|
.select_from(MessageAnswer)
|
|
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
|
|
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
|
|
.scalar_subquery(),
|
|
"answer_image_url", select(
|
|
msg.image_url
|
|
)
|
|
.select_from(MessageAnswer)
|
|
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
|
|
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
|
|
.scalar_subquery(),
|
|
)
|
|
)
|
|
.select_from(Message)
|
|
.join(Users, Users.id == Message.user_id)
|
|
.join(MessageAnswer, Message.id == MessageAnswer.self_id, isouter=True)
|
|
.where(Message.id == message_id, Message.visibility == True) # noqa: E712
|
|
)
|
|
|
|
result = await self.session.execute(query)
|
|
result = result.scalar_one()
|
|
return SMessage.model_validate(result)
|
|
except NoResultFound:
|
|
raise MessageNotFoundException
|
|
|
|
async def delete_message(self, message_id: int) -> None:
|
|
stmt = update(Message).values(visibility=False).where(Message.id == message_id)
|
|
await self.session.execute(stmt)
|
|
|
|
async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageList:
|
|
|
|
msg = aliased(Message, name="msg")
|
|
|
|
messages_with_users = (
|
|
select(
|
|
Message.id,
|
|
Message.message,
|
|
Message.image_url,
|
|
Message.chat_id,
|
|
Message.user_id,
|
|
Message.created_at,
|
|
Users.username,
|
|
Users.avatar_image,
|
|
MessageAnswer.self_id,
|
|
MessageAnswer.answer_id,
|
|
(
|
|
select(msg.message)
|
|
.select_from(MessageAnswer)
|
|
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
|
|
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
|
|
.scalar_subquery()
|
|
).label("answer_message"),
|
|
(
|
|
select(msg.image_url)
|
|
.select_from(MessageAnswer)
|
|
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
|
|
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
|
|
.scalar_subquery()
|
|
).label("answer_image_url"),
|
|
)
|
|
.select_from(Message)
|
|
.join(Users, Message.user_id == Users.id)
|
|
.join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
|
|
.where(Message.chat_id == chat_id, Message.visibility == True) # noqa: E712
|
|
.order_by(Message.created_at.desc())
|
|
.limit(messages_to_get)
|
|
.offset(message_number_from)
|
|
.cte("messages_with_users")
|
|
)
|
|
|
|
query = (
|
|
select(
|
|
func.json_build_object(
|
|
"messages",
|
|
func.json_agg(
|
|
func.json_build_object(
|
|
"id", messages_with_users.c.id,
|
|
"message", messages_with_users.c.message,
|
|
"image_url", messages_with_users.c.image_url,
|
|
"chat_id", messages_with_users.c.chat_id,
|
|
"user_id", messages_with_users.c.user_id,
|
|
"created_at", messages_with_users.c.created_at,
|
|
"avatar_image", messages_with_users.c.avatar_image,
|
|
"username", messages_with_users.c.username,
|
|
"answer_id", messages_with_users.c.answer_id,
|
|
"answer_message", messages_with_users.c.answer_message,
|
|
"answer_image_url", messages_with_users.c.answer_image_url,
|
|
)
|
|
)
|
|
)
|
|
).select_from(messages_with_users)
|
|
)
|
|
|
|
result = await self.session.execute(query)
|
|
result = result.scalar()
|
|
return SMessageList.model_validate(result)
|
|
|
|
async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> None:
|
|
stmt = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url)
|
|
await self.session.execute(stmt)
|
|
|
|
async def add_answer(self, self_id: int, answer_id: int) -> None:
|
|
stmt = (
|
|
insert(MessageAnswer)
|
|
.values(self_id=self_id, answer_id=answer_id)
|
|
)
|
|
await self.session.execute(stmt)
|
|
|
|
async def change_data(self, chat_id: int, chat_name: str, avatar_image: HttpUrl) -> None:
|
|
stmt = (
|
|
update(Chat)
|
|
.values(chat_name=chat_name, avatar_image=avatar_image)
|
|
.where(Chat.id == chat_id)
|
|
)
|
|
await self.session.execute(stmt)
|
|
|
|
async def delete_chat(self, chat_id: int) -> None:
|
|
stmt = update(Chat).where(Chat.id == chat_id).values(visibility=False)
|
|
await self.session.execute(stmt)
|
|
|
|
async def delete_user_from_chat(self, chat_id: int, user_id: int) -> None:
|
|
stmt = delete(UserChat).where(UserChat.chat_id == chat_id, UserChat.user_id == user_id)
|
|
await self.session.execute(stmt)
|
|
|
|
async def pin_chat(self, chat_id: int, user_id: int) -> None:
|
|
try:
|
|
stmt = insert(PinnedChat).values(chat_id=chat_id, user_id=user_id)
|
|
await self.session.execute(stmt)
|
|
except IntegrityError:
|
|
raise UserAlreadyPinnedChatException
|
|
|
|
async def unpin_chat(self, chat_id: int, user_id: int) -> None:
|
|
stmt = delete(PinnedChat).where(PinnedChat.chat_id == chat_id, PinnedChat.user_id == user_id)
|
|
await self.session.execute(stmt)
|
|
|
|
async def get_pinned_chats(self, user_id: int) -> SPinnedChats:
|
|
|
|
query = (
|
|
select(
|
|
func.json_build_object(
|
|
"pinned_chats", func.json_agg(
|
|
func.json_build_object(
|
|
"chat_id", Chat.id,
|
|
"chat_for", Chat.chat_for,
|
|
"chat_name", Chat.chat_name,
|
|
"created_by", Chat.created_by,
|
|
"avatar_image", Chat.avatar_image,
|
|
)
|
|
)
|
|
)
|
|
)
|
|
.select_from(PinnedChat)
|
|
.join(Chat, PinnedChat.chat_id == Chat.id)
|
|
.where(PinnedChat.user_id == user_id, Chat.visibility == True) # noqa: E712
|
|
)
|
|
|
|
result = await self.session.execute(query)
|
|
result = result.scalar_one()
|
|
return SPinnedChats.model_validate(result)
|
|
|
|
async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> None:
|
|
try:
|
|
stmt = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id)
|
|
await self.session.execute(stmt)
|
|
except IntegrityError:
|
|
raise MessageAlreadyPinnedException
|
|
|
|
async def unpin_message(self, chat_id: int, message_id: int) -> None:
|
|
stmt = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id)
|
|
await self.session.execute(stmt)
|
|
|
|
async def get_pinned_messages(self, chat_id: int) -> SPinnedMessages:
|
|
|
|
msg = aliased(Message, name="msg")
|
|
|
|
messages_with_users = (
|
|
select(
|
|
Message.id,
|
|
Message.message,
|
|
Message.image_url,
|
|
Message.chat_id,
|
|
Message.user_id,
|
|
Message.created_at,
|
|
Message.visibility,
|
|
Users.username,
|
|
Users.avatar_image,
|
|
MessageAnswer.self_id,
|
|
MessageAnswer.answer_id,
|
|
(
|
|
select(msg.message)
|
|
.select_from(MessageAnswer)
|
|
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True) # noqa
|
|
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
|
|
.scalar_subquery()
|
|
).label("answer_message"),
|
|
(
|
|
select(msg.image_url)
|
|
.select_from(MessageAnswer)
|
|
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
|
|
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
|
|
.scalar_subquery()
|
|
).label("answer_image_url"),
|
|
)
|
|
.select_from(PinnedMessage)
|
|
.join(Message, PinnedMessage.message_id == Message.id)
|
|
.join(Users, Message.user_id == Users.id)
|
|
.join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
|
|
.where(PinnedMessage.chat_id == chat_id, Message.visibility == True) # noqa: E712
|
|
.order_by(Message.created_at.desc())
|
|
.cte("messages_with_users")
|
|
)
|
|
|
|
query = (
|
|
select(
|
|
func.json_build_object(
|
|
"pinned_messages",
|
|
func.json_agg(
|
|
func.json_build_object(
|
|
"id", messages_with_users.c.id,
|
|
"message", messages_with_users.c.message,
|
|
"image_url", messages_with_users.c.image_url,
|
|
"chat_id", messages_with_users.c.chat_id,
|
|
"user_id", messages_with_users.c.user_id,
|
|
"created_at", messages_with_users.c.created_at,
|
|
"avatar_image", messages_with_users.c.avatar_image,
|
|
"username", messages_with_users.c.username,
|
|
"answer_id", messages_with_users.c.answer_id,
|
|
"answer_message", messages_with_users.c.answer_message,
|
|
"answer_image_url", messages_with_users.c.answer_image_url
|
|
)
|
|
)
|
|
)
|
|
).select_from(messages_with_users)
|
|
)
|
|
|
|
result = await self.session.execute(query)
|
|
result = result.scalar_one()
|
|
pinned_messages = SPinnedMessages.model_validate(result)
|
|
return pinned_messages
|