diff --git a/app/chat/router.py b/app/chat/router.py index 0244868..4f2ed17 100644 --- a/app/chat/router.py +++ b/app/chat/router.py @@ -43,13 +43,13 @@ async def create_chat( user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork) ): + if user.id == user_to_exclude: + raise UserCanNotReadThisChatException async with uow: - if user.id == user_to_exclude: - raise UserCanNotReadThisChatException chat_id = await uow.chat.create(user_id=user_to_exclude, chat_name=chat_name, created_by=user.id) - user_added_to_chat = await uow.chat.add_user_to_chat(user.id, chat_id) + await uow.chat.add_user_to_chat(user.id, chat_id) await uow.chat.add_user_to_chat(settings.ADMIN_USER_ID, chat_id) - return user_added_to_chat + await uow.commit() @router.get( diff --git a/app/dao/chat.py b/app/dao/chat.py index 788d36e..e881fcd 100644 --- a/app/dao/chat.py +++ b/app/dao/chat.py @@ -1,4 +1,5 @@ from sqlalchemy import insert, select, update, delete +from sqlalchemy.exc import IntegrityError from app.dao.base import BaseDAO from app.database import engine # noqa @@ -19,26 +20,22 @@ class ChatDAO(BaseDAO): async def find_one_or_none(self, **filter_by): query = select(Chats.__table__.columns).filter_by(**filter_by) result = await self.session.execute(query) - result = result.scalar_one_or_none() + result = result.mappings().one_or_none() + return result async def create(self, user_id: int, chat_name: str, created_by: int) -> int: stmt = insert(Chats).values(chat_for=user_id, chat_name=chat_name, created_by=created_by).returning(Chats.id) result = await self.session.execute(stmt) - await self.session.commit() result = result.scalar() return result - async def add_user_to_chat(self, user_id: int, chat_id: int) -> bool: - query = select(UserChat.user_id).where(UserChat.chat_id == chat_id) - result = await self.session.execute(query) - result = result.scalars().all() - if user_id in result: + async def add_user_to_chat(self, user_id: int, chat_id: int) -> None: + try: + stmt = insert(UserChat).values(user_id=user_id, chat_id=chat_id) + await self.session.execute(stmt) + except IntegrityError: raise UserAlreadyInChatException - stmt = insert(UserChat).values(user_id=user_id, chat_id=chat_id) - await self.session.execute(stmt) - await self.session.commit() - return True async def send_message(self, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> SMessage: inserted_image = ( diff --git a/app/migrations/versions/2024-06-05_database_creation.py b/app/migrations/versions/2024-06-05_database_creation.py index b38e845..a8359da 100644 --- a/app/migrations/versions/2024-06-05_database_creation.py +++ b/app/migrations/versions/2024-06-05_database_creation.py @@ -1,8 +1,8 @@ """Database Creation -Revision ID: 00acc3992d64 +Revision ID: 40559c83c848 Revises: -Create Date: 2024-06-05 12:56:38.627620 +Create Date: 2024-06-05 19:00:07.382050 """ from typing import Sequence, Union @@ -12,7 +12,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = '00acc3992d64' +revision: str = '40559c83c848' down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -72,12 +72,11 @@ def upgrade() -> None: sa.PrimaryKeyConstraint('id') ) op.create_table('user_chat', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=True), - sa.Column('chat_id', sa.Integer(), nullable=True), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('chat_id', sa.Integer(), nullable=False), sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ), sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + sa.PrimaryKeyConstraint('user_id', 'chat_id') ) op.create_table('message_answer', sa.Column('self_id', sa.Integer(), nullable=False), diff --git a/app/models/user_chat.py b/app/models/user_chat.py index d89f929..bc654ad 100644 --- a/app/models/user_chat.py +++ b/app/models/user_chat.py @@ -7,6 +7,5 @@ from app.database import Base class UserChat(Base): __tablename__ = "user_chat" - id: Mapped[int] = mapped_column(primary_key=True) - user_id = mapped_column(ForeignKey("users.id")) - chat_id = mapped_column(ForeignKey("chat.id")) + user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), primary_key=True) + chat_id: Mapped[int] = mapped_column(ForeignKey("chat.id"), primary_key=True) diff --git a/app/utils/auth.py b/app/utils/auth.py index 2555ec9..eee29fd 100644 --- a/app/utils/auth.py +++ b/app/utils/auth.py @@ -17,6 +17,8 @@ from app.users.schemas import SUser, SConfirmationData, SInvitationData pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY) + def get_password_hash(password: str) -> str: return pwd_context.hash(password) @@ -35,27 +37,23 @@ def create_access_token(data: dict[str, str | datetime]) -> str: def encode_invitation_token(user_data: SInvitationData) -> str: - cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY) invitation_token = cipher_suite.encrypt(user_data.model_dump_json().encode()) return invitation_token.decode() def decode_invitation_token(invitation_token: str) -> SInvitationData: user_code = invitation_token.encode() - cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY) user_data = cipher_suite.decrypt(user_code) return SInvitationData.model_validate_json(user_data) def encode_confirmation_token(user_data: SConfirmationData) -> str: - cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY) invitation_token = cipher_suite.encrypt(user_data.model_dump_json().encode()) return invitation_token.decode() def decode_confirmation_token(invitation_token: str) -> SConfirmationData: user_code = invitation_token.encode() - cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY) user_data = cipher_suite.decrypt(user_code) return SConfirmationData.model_validate_json(user_data)