diff --git a/.env_template b/.env_template index e90a93a..d10992d 100644 --- a/.env_template +++ b/.env_template @@ -17,6 +17,7 @@ ALGORITHM= REDIS_HOST= REDIS_PORT= +REDIS_DB= SMTP_HOST= SMTP_PORT= @@ -28,4 +29,4 @@ IMAGE_UPLOAD_SERVER= INVITATION_LINK_HOST= INVITATION_LINK_TOKEN_KEY= -SENTRY_DSN= \ No newline at end of file +SENTRY_DSN= diff --git a/app/chat/dao.py b/app/chat/dao.py index 9eb5a86..146a417 100644 --- a/app/chat/dao.py +++ b/app/chat/dao.py @@ -1,7 +1,7 @@ from sqlalchemy import insert, select, update, delete from app.dao.base import BaseDAO -from app.database import async_session_maker, engine # noqa +from app.database import engine # noqa from app.exceptions import UserAlreadyInChatException, UserAlreadyPinnedChatException from app.chat.shemas import SMessage from app.models.users import Users @@ -18,22 +18,21 @@ class ChatDAO(BaseDAO): async def create(self, user_id: int, chat_name: str, created_by: int) -> int: query = insert(Chats).values(chat_for=user_id, chat_name=chat_name, created_by=created_by).returning(Chats.id) - async with async_session_maker() as session: - result = await session.execute(query) - await session.commit() + + result = await self.session.execute(query) + await self.session.commit() result = result.scalar() return result async def add_user_to_chat(self, user_id: int, chat_id: int) -> bool: query = select(UserChat.user_id).where(UserChat.chat_id == chat_id) - async with async_session_maker() as session: - result = await session.execute(query) - result = result.scalars().all() - if user_id in result: - raise UserAlreadyInChatException - query = insert(UserChat).values(user_id=user_id, chat_id=chat_id) - await session.execute(query) - await session.commit() + result = await self.session.execute(query) + result = result.scalars().all() + if user_id in result: + raise UserAlreadyInChatException + query = insert(UserChat).values(user_id=user_id, chat_id=chat_id) + await self.session.execute(query) + await self.session.commit() return True async def send_message(self, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> SMessage: @@ -65,10 +64,9 @@ class ChatDAO(BaseDAO): .join(Answer, Answer.self_id == inserted_image.c.id, isouter=True) ) - async with async_session_maker() as session: - result = await session.execute(query) - await session.commit() - result = result.mappings().one() + result = await self.session.execute(query) + await self.session.commit() + result = result.mappings().one() return SMessage.model_validate(result, from_attributes=True) async def get_message_by_id(self, message_id: int): @@ -91,17 +89,15 @@ class ChatDAO(BaseDAO): .join(Answer, Answer.self_id == Message.id, isouter=True) .where(Message.id == message_id, Message.visibility == True) # noqa: E712 ) - async with async_session_maker() as session: - result = await session.execute(query) + result = await self.session.execute(query) result = result.mappings().all() if result: return result[0] async def delete_message(self, message_id: int) -> bool: query = update(Message).where(Message.id == message_id).values(visibility=False) - async with async_session_maker() as session: - await session.execute(query) - await session.commit() + await self.session.execute(query) + await self.session.commit() return True async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> list[dict]: @@ -145,8 +141,7 @@ class ChatDAO(BaseDAO): .limit(messages_to_get) .offset(message_number_from) ) - async with async_session_maker() as session: - result = await session.execute(messages) + result = await self.session.execute(messages) result = result.mappings().all() if result: result = [dict(res) for res in result] @@ -154,10 +149,9 @@ class ChatDAO(BaseDAO): async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> bool: query = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url) - async with async_session_maker() as session: - await session.execute(query) - await session.commit() - return True + await self.session.execute(query) + await self.session.commit() + return True async def add_answer(self, self_id: int, answer_id: int) -> SMessage: answer = ( @@ -187,44 +181,39 @@ class ChatDAO(BaseDAO): .where(Message.id == self_id) ) - async with async_session_maker() as session: - result = await session.execute(query) - await session.commit() + result = await self.session.execute(query) + await self.session.commit() result = result.mappings().one() return SMessage.model_validate(result) async def delete_chat(self, chat_id: int) -> bool: query = update(Chats).where(Chats.id == chat_id).values(visibility=False) - async with async_session_maker() as session: - await session.execute(query) - await session.commit() - return True + await self.session.execute(query) + await self.session.commit() + return True async def delete_user(self, chat_id: int, user_id: int) -> bool: query = delete(UserChat).where(UserChat.chat_id == chat_id, UserChat.user_id == user_id) - async with async_session_maker() as session: - await session.execute(query) - await session.commit() - return True + await self.session.execute(query) + await self.session.commit() + return True async def pin_chat(self, chat_id: int, user_id: int) -> bool: query = select(PinnedChats.chat_id).where(PinnedChats.user_id == user_id) - async with async_session_maker() as session: - result = await session.execute(query) - result = result.scalars().all() - if chat_id in result: - raise UserAlreadyPinnedChatException - query = insert(PinnedChats).values(chat_id=chat_id, user_id=user_id) - await session.execute(query) - await session.commit() - return True + result = await self.session.execute(query) + result = result.scalars().all() + if chat_id in result: + raise UserAlreadyPinnedChatException + query = insert(PinnedChats).values(chat_id=chat_id, user_id=user_id) + await self.session.execute(query) + await self.session.commit() + return True async def unpin_chat(self, chat_id: int, user_id: int) -> bool: query = delete(PinnedChats).where(PinnedChats.chat_id == chat_id, PinnedChats.user_id == user_id) - async with async_session_maker() as session: - await session.execute(query) - await session.commit() - return True + await self.session.execute(query) + await self.session.commit() + return True async def get_pinned_chats(self, user_id: int): chats_with_descriptions = ( @@ -263,24 +252,19 @@ class ChatDAO(BaseDAO): .where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True) # noqa: E712 ) # print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса - async with async_session_maker() as session: - result = await session.execute(query) - result = result.mappings().all() - return result + result = await self.session.execute(query) + result = result.mappings().all() + return result async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> bool: query = insert(PinnedMessages).values(chat_id=chat_id, message_id=message_id, user_id=user_id) - async with async_session_maker() as session: - await session.execute(query) - await session.commit() - return True + await self.session.execute(query) + return True async def unpin_message(self, chat_id: int, message_id: int) -> bool: query = delete(PinnedMessages).where(PinnedMessages.chat_id == chat_id, PinnedMessages.message_id == message_id) - async with async_session_maker() as session: - await session.execute(query) - await session.commit() - return True + await self.session.execute(query) + return True async def get_pinned_messages(self, chat_id: int) -> list[dict]: query = ( @@ -304,9 +288,9 @@ class ChatDAO(BaseDAO): .where(PinnedMessages.chat_id == chat_id, Message.visibility == True) # noqa: E712 .order_by(Message.created_at.desc()) ) - async with async_session_maker() as session: - result = await session.execute(query) - result = result.mappings().all() - if result: - result = [dict(res) for res in result] - return result + + result = await self.session.execute(query) + result = result.mappings().all() + if result: + result = [dict(res) for res in result] + return result diff --git a/app/chat/router.py b/app/chat/router.py index 064a0da..58c42cf 100644 --- a/app/chat/router.py +++ b/app/chat/router.py @@ -4,10 +4,9 @@ from fastapi import APIRouter, Depends, status from app.config import settings from app.exceptions import UserDontHavePermissionException, MessageNotFoundException, UserCanNotReadThisChatException -from app.chat.dao import ChatDAO from app.chat.shemas import SMessage, SLastMessages, SPinnedChat, SDeletedUser, SChat, SDeletedChat +from app.unit_of_work import UnitOfWork -from app.users.dao import UserDAO from app.users.dependencies import check_verificated_user_with_exc from app.users.auth import ADMIN_USER_ID, AuthService from app.users.schemas import SCreateInvitationLink, SUserAddedToChat, SUser @@ -20,42 +19,10 @@ router = APIRouter(prefix="/chat", tags=["Чат"]) status_code=status.HTTP_200_OK, response_model=list[SChat] ) -async def get_all_chats(user: SUser = Depends(check_verificated_user_with_exc)): - result = await UserDAO.get_user_allowed_chats(user.id) - return result - - -@router.post( - "", - status_code=status.HTTP_201_CREATED, - response_model=None, -) -async def add_message_to_chat(chat_id: int, message: str, user: SUser = Depends(check_verificated_user_with_exc)): - chats = await AuthService.get_user_allowed_chats_id(user.id) - if chat_id not in chats: - raise UserDontHavePermissionException - send_message_to_chat = await ChatDAO.send_message( - user_id=user.id, - chat_id=chat_id, - message=message, - ) - return send_message_to_chat - - -@router.delete( - "/delete_message", - status_code=status.HTTP_200_OK, - response_model=None, -) -async def delete_message_from_chat(message_id: int, user: SUser = Depends(check_verificated_user_with_exc)): - get_message_sender = await ChatDAO.get_message_by_id(message_id=message_id) - if get_message_sender is None: - raise MessageNotFoundException - if get_message_sender["user_id"] != user.id: - if not await AuthService.validate_user_admin(user_id=user.id): - raise UserDontHavePermissionException - deleted_message = await ChatDAO.delete_message(message_id=message_id) - return deleted_message +async def get_all_chats(user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): + async with uow: + result = await uow.user.get_user_allowed_chats(user.id) + return result @router.post( @@ -63,13 +30,14 @@ async def delete_message_from_chat(message_id: int, user: SUser = Depends(check_ status_code=status.HTTP_201_CREATED, response_model=None, ) -async def create_chat(user_to_exclude: int, chat_name: str, user: SUser = Depends(check_verificated_user_with_exc)): - if user.id == user_to_exclude: - raise UserCanNotReadThisChatException - chat_id = await ChatDAO.create(user_id=user_to_exclude, chat_name=chat_name, created_by=user.id) - user_added_to_chat = await ChatDAO.add_user_to_chat(user.id, chat_id) - await ChatDAO.add_user_to_chat(ADMIN_USER_ID, chat_id) - return user_added_to_chat +async def create_chat(user_to_exclude: int, chat_name: str, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): + async with uow: + if user.id == user_to_exclude: + raise UserCanNotReadThisChatException + chat_id = await uow.chat.create(user_id=user_to_exclude, chat_name=chat_name, created_by=user.id) + user_added_to_chat = await uow.chat.add_user_to_chat(user.id, chat_id) + await uow.chat.add_user_to_chat(ADMIN_USER_ID, chat_id) + return user_added_to_chat @router.get( @@ -77,14 +45,15 @@ async def create_chat(user_to_exclude: int, chat_name: str, user: SUser = Depend status_code=status.HTTP_200_OK, response_model=list[SMessage] ) -async def get_last_message(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)): - await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id) - message = await ChatDAO.get_some_messages(chat_id=chat_id, message_number_from=0, messages_to_get=1) - if message is None: - raise MessageNotFoundException - for mes in message: - mes["created_at"] = mes["created_at"].isoformat() - return message +async def get_last_message(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): + await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id) + async with uow: + message = await uow.chat.get_some_messages(chat_id=chat_id, message_number_from=0, messages_to_get=1) + if message is None: + raise MessageNotFoundException + for mes in message: + mes["created_at"] = mes["created_at"].isoformat() + return message @router.get( @@ -93,18 +62,20 @@ async def get_last_message(chat_id: int, user: SUser = Depends(check_verificated response_model=list[SMessage] ) async def get_some_messages( - chat_id: int, last_messages: SLastMessages = Depends(), user: SUser = Depends(check_verificated_user_with_exc) + chat_id: int, last_messages: SLastMessages = Depends(), user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork) ): - await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id) - messages = await ChatDAO.get_some_messages( - chat_id=chat_id, message_number_from=last_messages.messages_loaded, - messages_to_get=last_messages.messages_to_get - ) - if not messages: - raise MessageNotFoundException - for mes in messages: - mes["created_at"] = mes["created_at"].isoformat() - return messages + await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id) + async with uow: + messages = await uow.chat.get_some_messages( + chat_id=chat_id, + message_number_from=last_messages.messages_loaded, + messages_to_get=last_messages.messages_to_get + ) + if not messages: + raise MessageNotFoundException + for mes in messages: + mes["created_at"] = mes["created_at"].isoformat() + return messages @router.get( @@ -112,14 +83,15 @@ async def get_some_messages( status_code=status.HTTP_200_OK, response_model=SMessage ) -async def get_message_by_id(chat_id: int, message_id: int, user: SUser = Depends(check_verificated_user_with_exc)): - await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id) - message = await ChatDAO.get_message_by_id(message_id=message_id) - if not message: - raise MessageNotFoundException - message = dict(message) - message["created_at"] = message["created_at"].isoformat() - return message +async def get_message_by_id(chat_id: int, message_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): + await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id) + async with uow: + message = await uow.chat.get_message_by_id(message_id=message_id) + if not message: + raise MessageNotFoundException + message = dict(message) + message["created_at"] = message["created_at"].isoformat() + return message @router.get( @@ -127,8 +99,8 @@ async def get_message_by_id(chat_id: int, message_id: int, user: SUser = Depends status_code=status.HTTP_201_CREATED, response_model=SCreateInvitationLink, ) -async def create_invitation_link(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)): - await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id) +async def create_invitation_link(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): + await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id) cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY) invitation_token = cipher_suite.encrypt(str(chat_id).encode()) invitation_link = settings.INVITATION_LINK_HOST + "/api/chat/invite_to_chat/" + str(invitation_token).split("'")[1] @@ -140,14 +112,15 @@ async def create_invitation_link(chat_id: int, user: SUser = Depends(check_verif status_code=status.HTTP_200_OK, response_model=SUserAddedToChat, ) -async def invite_to_chat(invitation_token: str, user: SUser = Depends(check_verificated_user_with_exc)): - invitation_token = invitation_token.encode() - cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY) - chat_id = int(cipher_suite.decrypt(invitation_token)) - chat = await ChatDAO.find_one_or_none(id=chat_id) - if user.id == chat.chat_for: - raise UserCanNotReadThisChatException - return {"user_added_to_chat": await ChatDAO.add_user_to_chat(chat_id=chat_id, user_id=user.id)} +async def invite_to_chat(invitation_token: str, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): + async with uow: + invitation_token = invitation_token.encode() + cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY) + chat_id = int(cipher_suite.decrypt(invitation_token)) + chat = await uow.chat.find_one_or_none(id=chat_id) + if user.id == chat.chat_for: + raise UserCanNotReadThisChatException + return {"user_added_to_chat": await uow.chat.add_user_to_chat(chat_id=chat_id, user_id=user.id)} @router.delete( @@ -155,11 +128,12 @@ async def invite_to_chat(invitation_token: str, user: SUser = Depends(check_veri status_code=status.HTTP_200_OK, response_model=SDeletedChat, ) -async def delete_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)): - chat = await ChatDAO.find_one_or_none(id=chat_id) - if user.id == chat.created_by: - return {"deleted_chat": await ChatDAO.delete_chat(chat_id)} - raise UserDontHavePermissionException +async def delete_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): + async with uow: + chat = await uow.chat.find_one_or_none(id=chat_id) + if user.id == chat.created_by: + return {"deleted_chat": await uow.chat.delete_chat(chat_id)} + raise UserDontHavePermissionException @router.delete( @@ -167,11 +141,12 @@ async def delete_chat(chat_id: int, user: SUser = Depends(check_verificated_user status_code=status.HTTP_200_OK, response_model=SDeletedUser ) -async def delete_user_from_chat(chat_id: int, user_id: int, user: SUser = Depends(check_verificated_user_with_exc)): - chat = await ChatDAO.find_one_or_none(id=chat_id) - if user.id == chat.created_by: - return {"deleted_user": await ChatDAO.delete_user(chat_id=chat_id, user_id=user_id)} - raise UserDontHavePermissionException +async def delete_user_from_chat(chat_id: int, user_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): + async with uow: + chat = await uow.chat.find_one_or_none(id=chat_id) + if user.id == chat.created_by: + return {"deleted_user": await uow.chat.delete_user(chat_id=chat_id, user_id=user_id)} + raise UserDontHavePermissionException @router.post( @@ -179,10 +154,11 @@ async def delete_user_from_chat(chat_id: int, user_id: int, user: SUser = Depend status_code=status.HTTP_200_OK, response_model=SPinnedChat ) -async def pinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)): - await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id) - await ChatDAO.pin_chat(chat_id=chat_id, user_id=user.id) - return {"chat_id": chat_id, "user_id": user.id} +async def pinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): + await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id) + async with uow: + await uow.chat.pin_chat(chat_id=chat_id, user_id=user.id) + return {"chat_id": chat_id, "user_id": user.id} @router.delete( @@ -190,10 +166,11 @@ async def pinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_w status_code=status.HTTP_200_OK, response_model=SPinnedChat ) -async def unpinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)): - await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id) - await ChatDAO.unpin_chat(chat_id=chat_id, user_id=user.id) - return {"chat_id": chat_id, "user_id": user.id} +async def unpinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): + await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id) + async with uow: + await uow.chat.unpin_chat(chat_id=chat_id, user_id=user.id) + return {"chat_id": chat_id, "user_id": user.id} @router.get( @@ -201,8 +178,9 @@ async def unpinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user status_code=status.HTTP_200_OK, response_model=list[SChat] ) -async def get_pinned_chats(user: SUser = Depends(check_verificated_user_with_exc)): - return await ChatDAO.get_pinned_chats(user_id=user.id) +async def get_pinned_chats(user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): + async with uow: + return await uow.chat.get_pinned_chats(user_id=user.id) @router.get( @@ -210,10 +188,11 @@ async def get_pinned_chats(user: SUser = Depends(check_verificated_user_with_exc status_code=status.HTTP_200_OK, response_model=list[SMessage] | None ) -async def pinned_messages(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)): - await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id) - messages = await ChatDAO.get_pinned_messages(chat_id=chat_id) - if messages: - for mes in messages: - mes["created_at"] = mes["created_at"].isoformat() - return messages +async def pinned_messages(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)): + await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id) + async with uow: + messages = await uow.chat.get_pinned_messages(chat_id=chat_id) + if messages: + for mes in messages: + mes["created_at"] = mes["created_at"].isoformat() + return messages diff --git a/app/chat/shemas.py b/app/chat/shemas.py index 9a9d0c5..a0ab21d 100644 --- a/app/chat/shemas.py +++ b/app/chat/shemas.py @@ -24,12 +24,6 @@ class SLastMessages(BaseModel): messages_to_get: int -class SPinnedMessage(BaseModel): - message_id: int - user_id: int - chat_id: int - - class SPinnedChat(BaseModel): user_id: int chat_id: int diff --git a/app/chat/websocket.py b/app/chat/websocket.py index 1946287..b20ab31 100644 --- a/app/chat/websocket.py +++ b/app/chat/websocket.py @@ -2,6 +2,7 @@ from fastapi import WebSocket, WebSocketDisconnect, Depends from app.exceptions import IncorrectDataException, UserDontHavePermissionException from app.services.message_service import MessageService +from app.unit_of_work import UnitOfWork from app.users.auth import AuthService from app.chat.router import router from app.chat.shemas import SSendMessage, SMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage @@ -22,14 +23,14 @@ class ConnectionManager: def disconnect(self, chat_id: int, websocket: WebSocket): self.active_connections[chat_id].remove(websocket) - async def broadcast(self, user_id: int, chat_id: int, message: dict): + async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict): if "flag" not in message: raise IncorrectDataException if message["flag"] == "send": message = SSendMessage.model_validate(message) - new_message = await self.add_message_to_database(user_id=user_id, chat_id=chat_id, message=message) + new_message = await self.add_message_to_database(uow=uow, user_id=user_id, chat_id=chat_id, message=message) new_message = new_message.model_dump() new_message["created_at"] = new_message["created_at"].isoformat() new_message["flag"] = "send" @@ -40,7 +41,7 @@ class ConnectionManager: if message.user_id != user_id: raise UserDontHavePermissionException - deleted_message = await self.delete_message(message.id) + deleted_message = await self.delete_message(uow=uow, message_id=message.id) new_message = {"deleted_message": deleted_message, "id": message.id, "flag": "delete"} elif message["flag"] == "edit": @@ -49,7 +50,9 @@ class ConnectionManager: if message.user_id != user_id: raise UserDontHavePermissionException - edited_message = await self.edit_message(message.id, message.new_message, message.new_image_url) + edited_message = await self.edit_message( + uow=uow, message_id=message.id, new_message=message.new_message, image_url=message.new_image_url + ) new_message = { "flag": "edit", "id": message.id, @@ -60,14 +63,14 @@ class ConnectionManager: elif message["flag"] == "pin": message = SPinMessage.model_validate(message) - pinned_message = await self.pin_message(chat_id=chat_id, user_id=message.user_id, message_id=message.id) + pinned_message = await self.pin_message(uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id) new_message = pinned_message.model_dump() new_message["created_at"] = new_message["created_at"].isoformat() new_message["flag"] = "pin" elif message["flag"] == "unpin": message = SUnpinMessage.model_validate(message) - unpinned_message = await self.unpin_message(chat_id=chat_id, message_id=message.id) + unpinned_message = await self.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id) new_message = {"flag": "pin", "id": unpinned_message} else: @@ -77,34 +80,34 @@ class ConnectionManager: await websocket.send_json(new_message) @staticmethod - async def add_message_to_database(user_id: int, chat_id: int, message: SSendMessage) -> SMessage: + async def add_message_to_database(uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage) -> SMessage: new_message = await MessageService.send_message( - user_id=user_id, chat_id=chat_id, message=message.message, image_url=message.image_url + uow=uow, user_id=user_id, chat_id=chat_id, message=message.message, image_url=message.image_url ) if message.answer: - new_message = await MessageService.add_answer(self_id=new_message.id, answer_id=message.answer) + new_message = await MessageService.add_answer(uow=uow, self_id=new_message.id, answer_id=message.answer) return new_message @staticmethod - async def delete_message(message_id: int) -> bool: - new_message = await MessageService.delete_message(message_id) + async def delete_message(uow: UnitOfWork, message_id: int) -> bool: + new_message = await MessageService.delete_message(uow=uow, message_id=message_id) return new_message @staticmethod - async def edit_message(message_id: int, new_message: str, image_url: str) -> bool: + async def edit_message(uow: UnitOfWork, message_id: int, new_message: str, image_url: str) -> bool: new_message = await MessageService.edit_message( - message_id=message_id, new_message=new_message, new_image_url=image_url + uow=uow, message_id=message_id, new_message=new_message, new_image_url=image_url ) return new_message @staticmethod - async def pin_message(chat_id: int, user_id: int, message_id: int) -> SMessage: - pinned_message = await MessageService.pin_message(chat_id=chat_id, user_id=user_id, message_id=message_id) + async def pin_message(uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage: + pinned_message = await MessageService.pin_message(uow=uow, chat_id=chat_id, user_id=user_id, message_id=message_id) return pinned_message @staticmethod - async def unpin_message(chat_id: int, message_id: int) -> int: - unpinned_message_id = await MessageService.unpin_message(chat_id=chat_id, message_id=message_id) + async def unpin_message(uow: UnitOfWork, chat_id: int, message_id: int) -> int: + unpinned_message_id = await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message_id) return unpinned_message_id @@ -114,14 +117,14 @@ manager = ConnectionManager() @router.websocket( "/ws/{chat_id}", ) -async def websocket_endpoint(chat_id: int, websocket: WebSocket, user: SUser = Depends(get_current_user_ws)): - await AuthService.check_verificated_user_with_exc(user_id=user.id) - await AuthService.validate_user_access_to_chat(user_id=user.id, chat_id=chat_id) +async def websocket_endpoint(chat_id: int, websocket: WebSocket, user: SUser = Depends(get_current_user_ws), uow=Depends(UnitOfWork)): + await AuthService.check_verificated_user_with_exc(uow=uow, user_id=user.id) + await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id) await manager.connect(chat_id, websocket) try: while True: data = await websocket.receive_json() - await manager.broadcast(user_id=user.id, chat_id=chat_id, message=data) + await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data) except WebSocketDisconnect: manager.disconnect(chat_id, websocket) diff --git a/app/config.py b/app/config.py index a0cc05d..f6d2349 100644 --- a/app/config.py +++ b/app/config.py @@ -25,6 +25,7 @@ class Settings(BaseSettings): REDIS_HOST: str REDIS_PORT: int + REDIS_DB: int SMTP_HOST: str SMTP_PORT: int diff --git a/app/dao/base.py b/app/dao/base.py index 0f3632f..990b9fb 100644 --- a/app/dao/base.py +++ b/app/dao/base.py @@ -1,8 +1,6 @@ from sqlalchemy import select, insert from sqlalchemy.ext.asyncio import AsyncSession -from app.database import async_session_maker - class BaseDAO: model = None @@ -11,20 +9,16 @@ class BaseDAO: self.session = session async def add(self, **data): # Метод добавляет данные в БД - async with async_session_maker() as session: - query = insert(self.model).values(**data).returning(self.model.id) - result = await session.execute(query) - await session.commit() - return result.scalar() + stmt = insert(self.model).values(**data).returning(self.model.id) + result = await self.session.execute(stmt) + return result.scalar() async def find_one_or_none(self, **filter_by): # Метод проверяет наличие строки с заданными параметрами - async with async_session_maker() as session: - query = select(self.model).filter_by(**filter_by) - result = await session.execute(query) - return result.scalar_one_or_none() + query = select(self.model).filter_by(**filter_by) + result = await self.session.execute(query) + return result.scalar_one_or_none() async def find_all(self, **filter_by): # Метод возвращает все строки таблицы или те, которые соответствуют отбору - async with async_session_maker() as session: - query = select(self.model.__table__.columns).filter_by(**filter_by) - result = await session.execute(query) - return result.mappings().all() + query = select(self.model.__table__.columns).filter_by(**filter_by) + result = await self.session.execute(query) + return result.mappings().all() diff --git a/app/services/message_service.py b/app/services/message_service.py index 5b4b1d3..9718a56 100644 --- a/app/services/message_service.py +++ b/app/services/message_service.py @@ -1,34 +1,42 @@ -from app.chat.dao import ChatDAO from app.chat.shemas import SMessage +from app.unit_of_work import UnitOfWork class MessageService: @staticmethod - async def send_message(user_id: int, chat_id: int, message: str, image_url: str | None = None) -> SMessage: - new_message = await ChatDAO.send_message(user_id=user_id, chat_id=chat_id, message=message, image_url=image_url) - return new_message + async def send_message(uow: UnitOfWork, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> SMessage: + async with uow: + new_message = await uow.chat.send_message(user_id=user_id, chat_id=chat_id, message=message, image_url=image_url) + return new_message @staticmethod - async def add_answer(self_id: int, answer_id: int) -> SMessage: - new_message = await ChatDAO.add_answer(self_id=self_id, answer_id=answer_id) - return new_message + async def add_answer(uow: UnitOfWork, self_id: int, answer_id: int) -> SMessage: + async with uow: + new_message = await uow.chat.add_answer(self_id=self_id, answer_id=answer_id) + return new_message @staticmethod - async def delete_message(message_id: int) -> bool: - new_message = await ChatDAO.delete_message(message_id=message_id) - return new_message + async def delete_message(uow: UnitOfWork, message_id: int) -> bool: + async with uow: + new_message = await uow.chat.delete_message(message_id=message_id) + return new_message @staticmethod - async def edit_message(message_id: int, new_message: str, new_image_url: str) -> bool: - new_message = await ChatDAO.edit_message(message_id=message_id, new_message=new_message, new_image_url=new_image_url) - return new_message + async def edit_message(uow: UnitOfWork, message_id: int, new_message: str, new_image_url: str) -> bool: + async with uow: + new_message = await uow.chat.edit_message(message_id=message_id, new_message=new_message, new_image_url=new_image_url) + return new_message @staticmethod - async def pin_message(chat_id: int, user_id: int, message_id: int) -> SMessage: - pinned_message = await ChatDAO.pin_message(chat_id=chat_id, message_id=message_id, user_id=user_id) - return pinned_message + async def pin_message(uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage: + async with uow: + pinned_message = await uow.chat.pin_message(chat_id=chat_id, message_id=message_id, user_id=user_id) + await uow.commit() + return pinned_message @staticmethod - async def unpin_message(chat_id: int, message_id: int) -> int: - unpinned_message = await ChatDAO.unpin_message(chat_id=chat_id, message_id=message_id) - return unpinned_message + async def unpin_message(uow: UnitOfWork, chat_id: int, message_id: int) -> int: + async with uow: + unpinned_message = await uow.chat.unpin_message(chat_id=chat_id, message_id=message_id) + await uow.commit() + return unpinned_message diff --git a/app/services/redis_service.py b/app/services/redis_service.py new file mode 100644 index 0000000..753eaa3 --- /dev/null +++ b/app/services/redis_service.py @@ -0,0 +1,24 @@ +from redis.asyncio.client import Redis + +from app.config import settings + + +def get_redis_session() -> Redis: + return Redis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB) + + +class RedisService: + @staticmethod + async def set_verification_code(redis: Redis, user_id: int, verification_code: str): + await redis.setex(f"user_verification_code: {user_id}", 1800, verification_code) + + @staticmethod + async def get_verification_code(redis: Redis, user_id: int) -> str: + verification_code = await redis.get(f"user_verification_code: {user_id}") + return verification_code + + @staticmethod + async def delete_verification_code(redis: Redis, user_id: int): + await redis.delete(f"user_verification_code: {user_id}") + + diff --git a/app/services/user_service.py b/app/services/user_service.py index 0f45484..b7fc587 100644 --- a/app/services/user_service.py +++ b/app/services/user_service.py @@ -1,9 +1,10 @@ -from app.users.dao import UserDAO +from app.unit_of_work import UnitOfWork from app.users.schemas import SUser class UserService: @staticmethod - async def find_one_or_none(user_id: int) -> SUser | None: - user = await UserDAO.find_one_or_none(id=user_id) - return user + async def find_one_or_none(uow: UnitOfWork, user_id: int) -> SUser | None: + async with uow: + user = await uow.user.find_one_or_none(id=user_id) + return user diff --git a/app/tasks/tasks.py b/app/tasks/tasks.py index a73951e..f591ed4 100644 --- a/app/tasks/tasks.py +++ b/app/tasks/tasks.py @@ -72,3 +72,13 @@ def send_password_recover_email(username: str, email_to: EmailStr, MODE: str): server.send_message(msg_content) return confirmation_code + + +@celery.task +def send_data_change_confirmation_email(username: str, email_to: EmailStr, MODE: str): + pass + + +@celery.task +def send_data_change_email(username: str, email_to: EmailStr, MODE: str): + pass diff --git a/app/unit_of_work.py b/app/unit_of_work.py new file mode 100644 index 0000000..be2f8d7 --- /dev/null +++ b/app/unit_of_work.py @@ -0,0 +1,24 @@ +from app.chat.dao import ChatDAO +from app.database import async_session_maker +from app.users.dao import UserDAO + + +class UnitOfWork: + def __init__(self): + self.session_factory = async_session_maker + + async def __aenter__(self): + self.session = self.session_factory() + + self.user = UserDAO(self.session) + self.chat = ChatDAO(self.session) + + async def __aexit__(self, *args): + await self.rollback() + await self.session.close() + + async def commit(self): + await self.session.commit() + + async def rollback(self): + await self.session.rollback() diff --git a/app/users/auth.py b/app/users/auth.py index c4b99ab..f1d0412 100644 --- a/app/users/auth.py +++ b/app/users/auth.py @@ -12,9 +12,8 @@ from app.exceptions import ( UserNotFoundException, UserMustConfirmEmailException, ) -from app.users.dao import UserDAO -from app.models.users import Users -from app.users.schemas import SUserRegister +from app.unit_of_work import UnitOfWork +from app.users.schemas import SUserRegister, SUser pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -41,68 +40,72 @@ def create_access_token(data: dict[str, str | datetime]) -> str: class AuthService: - # Функция проверки наличия юзера по мейлу @staticmethod - async def authenticate_user_by_email(email: EmailStr, password: str) -> Users | None: - user = await UserDAO.find_one_or_none(email=email) - if not user or not verify_password(password, user.hashed_password): - return None - return user + async def authenticate_user_by_email(uow: UnitOfWork, email: EmailStr, password: str) -> SUser | None: + async with uow: + user = await uow.user.find_one_or_none(email=email) + if not user or not verify_password(password, user.hashed_password): + return None + return user - # Функция проверки наличия юзера по нику @staticmethod - async def authenticate_user_by_username(username: str, password: str) -> Users | None: - user = await UserDAO.find_one_or_none(username=username) - if not user or not verify_password(password, user.hashed_password): - return None - return user + async def authenticate_user_by_username(uow: UnitOfWork, username: str, password: str) -> SUser | None: + async with uow: + user = await uow.user.find_one_or_none(username=username) + if not user or not verify_password(password, user.hashed_password): + return None + return user @classmethod - async def authenticate_user(cls, email_or_username: str, password: str) -> Users: - user = await cls.authenticate_user_by_email(email_or_username, password) + async def authenticate_user(cls, uow: UnitOfWork, email_or_username: str, password: str) -> SUser: + user = await cls.authenticate_user_by_email(uow, email_or_username, password) if not user: - user = await cls.authenticate_user_by_username(email_or_username, password) + user = await cls.authenticate_user_by_username(uow, email_or_username, password) if not user: raise IncorrectAuthDataException return user @staticmethod - async def check_existing_user(user_data: SUserRegister) -> None: - existing_user = await UserDAO.find_one_or_none(email=user_data.email) - if existing_user: - raise UserAlreadyExistsException - existing_user = await UserDAO.find_one_or_none(username=user_data.username) - if existing_user: - raise UserAlreadyExistsException + async def check_existing_user(uow: UnitOfWork, user_data: SUserRegister) -> None: + async with uow: + existing_user = await uow.user.find_one_or_none(email=user_data.email) + if existing_user: + raise UserAlreadyExistsException + existing_user = await uow.user.find_one_or_none(username=user_data.username) + if existing_user: + raise UserAlreadyExistsException @staticmethod - async def check_verificated_user(user_id: int) -> bool: - user = await UserDAO.find_one_or_none(id=user_id) - if not user: - raise UserNotFoundException - return user.role >= VERIFICATED_USER + async def check_verificated_user(uow: UnitOfWork, user_id: int) -> bool: + async with uow: + user = await uow.user.find_one_or_none(id=user_id) + if not user: + raise UserNotFoundException + return user.role >= VERIFICATED_USER @classmethod - async def check_verificated_user_with_exc(cls, user_id: int): - if not await cls.check_verificated_user(user_id=user_id): + async def check_verificated_user_with_exc(cls, uow: UnitOfWork, user_id: int): + if not await cls.check_verificated_user(uow=uow, user_id=user_id): raise UserMustConfirmEmailException @staticmethod - async def get_user_allowed_chats_id(user_id: int) -> list[int]: - user_allowed_chats = await UserDAO.get_user_allowed_chats(user_id) - user_allowed_chats_id = [chat["chat_id"] for chat in user_allowed_chats] - return user_allowed_chats_id + async def get_user_allowed_chats_id(uow: UnitOfWork, user_id: int) -> list[int]: + async with uow: + user_allowed_chats = await uow.user.get_user_allowed_chats(user_id) + user_allowed_chats_id = [chat["chat_id"] for chat in user_allowed_chats] + return user_allowed_chats_id @classmethod - async def validate_user_access_to_chat(cls, user_id: int, chat_id: int) -> bool: - user_allowed_chats = await cls.get_user_allowed_chats_id(user_id=user_id) + async def validate_user_access_to_chat(cls, uow: UnitOfWork, user_id: int, chat_id: int) -> bool: + user_allowed_chats = await cls.get_user_allowed_chats_id(uow=uow, user_id=user_id) if chat_id not in user_allowed_chats: raise UserDontHavePermissionException return True @staticmethod - async def validate_user_admin(user_id: int) -> bool: - user_role = await UserDAO.get_user_role(user_id=user_id) - if user_role == ADMIN_USER: - return True - return False + async def validate_user_admin(uow: UnitOfWork, user_id: int) -> bool: + async with uow: + user_role = await uow.user.get_user_role(user_id=user_id) + if user_role == ADMIN_USER: + return True + return False diff --git a/app/users/dao.py b/app/users/dao.py index 1960199..2178339 100644 --- a/app/users/dao.py +++ b/app/users/dao.py @@ -1,11 +1,11 @@ -from sqlalchemy import update, select, insert, and_, func, text, delete +from pydantic import HttpUrl +from sqlalchemy import update, select, insert, func from app.dao.base import BaseDAO -from app.database import async_session_maker, engine # noqa +from app.database import engine # noqa from app.models.chat import Chats from app.models.user_avatar import UserAvatar from app.models.users import Users -from app.models.user_verification_code import UserVerificationCode from app.models.user_chat import UserChat from app.users.schemas import SUser, SUserAvatars @@ -14,32 +14,25 @@ class UserDAO(BaseDAO): model = Users async def find_one_or_none(self, **filter_by) -> SUser | None: - async with async_session_maker() as session: - query = select(Users).filter_by(**filter_by) - result = await session.execute(query) - result = result.scalar_one_or_none() - if result: - return SUser.model_validate(result, from_attributes=True) + query = select(Users).filter_by(**filter_by) + result = await self.session.execute(query) + result = result.scalar_one_or_none() + if result: + return SUser.model_validate(result, from_attributes=True) async def find_all(self, **filter_by): - async with async_session_maker() as session: - query = select(Users.__table__.columns).filter_by(**filter_by).where(Users.role != 100) - result = await session.execute(query) - return result.mappings().all() + query = select(Users.__table__.columns).filter_by(**filter_by).where(Users.role != 100) + result = await self.session.execute(query) + return result.mappings().all() async def change_data(self, user_id: int, **data_to_change) -> str: - query = update(Users).where(Users.id == user_id).values(**data_to_change) - async with async_session_maker() as session: - await session.execute(query) - await session.commit() - query = select(Users.username).where(Users.id == user_id) - result = await session.execute(query) - return result.scalar() + query = update(Users).where(Users.id == user_id).values(**data_to_change).returning(Users.username) + result = await self.session.execute(query) + return result.scalar() async def get_user_role(self, user_id: int) -> int: query = select(Users.role).where(Users.id == user_id) - async with async_session_maker() as session: - result = await session.execute(query) + result = await self.session.execute(query) return result.scalar() async def get_user_allowed_chats(self, user_id: int): @@ -89,26 +82,18 @@ class UserDAO(BaseDAO): chats_with_avatars.c.avatar_hex, ) .select_from(chats_with_avatars) - .where(and_(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True)) # noqa: E712 + .where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True) # noqa: E712 ) - async with async_session_maker() as session: - result = await session.execute(query) - result = result.mappings().all() + result = await self.session.execute(query) + result = result.mappings().all() return result - async def get_user_avatar(self, user_id: int) -> str: - query = select(Users.avatar_image).where(Users.id == user_id) - async with async_session_maker() as session: - result = await session.execute(query) - return result.scalar() - - async def add_user_avatar(self, user_id: int, avatar: str) -> bool: + async def add_user_avatar(self, user_id: int, avatar: HttpUrl) -> bool: query = insert(UserAvatar).values(user_id=user_id, avatar_image=avatar) - async with async_session_maker() as session: - await session.execute(query) - await session.commit() - return True + await self.session.execute(query) + await self.session.commit() + return True async def get_user_avatars(self, user_id: int) -> SUserAvatars: query = select( @@ -124,54 +109,6 @@ class UserDAO(BaseDAO): .scalar_subquery() ) ) - async with async_session_maker() as session: - result = await session.execute(query) - result = result.scalar() - return SUserAvatars.model_validate(result) - - async def delete_user_avatar(self, avatar_id: int, user_id: int) -> bool: - query = delete(UserAvatar).where(and_(UserAvatar.id == avatar_id, UserAvatar.user_id == user_id)) - async with async_session_maker() as session: - await session.execute(query) - await session.commit() - return True - - -class UserCodesDAO(BaseDAO): - model = UserVerificationCode - - async def set_user_codes(self, cls, user_id: int, code: str, description: str): - query = ( - insert(UserVerificationCode) - .values(user_id=user_id, code=code, description=description) - .returning(cls.model.code) - ) - async with async_session_maker() as session: - result = await session.execute(query) - await session.commit() - return result.scalar() - - async def get_user_codes(self, **filter_by) -> list[dict | None]: - """ - SELECT - usersverificationcodes.id, - usersverificationcodes.user_id, - usersverificationcodes.code, - usersverificationcodes.description, - usersverificationcodes.date_of_creation - FROM usersverificationcodes - WHERE - usersverificationcodes.user_id = 20 - AND now() - usersverificationcodes.date_of_creation < INTERVAL '30 minutes' - """ - query = ( - select(UserVerificationCode.__table__.columns) - .where((func.now() - UserVerificationCode.date_of_creation) < text("INTERVAL '10 minutes'")) - .filter_by(**filter_by) - ) - - async with async_session_maker() as session: - # print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса - result = await session.execute(query) - result = result.mappings().all() - return result + result = await self.session.execute(query) + result = result.scalar() + return SUserAvatars.model_validate(result) diff --git a/app/users/dependencies.py b/app/users/dependencies.py index 1ad7079..c2b2d03 100644 --- a/app/users/dependencies.py +++ b/app/users/dependencies.py @@ -10,6 +10,7 @@ from app.exceptions import ( UserMustConfirmEmailException, ) from app.services.user_service import UserService +from app.unit_of_work import UnitOfWork from app.users.auth import create_access_token, VERIFICATED_USER from app.users.schemas import SUser @@ -21,7 +22,7 @@ def get_token(request: Request) -> str: return token -async def get_current_user(response: Response, token: str = Depends(get_token)) -> SUser: +async def get_current_user(response: Response, token: str = Depends(get_token), uow=Depends(UnitOfWork)) -> SUser: try: payload = jwt.decode(token, settings.SECRET_KEY, settings.ALGORITHM) except ExpiredSignatureError: @@ -33,7 +34,7 @@ async def get_current_user(response: Response, token: str = Depends(get_token)) if not user_id: raise UserIsNotPresentException - user = await UserService.find_one_or_none(user_id=int(user_id)) + user = await UserService.find_one_or_none(uow=uow, user_id=int(user_id)) if not user: raise UserIsNotPresentException @@ -55,7 +56,7 @@ def get_token_ws(websocket: WebSocket) -> str: return token -async def get_current_user_ws(token: str = Depends(get_token_ws)): +async def get_current_user_ws(token: str = Depends(get_token_ws), uow=Depends(UnitOfWork)): try: payload = jwt.decode(token, settings.SECRET_KEY, settings.ALGORITHM) except ExpiredSignatureError: @@ -67,7 +68,7 @@ async def get_current_user_ws(token: str = Depends(get_token_ws)): if not user_id: raise UserIsNotPresentException - user = await UserService.find_one_or_none(user_id=int(user_id)) + user = await UserService.find_one_or_none(uow=uow, user_id=int(user_id)) if not user: raise UserIsNotPresentException diff --git a/app/users/router.py b/app/users/router.py index 0cad89e..c657520 100644 --- a/app/users/router.py +++ b/app/users/router.py @@ -7,31 +7,30 @@ from app.config import settings from app.exceptions import ( PasswordsMismatchException, WrongCodeException, - UserNotFoundException, - SomethingWentWrongException, UserAlreadyExistsException, + IncorrectAuthDataException, ) -from app.users.auth import get_password_hash, create_access_token, VERIFICATED_USER, AuthService -from app.users.dao import UserDAO, UserCodesDAO +from app.services.redis_service import RedisService, get_redis_session +from app.unit_of_work import UnitOfWork +from app.users.auth import get_password_hash, create_access_token, VERIFICATED_USER, AuthService, verify_password from app.users.dependencies import get_current_user from app.users.schemas import ( SUserLogin, SUserRegister, SUserResponse, - SUserPasswordRecover, - SUserCode, - SUserPasswordChange, - SRecoverEmailSent, SUserToken, SEmailVerification, - SConfirmPasswordRecovery, - SPasswordRecovered, SUserAvatars, SUsername, - SEmail, SUser, + SEmail, + SUser, + SUserChangeData, + SUserSendConfirmationCode, +) +from app.tasks.tasks import ( + send_registration_confirmation_email, + send_data_change_confirmation_email ) -from app.tasks.tasks import send_registration_confirmation_email, send_password_change_email, \ - send_password_recover_email router = APIRouter(prefix="/users", tags=["Пользователи"]) @@ -41,9 +40,10 @@ router = APIRouter(prefix="/users", tags=["Пользователи"]) status_code=status.HTTP_200_OK, response_model=list[SUserResponse], ) -async def get_all_users(): - users = await UserDAO.find_all() - return users +async def get_all_users(uow=Depends(UnitOfWork)): + async with uow: + users = await uow.user.find_all() + return users @router.post( @@ -51,10 +51,11 @@ async def get_all_users(): status_code=status.HTTP_200_OK, response_model=None, ) -async def check_existing_username(username: SUsername): - user = await UserDAO.find_one_or_none(username=username.username) - if user: - raise UserAlreadyExistsException +async def check_existing_username(username: SUsername, uow=Depends(UnitOfWork)): + async with uow: + user = await uow.user.find_one_or_none(username=username.username) + if user: + raise UserAlreadyExistsException @router.post( @@ -62,10 +63,11 @@ async def check_existing_username(username: SUsername): status_code=status.HTTP_200_OK, response_model=None, ) -async def check_existing_email(email: SEmail): - user = await UserDAO.find_one_or_none(email=email.email) - if user: - raise UserAlreadyExistsException +async def check_existing_email(email: SEmail, uow=Depends(UnitOfWork)): + async with uow: + user = await uow.user.find_one_or_none(email=email.email) + if user: + raise UserAlreadyExistsException @router.post( @@ -73,29 +75,30 @@ async def check_existing_email(email: SEmail): status_code=status.HTTP_201_CREATED, response_model=SUserToken, ) -async def register_user(response: Response, user_data: SUserRegister): +async def register_user(response: Response, user_data: SUserRegister, uow=Depends(UnitOfWork)): if user_data.password != user_data.password2: raise PasswordsMismatchException - await AuthService.check_existing_user(user_data) + await AuthService.check_existing_user(uow, user_data) hashed_password = get_password_hash(user_data.password) - user_id = await UserDAO.add( - email=user_data.email, - hashed_password=hashed_password, - username=user_data.username, - date_of_birth=user_data.date_of_birth, - ) + async with uow: + user_id = await uow.user.add( + email=user_data.email, + hashed_password=hashed_password, + username=user_data.username, + date_of_birth=user_data.date_of_birth, + ) result = send_registration_confirmation_email.delay( user_id=user_id, username=user_data.username, email_to=user_data.email, MODE=settings.MODE ) - result = result.get() - - if await UserCodesDAO.set_user_codes(user_id=user_id, code=result, description="Код подтверждения почты") == result: - user = await AuthService.authenticate_user_by_email(user_data.email, user_data.password) - access_token = create_access_token({"sub": str(user.id)}) - response.set_cookie(key="black_phoenix_access_token", value=access_token, httponly=True, secure=True) - return {"access_token": access_token} + user_code = result.get() + redis_session = get_redis_session() + await RedisService.set_verification_code(redis=redis_session, user_id=user_id, verification_code=user_code) + user = await AuthService.authenticate_user_by_email(uow, user_data.email, user_data.password) + access_token = create_access_token({"sub": str(user.id)}) + response.set_cookie(key="black_phoenix_access_token", value=access_token, httponly=True, secure=True) + return {"access_token": access_token} @router.get( @@ -103,17 +106,20 @@ async def register_user(response: Response, user_data: SUserRegister): status_code=status.HTTP_200_OK, response_model=SEmailVerification, ) -async def email_verification(user_code: str): +async def email_verification(user_code: str, uow=Depends(UnitOfWork)): invitation_token = user_code.encode() cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY) user_data = cipher_suite.decrypt(invitation_token).decode("utf-8") user_data = json.loads(user_data) - user_codes = await UserCodesDAO.get_user_codes( - user_id=user_data["user_id"], description="Код подтверждения почты", code=user_data["confirmation_code"] - ) - if not user_codes or not await UserDAO.change_data(user_id=user_data["user_id"], role=VERIFICATED_USER): - raise WrongCodeException - return {"email_verification": True} + redis_session = get_redis_session() + async with uow: + verification_code = await RedisService.get_verification_code(redis=redis_session, user_id=user_data["user_id"]) + if verification_code != user_data["confirmation_code"]: + raise WrongCodeException + + await uow.user.change_data(user_id=user_data["user_id"], role=VERIFICATED_USER) + await uow.commit() + return {"email_verification": True} @router.post( @@ -121,8 +127,8 @@ async def email_verification(user_code: str): status_code=status.HTTP_200_OK, response_model=SUserToken, ) -async def login_user(response: Response, user_data: SUserLogin): - user = await AuthService.authenticate_user(user_data.email_or_username, user_data.password) +async def login_user(response: Response, user_data: SUserLogin, uow=Depends(UnitOfWork)): + user = await AuthService.authenticate_user(uow, user_data.email_or_username, user_data.password) access_token = create_access_token({"sub": str(user.id)}) response.set_cookie("black_phoenix_access_token", access_token, httponly=True) return {"access_token": access_token} @@ -146,61 +152,56 @@ async def get_user(current_user: SUser = Depends(get_current_user)): return current_user -@router.patch( - "/send_recovery_email", - status_code=status.HTTP_200_OK, - response_model=SRecoverEmailSent, -) -async def send_recovery_email(email: SUserPasswordRecover): - existing_user = await UserDAO.find_one_or_none(email=email.email) - if not existing_user: - raise UserNotFoundException - result = send_password_recover_email.delay(existing_user.username, existing_user.email, MODE=settings.MODE) - result = result.get() - - if ( - await UserCodesDAO.set_user_codes( - user_id=existing_user.user_id, code=result, description="Код восстановления пароля" - ) - == result - ): - return {"recover_email_sent": True} - raise SomethingWentWrongException - - -@router.post( - "/confirm_password_recovery", - status_code=status.HTTP_200_OK, - response_model=SConfirmPasswordRecovery, -) -async def confirm_password_recovery(user_code: SUserCode): - user_codes = await UserCodesDAO.get_user_codes(description="Код восстановления пароля", code=user_code.user_code) - if not user_codes: - raise WrongCodeException - return {"user_id": user_codes[0]["user_id"]} - - -@router.post( - "/password_recovery", - status_code=status.HTTP_200_OK, - response_model=SPasswordRecovered, -) -async def password_recovery(passwords: SUserPasswordChange): - if passwords.password1 != passwords.password2: - raise PasswordsMismatchException - hashed_password = get_password_hash(passwords.password1) - username = await UserDAO.change_data(passwords.user_id, hashed_password=hashed_password) - user = await UserDAO.find_one_or_none(username=username, id=passwords.user_id) - if not user: - raise UserNotFoundException - send_password_change_email.delay(user.username, user.email, MODE=settings.MODE) - return {"username": username} - - @router.get( "/avatars", status_code=status.HTTP_200_OK, response_model=SUserAvatars, ) -async def get_user_avatars_history(user: SUser = Depends(get_current_user)): - return await UserDAO.get_user_avatars(user_id=user.id) +async def get_user_avatars_history(user=Depends(get_current_user), uow=Depends(UnitOfWork)): + async with uow: + return await uow.user.get_user_avatars(user_id=user.id) + + +@router.post( + "/send_confirmation_code", + status_code=status.HTTP_200_OK, + response_model=None, +) +async def send_confirmation_code(user_data: SUserSendConfirmationCode, user: SUser = Depends(get_current_user)): + redis_session = get_redis_session() + if verify_password(user_data.current_password, user.hashed_password): + verification_code = send_data_change_confirmation_email.delay( + username=user.username, email_to=user_data.email, MODE=settings.MODE + ) + verification_code = verification_code.get() + await RedisService.set_verification_code(redis=redis_session, user_id=user.id, verification_code=verification_code) + raise IncorrectAuthDataException + + +@router.post( + "/change_data", + status_code=status.HTTP_200_OK, + response_model=None, +) +async def change_user_data(user_data: SUserChangeData, user=Depends(get_current_user), uow=Depends(UnitOfWork)): + redis_session = get_redis_session() + verification_code = await RedisService.get_verification_code(redis=redis_session, user_id=user.id) + if verification_code != user_data.verification_code: + raise WrongCodeException + if user_data.new_password: + hashed_password = get_password_hash(user_data.new_password) + else: + hashed_password = user.hashed_password + async with uow: + await uow.user.change_data( + user_id=user.id, + email=user_data.email, + username=user_data.username, + avatar_url=user_data.avatar_url, + hashed_password=hashed_password + ) + await uow.user.add_user_avatar(user_id=user.id, avatar=user_data.avatar_url) + await uow.commit() + await RedisService.delete_verification_code(redis=redis_session, user_id=user.id) + + diff --git a/app/users/schemas.py b/app/users/schemas.py index 5c8cfdf..9bd4c0b 100644 --- a/app/users/schemas.py +++ b/app/users/schemas.py @@ -59,23 +59,6 @@ class SUser(BaseModel): date_of_registration: date -class SUserRename(BaseModel): - username: str = Query(None, min_length=2, max_length=30) - password: str - - -class SUserAvatar(BaseModel): - password: str - new_avatar_image: HttpUrl - avatar_hex: str - - -class SUserPassword(BaseModel): - current_password: str = Query(None, min_length=8) - new_password: str = Query(None, min_length=8) - new_password2: str = Query(None, min_length=8) - - class SUserPasswordRecover(BaseModel): email: EmailStr @@ -90,6 +73,19 @@ class SUserPasswordChange(BaseModel): password2: str = Query(None, min_length=8) +class SUserSendConfirmationCode(BaseModel): + email: EmailStr + current_password: str = Query(None, min_length=8) + + +class SUserChangeData(BaseModel): + verification_code: str + email: EmailStr + username: str = Query(None, min_length=2, max_length=30) + new_password: str | None = Query(None, min_length=8) + avatar_url: HttpUrl + + class SRecoverEmailSent(BaseModel): recover_email_sent: bool @@ -102,15 +98,6 @@ class SUserToken(BaseModel): access_token: str -class SUserName(BaseModel): - username: str - - -class SNewAvatar(BaseModel): - new_avatar_image: HttpUrl - avatar_hex: str - - class SConfirmPasswordRecovery(BaseModel): user_id: int