diff --git a/app/chat/router.py b/app/chat/router.py index e69f087..9218fed 100644 --- a/app/chat/router.py +++ b/app/chat/router.py @@ -197,7 +197,7 @@ async def get_some_messages( uow=uow, chat_id=chat_id, message_number_from=last_messages.messages_loaded, - messages_to_get=last_messages.messages_to_get + messages_to_get=last_messages.messages_to_get, ) return messages diff --git a/app/chat/shemas.py b/app/chat/shemas.py index dedb765..25620dc 100644 --- a/app/chat/shemas.py +++ b/app/chat/shemas.py @@ -1,10 +1,11 @@ from datetime import datetime +from uuid import UUID from pydantic import BaseModel, HttpUrl class SMessage(BaseModel): - id: int + id: int | UUID # TODO: Заменить на UUID message: str | None = None image_url: str | None = None chat_id: int @@ -12,19 +13,19 @@ class SMessage(BaseModel): username: str created_at: datetime avatar_image: str - answer_id: int | None + answer_id: int | None | UUID # TODO: Заменить на UUID answer_message: str | None answer_image_url: str | None class SMessageRaw(BaseModel): - _id: str + id: UUID message: str | None = None image_url: str | None = None chat_id: int user_id: int created_at: datetime - answer_id: int | None = None + answer_id: UUID | None = None answer_message: str | None = None answer_image_url: str | None = None @@ -76,7 +77,7 @@ class SSendMessage(BaseModel): flag: str message: str image_url: str | None - answer: int | None + answer: UUID | None class SDeleteMessage(BaseModel): diff --git a/app/chat/websocket.py b/app/chat/websocket.py index 1b521f1..d82378a 100644 --- a/app/chat/websocket.py +++ b/app/chat/websocket.py @@ -4,7 +4,7 @@ from collections import defaultdict import websockets from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException, status -from app.chat.exceptions import UseWSException +from app.chat.exceptions import UseWSException, MessageNotFoundException, MessageAlreadyPinnedException from app.exceptions import IncorrectDataException from app.chat.exceptions import UserDontHavePermissionException from app.services.message_service import MessageService @@ -20,33 +20,37 @@ class ConnectionManager: def __init__(self): self.active_connections: dict[int, list[WebSocket]] = defaultdict(list) self.message_methods = { - "send": self.send, - "delete": self.delete, - "edit": self.edit, - "pin": self.pin, - "unpin": self.unpin, + "send": self._send, + "delete": self._delete, + "edit": self._edit, + "pin": self._pin, + "unpin": self._unpin, } - async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None): + async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None) -> None: await websocket.accept(subprotocol=subprotocol) self.active_connections[chat_id].append(websocket) - async def disconnect(self, chat_id: int, websocket: WebSocket, code_and_reason: tuple[int, str] | None = None): + async def disconnect( + self, chat_id: int, websocket: WebSocket, code_and_reason: tuple[int, str] | None = None + ) -> None: self.active_connections[chat_id].remove(websocket) if code_and_reason: await websocket.close(code=code_and_reason[0], reason=code_and_reason[1]) - async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict): + async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> None: try: new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message) for websocket in self.active_connections[chat_id]: await websocket.send_json(new_message) await polling_manager.send(chat_id, new_message) + except (MessageNotFoundException, MessageAlreadyPinnedException): + pass except KeyError: raise IncorrectDataException @staticmethod - async def send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict: + async def _send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict: message = SSendMessage.model_validate(message) new_message = await MessageService.send_message( @@ -58,7 +62,7 @@ class ConnectionManager: return new_message @staticmethod - async def delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: + async def _delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: message = SDeleteMessage.model_validate(message) if message.user_id != user_id: @@ -69,7 +73,7 @@ class ConnectionManager: return new_message @staticmethod - async def edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: + async def _edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: message = SEditMessage.model_validate(message) if message.user_id != user_id: @@ -87,7 +91,7 @@ class ConnectionManager: return new_message @staticmethod - async def pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: + async def _pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: message = SPinMessage.model_validate(message) pinned_message = await MessageService.pin_message( uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id @@ -98,7 +102,7 @@ class ConnectionManager: return new_message @staticmethod - async def unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: + async def _unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: message = SUnpinMessage.model_validate(message) await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id) new_message = {"flag": "unpin", "id": message.id} @@ -179,7 +183,7 @@ class PollingManager: self.waiters: dict[int, list[asyncio.Future]] = defaultdict(list) self.messages: dict[int, list[dict]] = defaultdict(list) - async def poll(self, chat_id: int): + async def poll(self, chat_id: int) -> dict: future = asyncio.Future() self.waiters[chat_id].append(future) try: @@ -196,7 +200,7 @@ class PollingManager: user_id: int, chat_id: int, message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage, - ): + ) -> None: message = message.model_dump() await manager.broadcast(uow=uow, user_id=user_id, chat_id=chat_id, message=message) @@ -206,7 +210,7 @@ class PollingManager: self, chat_id: int, message: dict, - ): + ) -> None: self.messages[chat_id].append(message) while self.waiters[chat_id]: waiter = self.waiters[chat_id].pop(0) diff --git a/app/dao/chat.py b/app/dao/chat.py index 5a26df7..cd24531 100644 --- a/app/dao/chat.py +++ b/app/dao/chat.py @@ -1,22 +1,18 @@ +from uuid import UUID + from pydantic import HttpUrl from sqlalchemy import insert, select, update, delete, func from sqlalchemy.exc import IntegrityError, NoResultFound -from sqlalchemy.orm import aliased from app.dao.base import BaseDAO from app.chat.exceptions import ( UserAlreadyInChatException, UserAlreadyPinnedChatException, - MessageNotFoundException, MessageAlreadyPinnedException, ChatNotFoundException, ) -from app.chat.shemas import SMessage, SMessageList, SPinnedMessages, SPinnedChats, SChat -from app.database import db -from app.models.users import Users -from app.models.message_answer import MessageAnswer +from app.chat.shemas import SPinnedChats, SChat from app.models.chat import Chat -from app.models.message import Message from app.models.pinned_chat import PinnedChat from app.models.pinned_message import PinnedMessage from app.models.user_chat import UserChat @@ -64,142 +60,6 @@ class ChatDAO(BaseDAO): except IntegrityError: raise UserAlreadyInChatException - async def send_message(self, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> int: - stmt = ( - insert(Message) - .values(chat_id=chat_id, user_id=user_id, message=message, image_url=image_url) - .returning(Message.id) - ) - - result = await self.session.execute(stmt) - return result.scalar() - - async def get_message_by_id(self, message_id: int) -> SMessage: - try: - msg = aliased(Message, name="msg") - - query = ( - select( - func.json_build_object( - "id", Message.id, - "message", Message.message, - "image_url", Message.image_url, - "chat_id", Message.chat_id, - "user_id", Message.user_id, - "created_at", Message.created_at, - "avatar_image", Users.avatar_image, - "username", Users.username, - "answer_id", MessageAnswer.answer_id, - "answer_message", select( - msg.message - ) - .select_from(MessageAnswer) - .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) - .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 - .scalar_subquery(), - "answer_image_url", select( - msg.image_url - ) - .select_from(MessageAnswer) - .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) - .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 - .scalar_subquery(), - ) - ) - .select_from(Message) - .join(Users, Users.id == Message.user_id) - .join(MessageAnswer, Message.id == MessageAnswer.self_id, isouter=True) - .where(Message.id == message_id, Message.visibility == True) # noqa: E712 - ) - - result = await self.session.execute(query) - result = result.scalar_one() - return SMessage.model_validate(result) - except NoResultFound: - raise MessageNotFoundException - - async def delete_message(self, message_id: int) -> None: - stmt = update(Message).values(visibility=False).where(Message.id == message_id) - await self.session.execute(stmt) - - async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageList: - - msg = aliased(Message, name="msg") - - messages_with_users = ( - select( - Message.id, - Message.message, - Message.image_url, - Message.chat_id, - Message.user_id, - Message.created_at, - Users.username, - Users.avatar_image, - MessageAnswer.answer_id, - ( - select(msg.message) - .select_from(MessageAnswer) - .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) - .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 - .scalar_subquery() - ).label("answer_message"), - ( - select(msg.image_url) - .select_from(MessageAnswer) - .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) - .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 - .scalar_subquery() - ).label("answer_image_url"), - ) - .select_from(Message) - .join(Users, Message.user_id == Users.id) - .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True) - .where(Message.chat_id == chat_id, Message.visibility == True) # noqa: E712 - .order_by(Message.created_at.desc()) - .limit(messages_to_get) - .offset(message_number_from) - .cte("messages_with_users") - ) - - query = ( - select( - func.json_build_object( - "messages", - func.json_agg( - func.json_build_object( - "id", messages_with_users.c.id, - "message", messages_with_users.c.message, - "image_url", messages_with_users.c.image_url, - "chat_id", messages_with_users.c.chat_id, - "user_id", messages_with_users.c.user_id, - "created_at", messages_with_users.c.created_at, - "avatar_image", messages_with_users.c.avatar_image, - "username", messages_with_users.c.username, - "answer_id", messages_with_users.c.answer_id, - "answer_message", messages_with_users.c.answer_message, - "answer_image_url", messages_with_users.c.answer_image_url, - ) - ) - ) - ).select_from(messages_with_users) - ) - - result = await self.session.execute(query) - result = result.scalar() - return SMessageList.model_validate(result) - - async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> None: - stmt = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url) - await self.session.execute(stmt) - - async def add_answer(self, self_id: int, answer_id: int) -> None: - stmt = ( - insert(MessageAnswer) - .values(self_id=self_id, answer_id=answer_id) - ) - await self.session.execute(stmt) - async def change_data(self, chat_id: int, chat_name: str, avatar_image: HttpUrl) -> None: stmt = ( update(Chat) @@ -252,94 +112,26 @@ class ChatDAO(BaseDAO): result = result.scalar_one() return SPinnedChats.model_validate(result) - async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> None: + async def pin_message(self, chat_id: int, message_id: UUID, user_id: int) -> None: try: stmt = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id) await self.session.execute(stmt) except IntegrityError: raise MessageAlreadyPinnedException - async def unpin_message(self, chat_id: int, message_id: int) -> None: + async def unpin_message(self, chat_id: int, message_id: UUID) -> None: stmt = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id) await self.session.execute(stmt) - async def get_pinned_messages(self, chat_id: int) -> SPinnedMessages: - - msg = aliased(Message, name="msg") - - messages_with_users = ( - select( - Message.id, - Message.message, - Message.image_url, - Message.chat_id, - Message.user_id, - Message.created_at, - Message.visibility, - Users.username, - Users.avatar_image, - MessageAnswer.self_id, - MessageAnswer.answer_id, - ( - select(msg.message) - .select_from(MessageAnswer) - .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) # noqa - .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 - .scalar_subquery() - ).label("answer_message"), - ( - select(msg.image_url) - .select_from(MessageAnswer) - .join(msg, msg.id == MessageAnswer.answer_id, isouter=True) - .where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712 - .scalar_subquery() - ).label("answer_image_url"), - ) - .select_from(PinnedMessage) - .join(Message, PinnedMessage.message_id == Message.id) - .join(Users, Message.user_id == Users.id) - .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True) - .where(PinnedMessage.chat_id == chat_id, Message.visibility == True) # noqa: E712 - .order_by(Message.created_at.desc()) - .cte("messages_with_users") - ) - + async def get_pinned_messages_ids(self, chat_id: int) -> list[UUID]: query = ( - select( - func.json_build_object( - "pinned_messages", - func.json_agg( - func.json_build_object( - "id", messages_with_users.c.id, - "message", messages_with_users.c.message, - "image_url", messages_with_users.c.image_url, - "chat_id", messages_with_users.c.chat_id, - "user_id", messages_with_users.c.user_id, - "created_at", messages_with_users.c.created_at, - "avatar_image", messages_with_users.c.avatar_image, - "username", messages_with_users.c.username, - "answer_id", messages_with_users.c.answer_id, - "answer_message", messages_with_users.c.answer_message, - "answer_image_url", messages_with_users.c.answer_image_url - ) - ) - ) - ).select_from(messages_with_users) + select(PinnedMessage.message_id) + .where(PinnedMessage.chat_id == chat_id) ) result = await self.session.execute(query) - result = result.scalar_one() - pinned_messages = SPinnedMessages.model_validate(result) - return pinned_messages + result = result.scalars().all() + return result # noqa + + - async def test_mongo(self): - await db.message.insert_one( - { - "message": "пизда", - "chat_id": 1, - "user_id": 1, - "image_url": "", - "created_at": "2024-06-13 13:49:59.14757+00", - "visibility": True, - } - ) diff --git a/app/dao/message.py b/app/dao/message.py new file mode 100644 index 0000000..8d047fa --- /dev/null +++ b/app/dao/message.py @@ -0,0 +1,60 @@ +from datetime import datetime, UTC +from uuid import uuid4, UUID + +from app.chat.shemas import SMessageRaw, SMessageRawList + + +class MessageDAO: + def __init__(self, mongo_db): + self.message = mongo_db.message + + async def send_message( + self, + user_id: int, + chat_id: int, + message: str | None, + image_url: str | None, + answer_id: UUID | None = None, + answer_message: str | None = None, + answer_image_url: str | None = None, + ) -> UUID: + message_id = uuid4() + await self.message.insert_one( + { + "id": str(message_id), + "message": message, + "image_url": image_url, + "chat_id": chat_id, + "user_id": user_id, + "created_at": datetime.now(UTC), + "answer_id": str(answer_id) if answer_id else None, + "answer_message": answer_message, + "answer_image_url": answer_image_url, + "visibility": True, + } + ) + return message_id + + async def get_message_by_id(self, message_id: UUID) -> SMessageRaw: + message = self.message.find_one({"id": str(message_id)}) + return SMessageRaw.model_validate(message) + + async def delete_message(self, message_id: UUID) -> None: + await self.message.update_one({"id": str(message_id)}, {"$set": {"visibility": False}}) + + async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageRawList: + cursor = self.message.find({"visibility": True, "chat_id": chat_id}) + cursor.sort("created_at").skip(message_number_from) + return SMessageRawList.model_validate({"message_raw_list": await cursor.to_list(length=messages_to_get)}) + + async def edit_message(self, message_id: UUID, new_message: str, new_image_url: str) -> None: + await self.message.update_one( + {"id": str(message_id)}, + {"$set": {"message": new_message, "image_url": new_image_url}} + ) + + async def get_messages_from_ids(self, messages_ids: list[UUID]) -> SMessageRawList: + cursor = self.message.find({"visibility": True, "id": {"$in": [str(message_id) for message_id in messages_ids]}}) + return SMessageRawList.model_validate({"message_raw_list": await cursor.to_list(length=None) or None}) + + diff --git a/app/dao/user.py b/app/dao/user.py index 5ca0b85..ad50023 100644 --- a/app/dao/user.py +++ b/app/dao/user.py @@ -1,4 +1,4 @@ -from pydantic import HttpUrl, EmailStr +from pydantic import EmailStr from sqlalchemy import update, select, insert, func, or_ from sqlalchemy.exc import MultipleResultsFound, IntegrityError, NoResultFound diff --git a/app/database.py b/app/database.py index 68d2a1d..4d04d5f 100644 --- a/app/database.py +++ b/app/database.py @@ -29,5 +29,5 @@ MONGO_URL = f"mongodb://{settings.MONGO_HOST}:{settings.MONGO_PORT}" mongo_client = AsyncIOMotorClient(MONGO_URL) -db = mongo_client.test_db +mongo_db = mongo_client.test_db diff --git a/app/migrations/env.py b/app/migrations/env.py index 69dfdb1..baf231c 100644 --- a/app/migrations/env.py +++ b/app/migrations/env.py @@ -6,9 +6,7 @@ from alembic import context from app.database import DATABASE_URL, Base from app.models.users import Users # noqa -from app.models.message_answer import MessageAnswer # noqa from app.models.chat import Chat # noqa -from app.models.message import Message # noqa from app.models.pinned_chat import PinnedChat # noqa from app.models.pinned_message import PinnedMessage # noqa from app.models.user_chat import UserChat # noqa diff --git a/app/migrations/versions/2024-06-13_database_creation.py b/app/migrations/versions/2024-08-20_12-30-52_database_creation.py similarity index 71% rename from app/migrations/versions/2024-06-13_database_creation.py rename to app/migrations/versions/2024-08-20_12-30-52_database_creation.py index 4b251a9..753f398 100644 --- a/app/migrations/versions/2024-06-13_database_creation.py +++ b/app/migrations/versions/2024-08-20_12-30-52_database_creation.py @@ -1,8 +1,8 @@ """Database Creation -Revision ID: c69369209bab +Revision ID: 53fd6f2b93a4 Revises: -Create Date: 2024-06-13 18:40:03.297322 +Create Date: 2024-08-20 12:30:52.435668 """ from typing import Sequence, Union @@ -12,7 +12,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = 'c69369209bab' +revision: str = '53fd6f2b93a4' down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -52,18 +52,6 @@ def upgrade() -> None: sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), sa.PrimaryKeyConstraint('id') ) - op.create_table('message', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('message', sa.String(), nullable=True), - sa.Column('image_url', sa.String(), nullable=True), - sa.Column('chat_id', sa.Integer(), nullable=True), - sa.Column('user_id', sa.Integer(), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('visibility', sa.Boolean(), server_default='true', nullable=False), - sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') - ) op.create_table('pinned_chat', sa.Column('user_id', sa.Integer(), nullable=False), sa.Column('chat_id', sa.Integer(), nullable=False), @@ -71,6 +59,14 @@ def upgrade() -> None: sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), sa.PrimaryKeyConstraint('user_id', 'chat_id') ) + op.create_table('pinned_message', + sa.Column('chat_id', sa.Integer(), nullable=True), + sa.Column('message_id', sa.Uuid(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('message_id', 'user_id') + ) op.create_table('user_chat', sa.Column('user_id', sa.Integer(), nullable=False), sa.Column('chat_id', sa.Integer(), nullable=False), @@ -78,32 +74,14 @@ def upgrade() -> None: sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), sa.PrimaryKeyConstraint('user_id', 'chat_id') ) - op.create_table('message_answer', - sa.Column('self_id', sa.Integer(), nullable=False), - sa.Column('answer_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['answer_id'], ['message.id'], ), - sa.ForeignKeyConstraint(['self_id'], ['message.id'], ), - sa.PrimaryKeyConstraint('self_id') - ) - op.create_table('pinned_message', - sa.Column('chat_id', sa.Integer(), nullable=True), - sa.Column('message_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ), - sa.ForeignKeyConstraint(['message_id'], ['message.id'], ), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('message_id') - ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('pinned_message') - op.drop_table('message_answer') op.drop_table('user_chat') + op.drop_table('pinned_message') op.drop_table('pinned_chat') - op.drop_table('message') op.drop_table('user_avatar') op.drop_table('chat') op.drop_table('users') diff --git a/app/models/message.py b/app/models/message.py deleted file mode 100644 index d2d97ad..0000000 --- a/app/models/message.py +++ /dev/null @@ -1,18 +0,0 @@ -from datetime import datetime - -from sqlalchemy import ForeignKey, func, DateTime -from sqlalchemy.orm import Mapped, mapped_column - -from app.database import Base - - -class Message(Base): - __tablename__ = "message" - - id: Mapped[int] = mapped_column(primary_key=True) - chat_id = mapped_column(ForeignKey("chat.id")) - user_id = mapped_column(ForeignKey("users.id")) - message: Mapped[str | None] - image_url: Mapped[str | None] - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) - visibility: Mapped[bool] = mapped_column(server_default="true") diff --git a/app/models/message_answer.py b/app/models/message_answer.py deleted file mode 100644 index 4a4f1ea..0000000 --- a/app/models/message_answer.py +++ /dev/null @@ -1,11 +0,0 @@ -from sqlalchemy import ForeignKey -from sqlalchemy.orm import Mapped, mapped_column - -from app.database import Base - - -class MessageAnswer(Base): - __tablename__ = "message_answer" - - self_id: Mapped[int] = mapped_column(ForeignKey("message.id"), primary_key=True) - answer_id: Mapped[int] = mapped_column(ForeignKey("message.id")) diff --git a/app/models/pinned_message.py b/app/models/pinned_message.py index d6889da..eb9807c 100644 --- a/app/models/pinned_message.py +++ b/app/models/pinned_message.py @@ -1,5 +1,7 @@ +from uuid import UUID + from sqlalchemy import ForeignKey -from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import mapped_column, Mapped from app.database import Base @@ -8,5 +10,5 @@ class PinnedMessage(Base): __tablename__ = "pinned_message" chat_id = mapped_column(ForeignKey("chat.id")) - message_id = mapped_column(ForeignKey("message.id"), primary_key=True) - user_id = mapped_column(ForeignKey("users.id")) + message_id: Mapped[UUID] = mapped_column(primary_key=True) + user_id = mapped_column(ForeignKey("users.id"), primary_key=True) diff --git a/app/services/chat_service.py b/app/services/chat_service.py index 5dad71f..9a45d80 100644 --- a/app/services/chat_service.py +++ b/app/services/chat_service.py @@ -126,5 +126,4 @@ class ChatService: async def get_pinned_chats(uow: UnitOfWork, user_id: int) -> SPinnedChats: async with uow: pinned_chats = await uow.chat.get_pinned_chats(user_id=user_id) - await uow.chat.test_mongo() return pinned_chats diff --git a/app/services/message_service.py b/app/services/message_service.py index 2fcfc5a..ef7a1b9 100644 --- a/app/services/message_service.py +++ b/app/services/message_service.py @@ -1,54 +1,93 @@ -from app.chat.shemas import SMessage, SSendMessage, SPinnedMessages, SMessageList, SMessageRaw +from datetime import datetime, UTC, timedelta + +from app.chat.shemas import SMessage, SSendMessage, SPinnedMessages, SMessageList, SMessageRaw, SMessageRawList from app.services.user_service import UserService +from app.users.schemas import SUser from app.utils.unit_of_work import UnitOfWork class MessageService: - @staticmethod - async def add_avatar_image_and_username_to_message(uow: UnitOfWork, message: SMessageRaw) -> SMessage: + users: dict[int, tuple[SUser, datetime]] = {} + + @classmethod + async def _get_cached_user(cls, uow: UnitOfWork, user_id: int) -> SUser: + if user_id in cls.users and cls.users[user_id][1] > datetime.now(UTC): + return cls.users[user_id][0] + + user = await UserService.find_user(uow=uow, id=user_id) + cls.users[user_id] = user, datetime.now(UTC) + timedelta(minutes=5) + return user + + @classmethod + async def add_avatar_image_and_username_to_message(cls, uow: UnitOfWork, message: SMessageRaw) -> SMessage: + user = await cls._get_cached_user(uow=uow, user_id=message.user_id) message = message.model_dump() - user = await UserService.find_user(uow=uow, id=message["user_id"]) - message["id"] = message["_id"] - message["avatar_image"] = user.avatar_image + message["avatar_image"] = str(user.avatar_image) message["username"] = user.username return SMessage.model_validate(message) - @staticmethod + @classmethod + async def add_avatar_image_and_username_to_message_list( + cls, uow: UnitOfWork, messages: SMessageRawList + ) -> SMessageList: + return SMessageList.model_validate( + {"messages": [ + await cls.add_avatar_image_and_username_to_message(uow=uow, message=message) + for message in messages.message_raw_list + ] if messages.message_raw_list else None + } + ) + + @classmethod async def send_message( - uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage, image_url: str | None = None + cls, uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage, image_url: str | None = None ) -> SMessage: async with uow: - message_id = await uow.chat.send_message( - user_id=user_id, chat_id=chat_id, message=message.message, image_url=image_url - ) if message.answer: - await uow.chat.add_answer(self_id=message_id, answer_id=message.answer) - new_message = await uow.chat.get_message_by_id(message_id=message_id) - await uow.commit() + answer_message = await uow.message.get_message_by_id(message_id=message.answer) + message_id = await uow.message.send_message( + user_id=user_id, + chat_id=chat_id, + message=message.message, + image_url=image_url, + answer_id=answer_message.id, + answer_message=answer_message.message, + answer_image_url=answer_message.answer_image_url, + ) + else: + message_id = await uow.message.send_message( + user_id=user_id, + chat_id=chat_id, + message=message.message, + image_url=image_url, + ) + raw_message = await uow.message.get_message_by_id(message_id=message_id) + new_message = await cls.add_avatar_image_and_username_to_message(uow=uow, message=raw_message) - return new_message + return new_message @staticmethod async def delete_message(uow: UnitOfWork, message_id: int) -> None: async with uow: - await uow.chat.delete_message(message_id=message_id) - await uow.commit() + await uow.message.delete_message(message_id=message_id) @staticmethod async def edit_message(uow: UnitOfWork, message_id: int, new_message: str, new_image_url: str) -> None: async with uow: - await uow.chat.edit_message( + await uow.message.edit_message( message_id=message_id, new_message=new_message, new_image_url=new_image_url ) - await uow.commit() - @staticmethod - async def pin_message(uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage: + @classmethod + async def pin_message(cls, uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage: async with uow: await uow.chat.pin_message(chat_id=chat_id, message_id=message_id, user_id=user_id) - pinned_message = await uow.chat.get_message_by_id(message_id=message_id) await uow.commit() - return pinned_message + + raw_message = await uow.message.get_message_by_id(message_id=message_id) + pinned_message = await cls.add_avatar_image_and_username_to_message(uow=uow, message=raw_message) + + return pinned_message @staticmethod async def unpin_message(uow: UnitOfWork, chat_id: int, message_id: int) -> None: @@ -56,24 +95,28 @@ class MessageService: await uow.chat.unpin_message(chat_id=chat_id, message_id=message_id) await uow.commit() - @staticmethod - async def get_message_by_id(uow: UnitOfWork, message_id: int) -> SMessage: + @classmethod + async def get_message_by_id(cls, uow: UnitOfWork, message_id: int) -> SMessage: async with uow: - message = await uow.chat.get_message_by_id(message_id=message_id) + raw_message = await uow.message.get_message_by_id(message_id=message_id) + message = await cls.add_avatar_image_and_username_to_message(uow=uow, message=raw_message) return message - @staticmethod - async def get_pinned_messages(uow: UnitOfWork, chat_id: int) -> SPinnedMessages: + @classmethod + async def get_pinned_messages(cls, uow: UnitOfWork, chat_id: int) -> SPinnedMessages: async with uow: - pinned_messages = await uow.chat.get_pinned_messages(chat_id=chat_id) - return pinned_messages + pinned_messages_ids = await uow.chat.get_pinned_messages_ids(chat_id=chat_id) + raw_messages = await uow.message.get_messages_from_ids(messages_ids=pinned_messages_ids) + pinned_messages = await cls.add_avatar_image_and_username_to_message_list(uow=uow, messages=raw_messages) + return SPinnedMessages.model_validate({"pinned_messages": pinned_messages.messages}) - @staticmethod + @classmethod async def get_some_messages( - uow: UnitOfWork, chat_id: int, message_number_from: int, messages_to_get: int + cls, uow: UnitOfWork, chat_id: int, message_number_from: int, messages_to_get: int ) -> SMessageList: async with uow: - message = await uow.chat.get_some_messages( + messages = await uow.message.get_some_messages( chat_id=chat_id, message_number_from=message_number_from, messages_to_get=messages_to_get ) - return message + messages = await cls.add_avatar_image_and_username_to_message_list(uow=uow, messages=messages) + return messages diff --git a/app/services/user_service.py b/app/services/user_service.py index 3721e4b..299e5c1 100644 --- a/app/services/user_service.py +++ b/app/services/user_service.py @@ -112,7 +112,7 @@ class UserService: avatar_image=str(user_data.avatar_url or user.avatar_image), hashed_password=hashed_password ) - (await uow.user.add_user_avatar(user_id=user.id, avatar=str(user_data.avatar_url))) if user_data.avatar_url else None + (await uow.user.add_user_avatar(user_id=user.id, avatar=str(user_data.avatar_url))) if user_data.avatar_url else None # noqa await uow.commit() async with RedisService() as redis: diff --git a/app/tasks/email_templates.py b/app/tasks/email_templates.py index c5dbf2a..6b5b74d 100644 --- a/app/tasks/email_templates.py +++ b/app/tasks/email_templates.py @@ -23,7 +23,7 @@ def create_registration_confirmation_template( - +