Переделал дао, закрепы сообщений, uow

This commit is contained in:
urec56 2024-06-02 18:12:54 +05:00
parent 9ad60c2a8e
commit 1786e0fb3d
17 changed files with 458 additions and 506 deletions

View file

@ -17,6 +17,7 @@ ALGORITHM=
REDIS_HOST=
REDIS_PORT=
REDIS_DB=
SMTP_HOST=
SMTP_PORT=
@ -28,4 +29,4 @@ IMAGE_UPLOAD_SERVER=
INVITATION_LINK_HOST=
INVITATION_LINK_TOKEN_KEY=
SENTRY_DSN=
SENTRY_DSN=

View file

@ -1,7 +1,7 @@
from sqlalchemy import insert, select, update, delete
from app.dao.base import BaseDAO
from app.database import async_session_maker, engine # noqa
from app.database import engine # noqa
from app.exceptions import UserAlreadyInChatException, UserAlreadyPinnedChatException
from app.chat.shemas import SMessage
from app.models.users import Users
@ -18,22 +18,21 @@ class ChatDAO(BaseDAO):
async def create(self, user_id: int, chat_name: str, created_by: int) -> int:
query = insert(Chats).values(chat_for=user_id, chat_name=chat_name, created_by=created_by).returning(Chats.id)
async with async_session_maker() as session:
result = await session.execute(query)
await session.commit()
result = await self.session.execute(query)
await self.session.commit()
result = result.scalar()
return result
async def add_user_to_chat(self, user_id: int, chat_id: int) -> bool:
query = select(UserChat.user_id).where(UserChat.chat_id == chat_id)
async with async_session_maker() as session:
result = await session.execute(query)
result = result.scalars().all()
if user_id in result:
raise UserAlreadyInChatException
query = insert(UserChat).values(user_id=user_id, chat_id=chat_id)
await session.execute(query)
await session.commit()
result = await self.session.execute(query)
result = result.scalars().all()
if user_id in result:
raise UserAlreadyInChatException
query = insert(UserChat).values(user_id=user_id, chat_id=chat_id)
await self.session.execute(query)
await self.session.commit()
return True
async def send_message(self, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> SMessage:
@ -65,10 +64,9 @@ class ChatDAO(BaseDAO):
.join(Answer, Answer.self_id == inserted_image.c.id, isouter=True)
)
async with async_session_maker() as session:
result = await session.execute(query)
await session.commit()
result = result.mappings().one()
result = await self.session.execute(query)
await self.session.commit()
result = result.mappings().one()
return SMessage.model_validate(result, from_attributes=True)
async def get_message_by_id(self, message_id: int):
@ -91,17 +89,15 @@ class ChatDAO(BaseDAO):
.join(Answer, Answer.self_id == Message.id, isouter=True)
.where(Message.id == message_id, Message.visibility == True) # noqa: E712
)
async with async_session_maker() as session:
result = await session.execute(query)
result = await self.session.execute(query)
result = result.mappings().all()
if result:
return result[0]
async def delete_message(self, message_id: int) -> bool:
query = update(Message).where(Message.id == message_id).values(visibility=False)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
await self.session.execute(query)
await self.session.commit()
return True
async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> list[dict]:
@ -145,8 +141,7 @@ class ChatDAO(BaseDAO):
.limit(messages_to_get)
.offset(message_number_from)
)
async with async_session_maker() as session:
result = await session.execute(messages)
result = await self.session.execute(messages)
result = result.mappings().all()
if result:
result = [dict(res) for res in result]
@ -154,10 +149,9 @@ class ChatDAO(BaseDAO):
async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> bool:
query = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
await self.session.execute(query)
await self.session.commit()
return True
async def add_answer(self, self_id: int, answer_id: int) -> SMessage:
answer = (
@ -187,44 +181,39 @@ class ChatDAO(BaseDAO):
.where(Message.id == self_id)
)
async with async_session_maker() as session:
result = await session.execute(query)
await session.commit()
result = await self.session.execute(query)
await self.session.commit()
result = result.mappings().one()
return SMessage.model_validate(result)
async def delete_chat(self, chat_id: int) -> bool:
query = update(Chats).where(Chats.id == chat_id).values(visibility=False)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
await self.session.execute(query)
await self.session.commit()
return True
async def delete_user(self, chat_id: int, user_id: int) -> bool:
query = delete(UserChat).where(UserChat.chat_id == chat_id, UserChat.user_id == user_id)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
await self.session.execute(query)
await self.session.commit()
return True
async def pin_chat(self, chat_id: int, user_id: int) -> bool:
query = select(PinnedChats.chat_id).where(PinnedChats.user_id == user_id)
async with async_session_maker() as session:
result = await session.execute(query)
result = result.scalars().all()
if chat_id in result:
raise UserAlreadyPinnedChatException
query = insert(PinnedChats).values(chat_id=chat_id, user_id=user_id)
await session.execute(query)
await session.commit()
return True
result = await self.session.execute(query)
result = result.scalars().all()
if chat_id in result:
raise UserAlreadyPinnedChatException
query = insert(PinnedChats).values(chat_id=chat_id, user_id=user_id)
await self.session.execute(query)
await self.session.commit()
return True
async def unpin_chat(self, chat_id: int, user_id: int) -> bool:
query = delete(PinnedChats).where(PinnedChats.chat_id == chat_id, PinnedChats.user_id == user_id)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
await self.session.execute(query)
await self.session.commit()
return True
async def get_pinned_chats(self, user_id: int):
chats_with_descriptions = (
@ -263,24 +252,19 @@ class ChatDAO(BaseDAO):
.where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True) # noqa: E712
)
# print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса
async with async_session_maker() as session:
result = await session.execute(query)
result = result.mappings().all()
return result
result = await self.session.execute(query)
result = result.mappings().all()
return result
async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> bool:
query = insert(PinnedMessages).values(chat_id=chat_id, message_id=message_id, user_id=user_id)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
await self.session.execute(query)
return True
async def unpin_message(self, chat_id: int, message_id: int) -> bool:
query = delete(PinnedMessages).where(PinnedMessages.chat_id == chat_id, PinnedMessages.message_id == message_id)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
await self.session.execute(query)
return True
async def get_pinned_messages(self, chat_id: int) -> list[dict]:
query = (
@ -304,9 +288,9 @@ class ChatDAO(BaseDAO):
.where(PinnedMessages.chat_id == chat_id, Message.visibility == True) # noqa: E712
.order_by(Message.created_at.desc())
)
async with async_session_maker() as session:
result = await session.execute(query)
result = result.mappings().all()
if result:
result = [dict(res) for res in result]
return result
result = await self.session.execute(query)
result = result.mappings().all()
if result:
result = [dict(res) for res in result]
return result

View file

@ -4,10 +4,9 @@ from fastapi import APIRouter, Depends, status
from app.config import settings
from app.exceptions import UserDontHavePermissionException, MessageNotFoundException, UserCanNotReadThisChatException
from app.chat.dao import ChatDAO
from app.chat.shemas import SMessage, SLastMessages, SPinnedChat, SDeletedUser, SChat, SDeletedChat
from app.unit_of_work import UnitOfWork
from app.users.dao import UserDAO
from app.users.dependencies import check_verificated_user_with_exc
from app.users.auth import ADMIN_USER_ID, AuthService
from app.users.schemas import SCreateInvitationLink, SUserAddedToChat, SUser
@ -20,42 +19,10 @@ router = APIRouter(prefix="/chat", tags=["Чат"])
status_code=status.HTTP_200_OK,
response_model=list[SChat]
)
async def get_all_chats(user: SUser = Depends(check_verificated_user_with_exc)):
result = await UserDAO.get_user_allowed_chats(user.id)
return result
@router.post(
"",
status_code=status.HTTP_201_CREATED,
response_model=None,
)
async def add_message_to_chat(chat_id: int, message: str, user: SUser = Depends(check_verificated_user_with_exc)):
chats = await AuthService.get_user_allowed_chats_id(user.id)
if chat_id not in chats:
raise UserDontHavePermissionException
send_message_to_chat = await ChatDAO.send_message(
user_id=user.id,
chat_id=chat_id,
message=message,
)
return send_message_to_chat
@router.delete(
"/delete_message",
status_code=status.HTTP_200_OK,
response_model=None,
)
async def delete_message_from_chat(message_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
get_message_sender = await ChatDAO.get_message_by_id(message_id=message_id)
if get_message_sender is None:
raise MessageNotFoundException
if get_message_sender["user_id"] != user.id:
if not await AuthService.validate_user_admin(user_id=user.id):
raise UserDontHavePermissionException
deleted_message = await ChatDAO.delete_message(message_id=message_id)
return deleted_message
async def get_all_chats(user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
async with uow:
result = await uow.user.get_user_allowed_chats(user.id)
return result
@router.post(
@ -63,13 +30,14 @@ async def delete_message_from_chat(message_id: int, user: SUser = Depends(check_
status_code=status.HTTP_201_CREATED,
response_model=None,
)
async def create_chat(user_to_exclude: int, chat_name: str, user: SUser = Depends(check_verificated_user_with_exc)):
if user.id == user_to_exclude:
raise UserCanNotReadThisChatException
chat_id = await ChatDAO.create(user_id=user_to_exclude, chat_name=chat_name, created_by=user.id)
user_added_to_chat = await ChatDAO.add_user_to_chat(user.id, chat_id)
await ChatDAO.add_user_to_chat(ADMIN_USER_ID, chat_id)
return user_added_to_chat
async def create_chat(user_to_exclude: int, chat_name: str, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
async with uow:
if user.id == user_to_exclude:
raise UserCanNotReadThisChatException
chat_id = await uow.chat.create(user_id=user_to_exclude, chat_name=chat_name, created_by=user.id)
user_added_to_chat = await uow.chat.add_user_to_chat(user.id, chat_id)
await uow.chat.add_user_to_chat(ADMIN_USER_ID, chat_id)
return user_added_to_chat
@router.get(
@ -77,14 +45,15 @@ async def create_chat(user_to_exclude: int, chat_name: str, user: SUser = Depend
status_code=status.HTTP_200_OK,
response_model=list[SMessage]
)
async def get_last_message(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id)
message = await ChatDAO.get_some_messages(chat_id=chat_id, message_number_from=0, messages_to_get=1)
if message is None:
raise MessageNotFoundException
for mes in message:
mes["created_at"] = mes["created_at"].isoformat()
return message
async def get_last_message(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id)
async with uow:
message = await uow.chat.get_some_messages(chat_id=chat_id, message_number_from=0, messages_to_get=1)
if message is None:
raise MessageNotFoundException
for mes in message:
mes["created_at"] = mes["created_at"].isoformat()
return message
@router.get(
@ -93,18 +62,20 @@ async def get_last_message(chat_id: int, user: SUser = Depends(check_verificated
response_model=list[SMessage]
)
async def get_some_messages(
chat_id: int, last_messages: SLastMessages = Depends(), user: SUser = Depends(check_verificated_user_with_exc)
chat_id: int, last_messages: SLastMessages = Depends(), user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)
):
await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id)
messages = await ChatDAO.get_some_messages(
chat_id=chat_id, message_number_from=last_messages.messages_loaded,
messages_to_get=last_messages.messages_to_get
)
if not messages:
raise MessageNotFoundException
for mes in messages:
mes["created_at"] = mes["created_at"].isoformat()
return messages
await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id)
async with uow:
messages = await uow.chat.get_some_messages(
chat_id=chat_id,
message_number_from=last_messages.messages_loaded,
messages_to_get=last_messages.messages_to_get
)
if not messages:
raise MessageNotFoundException
for mes in messages:
mes["created_at"] = mes["created_at"].isoformat()
return messages
@router.get(
@ -112,14 +83,15 @@ async def get_some_messages(
status_code=status.HTTP_200_OK,
response_model=SMessage
)
async def get_message_by_id(chat_id: int, message_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id)
message = await ChatDAO.get_message_by_id(message_id=message_id)
if not message:
raise MessageNotFoundException
message = dict(message)
message["created_at"] = message["created_at"].isoformat()
return message
async def get_message_by_id(chat_id: int, message_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id)
async with uow:
message = await uow.chat.get_message_by_id(message_id=message_id)
if not message:
raise MessageNotFoundException
message = dict(message)
message["created_at"] = message["created_at"].isoformat()
return message
@router.get(
@ -127,8 +99,8 @@ async def get_message_by_id(chat_id: int, message_id: int, user: SUser = Depends
status_code=status.HTTP_201_CREATED,
response_model=SCreateInvitationLink,
)
async def create_invitation_link(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id)
async def create_invitation_link(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id)
cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY)
invitation_token = cipher_suite.encrypt(str(chat_id).encode())
invitation_link = settings.INVITATION_LINK_HOST + "/api/chat/invite_to_chat/" + str(invitation_token).split("'")[1]
@ -140,14 +112,15 @@ async def create_invitation_link(chat_id: int, user: SUser = Depends(check_verif
status_code=status.HTTP_200_OK,
response_model=SUserAddedToChat,
)
async def invite_to_chat(invitation_token: str, user: SUser = Depends(check_verificated_user_with_exc)):
invitation_token = invitation_token.encode()
cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY)
chat_id = int(cipher_suite.decrypt(invitation_token))
chat = await ChatDAO.find_one_or_none(id=chat_id)
if user.id == chat.chat_for:
raise UserCanNotReadThisChatException
return {"user_added_to_chat": await ChatDAO.add_user_to_chat(chat_id=chat_id, user_id=user.id)}
async def invite_to_chat(invitation_token: str, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
async with uow:
invitation_token = invitation_token.encode()
cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY)
chat_id = int(cipher_suite.decrypt(invitation_token))
chat = await uow.chat.find_one_or_none(id=chat_id)
if user.id == chat.chat_for:
raise UserCanNotReadThisChatException
return {"user_added_to_chat": await uow.chat.add_user_to_chat(chat_id=chat_id, user_id=user.id)}
@router.delete(
@ -155,11 +128,12 @@ async def invite_to_chat(invitation_token: str, user: SUser = Depends(check_veri
status_code=status.HTTP_200_OK,
response_model=SDeletedChat,
)
async def delete_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
chat = await ChatDAO.find_one_or_none(id=chat_id)
if user.id == chat.created_by:
return {"deleted_chat": await ChatDAO.delete_chat(chat_id)}
raise UserDontHavePermissionException
async def delete_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
async with uow:
chat = await uow.chat.find_one_or_none(id=chat_id)
if user.id == chat.created_by:
return {"deleted_chat": await uow.chat.delete_chat(chat_id)}
raise UserDontHavePermissionException
@router.delete(
@ -167,11 +141,12 @@ async def delete_chat(chat_id: int, user: SUser = Depends(check_verificated_user
status_code=status.HTTP_200_OK,
response_model=SDeletedUser
)
async def delete_user_from_chat(chat_id: int, user_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
chat = await ChatDAO.find_one_or_none(id=chat_id)
if user.id == chat.created_by:
return {"deleted_user": await ChatDAO.delete_user(chat_id=chat_id, user_id=user_id)}
raise UserDontHavePermissionException
async def delete_user_from_chat(chat_id: int, user_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
async with uow:
chat = await uow.chat.find_one_or_none(id=chat_id)
if user.id == chat.created_by:
return {"deleted_user": await uow.chat.delete_user(chat_id=chat_id, user_id=user_id)}
raise UserDontHavePermissionException
@router.post(
@ -179,10 +154,11 @@ async def delete_user_from_chat(chat_id: int, user_id: int, user: SUser = Depend
status_code=status.HTTP_200_OK,
response_model=SPinnedChat
)
async def pinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id)
await ChatDAO.pin_chat(chat_id=chat_id, user_id=user.id)
return {"chat_id": chat_id, "user_id": user.id}
async def pinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id)
async with uow:
await uow.chat.pin_chat(chat_id=chat_id, user_id=user.id)
return {"chat_id": chat_id, "user_id": user.id}
@router.delete(
@ -190,10 +166,11 @@ async def pinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_w
status_code=status.HTTP_200_OK,
response_model=SPinnedChat
)
async def unpinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id)
await ChatDAO.unpin_chat(chat_id=chat_id, user_id=user.id)
return {"chat_id": chat_id, "user_id": user.id}
async def unpinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id)
async with uow:
await uow.chat.unpin_chat(chat_id=chat_id, user_id=user.id)
return {"chat_id": chat_id, "user_id": user.id}
@router.get(
@ -201,8 +178,9 @@ async def unpinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user
status_code=status.HTTP_200_OK,
response_model=list[SChat]
)
async def get_pinned_chats(user: SUser = Depends(check_verificated_user_with_exc)):
return await ChatDAO.get_pinned_chats(user_id=user.id)
async def get_pinned_chats(user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
async with uow:
return await uow.chat.get_pinned_chats(user_id=user.id)
@router.get(
@ -210,10 +188,11 @@ async def get_pinned_chats(user: SUser = Depends(check_verificated_user_with_exc
status_code=status.HTTP_200_OK,
response_model=list[SMessage] | None
)
async def pinned_messages(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id)
messages = await ChatDAO.get_pinned_messages(chat_id=chat_id)
if messages:
for mes in messages:
mes["created_at"] = mes["created_at"].isoformat()
return messages
async def pinned_messages(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id)
async with uow:
messages = await uow.chat.get_pinned_messages(chat_id=chat_id)
if messages:
for mes in messages:
mes["created_at"] = mes["created_at"].isoformat()
return messages

View file

@ -24,12 +24,6 @@ class SLastMessages(BaseModel):
messages_to_get: int
class SPinnedMessage(BaseModel):
message_id: int
user_id: int
chat_id: int
class SPinnedChat(BaseModel):
user_id: int
chat_id: int

View file

@ -2,6 +2,7 @@ from fastapi import WebSocket, WebSocketDisconnect, Depends
from app.exceptions import IncorrectDataException, UserDontHavePermissionException
from app.services.message_service import MessageService
from app.unit_of_work import UnitOfWork
from app.users.auth import AuthService
from app.chat.router import router
from app.chat.shemas import SSendMessage, SMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage
@ -22,14 +23,14 @@ class ConnectionManager:
def disconnect(self, chat_id: int, websocket: WebSocket):
self.active_connections[chat_id].remove(websocket)
async def broadcast(self, user_id: int, chat_id: int, message: dict):
async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict):
if "flag" not in message:
raise IncorrectDataException
if message["flag"] == "send":
message = SSendMessage.model_validate(message)
new_message = await self.add_message_to_database(user_id=user_id, chat_id=chat_id, message=message)
new_message = await self.add_message_to_database(uow=uow, user_id=user_id, chat_id=chat_id, message=message)
new_message = new_message.model_dump()
new_message["created_at"] = new_message["created_at"].isoformat()
new_message["flag"] = "send"
@ -40,7 +41,7 @@ class ConnectionManager:
if message.user_id != user_id:
raise UserDontHavePermissionException
deleted_message = await self.delete_message(message.id)
deleted_message = await self.delete_message(uow=uow, message_id=message.id)
new_message = {"deleted_message": deleted_message, "id": message.id, "flag": "delete"}
elif message["flag"] == "edit":
@ -49,7 +50,9 @@ class ConnectionManager:
if message.user_id != user_id:
raise UserDontHavePermissionException
edited_message = await self.edit_message(message.id, message.new_message, message.new_image_url)
edited_message = await self.edit_message(
uow=uow, message_id=message.id, new_message=message.new_message, image_url=message.new_image_url
)
new_message = {
"flag": "edit",
"id": message.id,
@ -60,14 +63,14 @@ class ConnectionManager:
elif message["flag"] == "pin":
message = SPinMessage.model_validate(message)
pinned_message = await self.pin_message(chat_id=chat_id, user_id=message.user_id, message_id=message.id)
pinned_message = await self.pin_message(uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id)
new_message = pinned_message.model_dump()
new_message["created_at"] = new_message["created_at"].isoformat()
new_message["flag"] = "pin"
elif message["flag"] == "unpin":
message = SUnpinMessage.model_validate(message)
unpinned_message = await self.unpin_message(chat_id=chat_id, message_id=message.id)
unpinned_message = await self.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id)
new_message = {"flag": "pin", "id": unpinned_message}
else:
@ -77,34 +80,34 @@ class ConnectionManager:
await websocket.send_json(new_message)
@staticmethod
async def add_message_to_database(user_id: int, chat_id: int, message: SSendMessage) -> SMessage:
async def add_message_to_database(uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage) -> SMessage:
new_message = await MessageService.send_message(
user_id=user_id, chat_id=chat_id, message=message.message, image_url=message.image_url
uow=uow, user_id=user_id, chat_id=chat_id, message=message.message, image_url=message.image_url
)
if message.answer:
new_message = await MessageService.add_answer(self_id=new_message.id, answer_id=message.answer)
new_message = await MessageService.add_answer(uow=uow, self_id=new_message.id, answer_id=message.answer)
return new_message
@staticmethod
async def delete_message(message_id: int) -> bool:
new_message = await MessageService.delete_message(message_id)
async def delete_message(uow: UnitOfWork, message_id: int) -> bool:
new_message = await MessageService.delete_message(uow=uow, message_id=message_id)
return new_message
@staticmethod
async def edit_message(message_id: int, new_message: str, image_url: str) -> bool:
async def edit_message(uow: UnitOfWork, message_id: int, new_message: str, image_url: str) -> bool:
new_message = await MessageService.edit_message(
message_id=message_id, new_message=new_message, new_image_url=image_url
uow=uow, message_id=message_id, new_message=new_message, new_image_url=image_url
)
return new_message
@staticmethod
async def pin_message(chat_id: int, user_id: int, message_id: int) -> SMessage:
pinned_message = await MessageService.pin_message(chat_id=chat_id, user_id=user_id, message_id=message_id)
async def pin_message(uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage:
pinned_message = await MessageService.pin_message(uow=uow, chat_id=chat_id, user_id=user_id, message_id=message_id)
return pinned_message
@staticmethod
async def unpin_message(chat_id: int, message_id: int) -> int:
unpinned_message_id = await MessageService.unpin_message(chat_id=chat_id, message_id=message_id)
async def unpin_message(uow: UnitOfWork, chat_id: int, message_id: int) -> int:
unpinned_message_id = await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message_id)
return unpinned_message_id
@ -114,14 +117,14 @@ manager = ConnectionManager()
@router.websocket(
"/ws/{chat_id}",
)
async def websocket_endpoint(chat_id: int, websocket: WebSocket, user: SUser = Depends(get_current_user_ws)):
await AuthService.check_verificated_user_with_exc(user_id=user.id)
await AuthService.validate_user_access_to_chat(user_id=user.id, chat_id=chat_id)
async def websocket_endpoint(chat_id: int, websocket: WebSocket, user: SUser = Depends(get_current_user_ws), uow=Depends(UnitOfWork)):
await AuthService.check_verificated_user_with_exc(uow=uow, user_id=user.id)
await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id)
await manager.connect(chat_id, websocket)
try:
while True:
data = await websocket.receive_json()
await manager.broadcast(user_id=user.id, chat_id=chat_id, message=data)
await manager.broadcast(uow=uow, user_id=user.id, chat_id=chat_id, message=data)
except WebSocketDisconnect:
manager.disconnect(chat_id, websocket)

View file

@ -25,6 +25,7 @@ class Settings(BaseSettings):
REDIS_HOST: str
REDIS_PORT: int
REDIS_DB: int
SMTP_HOST: str
SMTP_PORT: int

View file

@ -1,8 +1,6 @@
from sqlalchemy import select, insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import async_session_maker
class BaseDAO:
model = None
@ -11,20 +9,16 @@ class BaseDAO:
self.session = session
async def add(self, **data): # Метод добавляет данные в БД
async with async_session_maker() as session:
query = insert(self.model).values(**data).returning(self.model.id)
result = await session.execute(query)
await session.commit()
return result.scalar()
stmt = insert(self.model).values(**data).returning(self.model.id)
result = await self.session.execute(stmt)
return result.scalar()
async def find_one_or_none(self, **filter_by): # Метод проверяет наличие строки с заданными параметрами
async with async_session_maker() as session:
query = select(self.model).filter_by(**filter_by)
result = await session.execute(query)
return result.scalar_one_or_none()
query = select(self.model).filter_by(**filter_by)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def find_all(self, **filter_by): # Метод возвращает все строки таблицы или те, которые соответствуют отбору
async with async_session_maker() as session:
query = select(self.model.__table__.columns).filter_by(**filter_by)
result = await session.execute(query)
return result.mappings().all()
query = select(self.model.__table__.columns).filter_by(**filter_by)
result = await self.session.execute(query)
return result.mappings().all()

View file

@ -1,34 +1,42 @@
from app.chat.dao import ChatDAO
from app.chat.shemas import SMessage
from app.unit_of_work import UnitOfWork
class MessageService:
@staticmethod
async def send_message(user_id: int, chat_id: int, message: str, image_url: str | None = None) -> SMessage:
new_message = await ChatDAO.send_message(user_id=user_id, chat_id=chat_id, message=message, image_url=image_url)
return new_message
async def send_message(uow: UnitOfWork, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> SMessage:
async with uow:
new_message = await uow.chat.send_message(user_id=user_id, chat_id=chat_id, message=message, image_url=image_url)
return new_message
@staticmethod
async def add_answer(self_id: int, answer_id: int) -> SMessage:
new_message = await ChatDAO.add_answer(self_id=self_id, answer_id=answer_id)
return new_message
async def add_answer(uow: UnitOfWork, self_id: int, answer_id: int) -> SMessage:
async with uow:
new_message = await uow.chat.add_answer(self_id=self_id, answer_id=answer_id)
return new_message
@staticmethod
async def delete_message(message_id: int) -> bool:
new_message = await ChatDAO.delete_message(message_id=message_id)
return new_message
async def delete_message(uow: UnitOfWork, message_id: int) -> bool:
async with uow:
new_message = await uow.chat.delete_message(message_id=message_id)
return new_message
@staticmethod
async def edit_message(message_id: int, new_message: str, new_image_url: str) -> bool:
new_message = await ChatDAO.edit_message(message_id=message_id, new_message=new_message, new_image_url=new_image_url)
return new_message
async def edit_message(uow: UnitOfWork, message_id: int, new_message: str, new_image_url: str) -> bool:
async with uow:
new_message = await uow.chat.edit_message(message_id=message_id, new_message=new_message, new_image_url=new_image_url)
return new_message
@staticmethod
async def pin_message(chat_id: int, user_id: int, message_id: int) -> SMessage:
pinned_message = await ChatDAO.pin_message(chat_id=chat_id, message_id=message_id, user_id=user_id)
return pinned_message
async def pin_message(uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage:
async with uow:
pinned_message = await uow.chat.pin_message(chat_id=chat_id, message_id=message_id, user_id=user_id)
await uow.commit()
return pinned_message
@staticmethod
async def unpin_message(chat_id: int, message_id: int) -> int:
unpinned_message = await ChatDAO.unpin_message(chat_id=chat_id, message_id=message_id)
return unpinned_message
async def unpin_message(uow: UnitOfWork, chat_id: int, message_id: int) -> int:
async with uow:
unpinned_message = await uow.chat.unpin_message(chat_id=chat_id, message_id=message_id)
await uow.commit()
return unpinned_message

View file

@ -0,0 +1,24 @@
from redis.asyncio.client import Redis
from app.config import settings
def get_redis_session() -> Redis:
return Redis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB)
class RedisService:
@staticmethod
async def set_verification_code(redis: Redis, user_id: int, verification_code: str):
await redis.setex(f"user_verification_code: {user_id}", 1800, verification_code)
@staticmethod
async def get_verification_code(redis: Redis, user_id: int) -> str:
verification_code = await redis.get(f"user_verification_code: {user_id}")
return verification_code
@staticmethod
async def delete_verification_code(redis: Redis, user_id: int):
await redis.delete(f"user_verification_code: {user_id}")

View file

@ -1,9 +1,10 @@
from app.users.dao import UserDAO
from app.unit_of_work import UnitOfWork
from app.users.schemas import SUser
class UserService:
@staticmethod
async def find_one_or_none(user_id: int) -> SUser | None:
user = await UserDAO.find_one_or_none(id=user_id)
return user
async def find_one_or_none(uow: UnitOfWork, user_id: int) -> SUser | None:
async with uow:
user = await uow.user.find_one_or_none(id=user_id)
return user

View file

@ -72,3 +72,13 @@ def send_password_recover_email(username: str, email_to: EmailStr, MODE: str):
server.send_message(msg_content)
return confirmation_code
@celery.task
def send_data_change_confirmation_email(username: str, email_to: EmailStr, MODE: str):
pass
@celery.task
def send_data_change_email(username: str, email_to: EmailStr, MODE: str):
pass

24
app/unit_of_work.py Normal file
View file

@ -0,0 +1,24 @@
from app.chat.dao import ChatDAO
from app.database import async_session_maker
from app.users.dao import UserDAO
class UnitOfWork:
def __init__(self):
self.session_factory = async_session_maker
async def __aenter__(self):
self.session = self.session_factory()
self.user = UserDAO(self.session)
self.chat = ChatDAO(self.session)
async def __aexit__(self, *args):
await self.rollback()
await self.session.close()
async def commit(self):
await self.session.commit()
async def rollback(self):
await self.session.rollback()

View file

@ -12,9 +12,8 @@ from app.exceptions import (
UserNotFoundException,
UserMustConfirmEmailException,
)
from app.users.dao import UserDAO
from app.models.users import Users
from app.users.schemas import SUserRegister
from app.unit_of_work import UnitOfWork
from app.users.schemas import SUserRegister, SUser
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@ -41,68 +40,72 @@ def create_access_token(data: dict[str, str | datetime]) -> str:
class AuthService:
# Функция проверки наличия юзера по мейлу
@staticmethod
async def authenticate_user_by_email(email: EmailStr, password: str) -> Users | None:
user = await UserDAO.find_one_or_none(email=email)
if not user or not verify_password(password, user.hashed_password):
return None
return user
async def authenticate_user_by_email(uow: UnitOfWork, email: EmailStr, password: str) -> SUser | None:
async with uow:
user = await uow.user.find_one_or_none(email=email)
if not user or not verify_password(password, user.hashed_password):
return None
return user
# Функция проверки наличия юзера по нику
@staticmethod
async def authenticate_user_by_username(username: str, password: str) -> Users | None:
user = await UserDAO.find_one_or_none(username=username)
if not user or not verify_password(password, user.hashed_password):
return None
return user
async def authenticate_user_by_username(uow: UnitOfWork, username: str, password: str) -> SUser | None:
async with uow:
user = await uow.user.find_one_or_none(username=username)
if not user or not verify_password(password, user.hashed_password):
return None
return user
@classmethod
async def authenticate_user(cls, email_or_username: str, password: str) -> Users:
user = await cls.authenticate_user_by_email(email_or_username, password)
async def authenticate_user(cls, uow: UnitOfWork, email_or_username: str, password: str) -> SUser:
user = await cls.authenticate_user_by_email(uow, email_or_username, password)
if not user:
user = await cls.authenticate_user_by_username(email_or_username, password)
user = await cls.authenticate_user_by_username(uow, email_or_username, password)
if not user:
raise IncorrectAuthDataException
return user
@staticmethod
async def check_existing_user(user_data: SUserRegister) -> None:
existing_user = await UserDAO.find_one_or_none(email=user_data.email)
if existing_user:
raise UserAlreadyExistsException
existing_user = await UserDAO.find_one_or_none(username=user_data.username)
if existing_user:
raise UserAlreadyExistsException
async def check_existing_user(uow: UnitOfWork, user_data: SUserRegister) -> None:
async with uow:
existing_user = await uow.user.find_one_or_none(email=user_data.email)
if existing_user:
raise UserAlreadyExistsException
existing_user = await uow.user.find_one_or_none(username=user_data.username)
if existing_user:
raise UserAlreadyExistsException
@staticmethod
async def check_verificated_user(user_id: int) -> bool:
user = await UserDAO.find_one_or_none(id=user_id)
if not user:
raise UserNotFoundException
return user.role >= VERIFICATED_USER
async def check_verificated_user(uow: UnitOfWork, user_id: int) -> bool:
async with uow:
user = await uow.user.find_one_or_none(id=user_id)
if not user:
raise UserNotFoundException
return user.role >= VERIFICATED_USER
@classmethod
async def check_verificated_user_with_exc(cls, user_id: int):
if not await cls.check_verificated_user(user_id=user_id):
async def check_verificated_user_with_exc(cls, uow: UnitOfWork, user_id: int):
if not await cls.check_verificated_user(uow=uow, user_id=user_id):
raise UserMustConfirmEmailException
@staticmethod
async def get_user_allowed_chats_id(user_id: int) -> list[int]:
user_allowed_chats = await UserDAO.get_user_allowed_chats(user_id)
user_allowed_chats_id = [chat["chat_id"] for chat in user_allowed_chats]
return user_allowed_chats_id
async def get_user_allowed_chats_id(uow: UnitOfWork, user_id: int) -> list[int]:
async with uow:
user_allowed_chats = await uow.user.get_user_allowed_chats(user_id)
user_allowed_chats_id = [chat["chat_id"] for chat in user_allowed_chats]
return user_allowed_chats_id
@classmethod
async def validate_user_access_to_chat(cls, user_id: int, chat_id: int) -> bool:
user_allowed_chats = await cls.get_user_allowed_chats_id(user_id=user_id)
async def validate_user_access_to_chat(cls, uow: UnitOfWork, user_id: int, chat_id: int) -> bool:
user_allowed_chats = await cls.get_user_allowed_chats_id(uow=uow, user_id=user_id)
if chat_id not in user_allowed_chats:
raise UserDontHavePermissionException
return True
@staticmethod
async def validate_user_admin(user_id: int) -> bool:
user_role = await UserDAO.get_user_role(user_id=user_id)
if user_role == ADMIN_USER:
return True
return False
async def validate_user_admin(uow: UnitOfWork, user_id: int) -> bool:
async with uow:
user_role = await uow.user.get_user_role(user_id=user_id)
if user_role == ADMIN_USER:
return True
return False

View file

@ -1,11 +1,11 @@
from sqlalchemy import update, select, insert, and_, func, text, delete
from pydantic import HttpUrl
from sqlalchemy import update, select, insert, func
from app.dao.base import BaseDAO
from app.database import async_session_maker, engine # noqa
from app.database import engine # noqa
from app.models.chat import Chats
from app.models.user_avatar import UserAvatar
from app.models.users import Users
from app.models.user_verification_code import UserVerificationCode
from app.models.user_chat import UserChat
from app.users.schemas import SUser, SUserAvatars
@ -14,32 +14,25 @@ class UserDAO(BaseDAO):
model = Users
async def find_one_or_none(self, **filter_by) -> SUser | None:
async with async_session_maker() as session:
query = select(Users).filter_by(**filter_by)
result = await session.execute(query)
result = result.scalar_one_or_none()
if result:
return SUser.model_validate(result, from_attributes=True)
query = select(Users).filter_by(**filter_by)
result = await self.session.execute(query)
result = result.scalar_one_or_none()
if result:
return SUser.model_validate(result, from_attributes=True)
async def find_all(self, **filter_by):
async with async_session_maker() as session:
query = select(Users.__table__.columns).filter_by(**filter_by).where(Users.role != 100)
result = await session.execute(query)
return result.mappings().all()
query = select(Users.__table__.columns).filter_by(**filter_by).where(Users.role != 100)
result = await self.session.execute(query)
return result.mappings().all()
async def change_data(self, user_id: int, **data_to_change) -> str:
query = update(Users).where(Users.id == user_id).values(**data_to_change)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
query = select(Users.username).where(Users.id == user_id)
result = await session.execute(query)
return result.scalar()
query = update(Users).where(Users.id == user_id).values(**data_to_change).returning(Users.username)
result = await self.session.execute(query)
return result.scalar()
async def get_user_role(self, user_id: int) -> int:
query = select(Users.role).where(Users.id == user_id)
async with async_session_maker() as session:
result = await session.execute(query)
result = await self.session.execute(query)
return result.scalar()
async def get_user_allowed_chats(self, user_id: int):
@ -89,26 +82,18 @@ class UserDAO(BaseDAO):
chats_with_avatars.c.avatar_hex,
)
.select_from(chats_with_avatars)
.where(and_(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True)) # noqa: E712
.where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True) # noqa: E712
)
async with async_session_maker() as session:
result = await session.execute(query)
result = result.mappings().all()
result = await self.session.execute(query)
result = result.mappings().all()
return result
async def get_user_avatar(self, user_id: int) -> str:
query = select(Users.avatar_image).where(Users.id == user_id)
async with async_session_maker() as session:
result = await session.execute(query)
return result.scalar()
async def add_user_avatar(self, user_id: int, avatar: str) -> bool:
async def add_user_avatar(self, user_id: int, avatar: HttpUrl) -> bool:
query = insert(UserAvatar).values(user_id=user_id, avatar_image=avatar)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
await self.session.execute(query)
await self.session.commit()
return True
async def get_user_avatars(self, user_id: int) -> SUserAvatars:
query = select(
@ -124,54 +109,6 @@ class UserDAO(BaseDAO):
.scalar_subquery()
)
)
async with async_session_maker() as session:
result = await session.execute(query)
result = result.scalar()
return SUserAvatars.model_validate(result)
async def delete_user_avatar(self, avatar_id: int, user_id: int) -> bool:
query = delete(UserAvatar).where(and_(UserAvatar.id == avatar_id, UserAvatar.user_id == user_id))
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
class UserCodesDAO(BaseDAO):
model = UserVerificationCode
async def set_user_codes(self, cls, user_id: int, code: str, description: str):
query = (
insert(UserVerificationCode)
.values(user_id=user_id, code=code, description=description)
.returning(cls.model.code)
)
async with async_session_maker() as session:
result = await session.execute(query)
await session.commit()
return result.scalar()
async def get_user_codes(self, **filter_by) -> list[dict | None]:
"""
SELECT
usersverificationcodes.id,
usersverificationcodes.user_id,
usersverificationcodes.code,
usersverificationcodes.description,
usersverificationcodes.date_of_creation
FROM usersverificationcodes
WHERE
usersverificationcodes.user_id = 20
AND now() - usersverificationcodes.date_of_creation < INTERVAL '30 minutes'
"""
query = (
select(UserVerificationCode.__table__.columns)
.where((func.now() - UserVerificationCode.date_of_creation) < text("INTERVAL '10 minutes'"))
.filter_by(**filter_by)
)
async with async_session_maker() as session:
# print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса
result = await session.execute(query)
result = result.mappings().all()
return result
result = await self.session.execute(query)
result = result.scalar()
return SUserAvatars.model_validate(result)

View file

@ -10,6 +10,7 @@ from app.exceptions import (
UserMustConfirmEmailException,
)
from app.services.user_service import UserService
from app.unit_of_work import UnitOfWork
from app.users.auth import create_access_token, VERIFICATED_USER
from app.users.schemas import SUser
@ -21,7 +22,7 @@ def get_token(request: Request) -> str:
return token
async def get_current_user(response: Response, token: str = Depends(get_token)) -> SUser:
async def get_current_user(response: Response, token: str = Depends(get_token), uow=Depends(UnitOfWork)) -> SUser:
try:
payload = jwt.decode(token, settings.SECRET_KEY, settings.ALGORITHM)
except ExpiredSignatureError:
@ -33,7 +34,7 @@ async def get_current_user(response: Response, token: str = Depends(get_token))
if not user_id:
raise UserIsNotPresentException
user = await UserService.find_one_or_none(user_id=int(user_id))
user = await UserService.find_one_or_none(uow=uow, user_id=int(user_id))
if not user:
raise UserIsNotPresentException
@ -55,7 +56,7 @@ def get_token_ws(websocket: WebSocket) -> str:
return token
async def get_current_user_ws(token: str = Depends(get_token_ws)):
async def get_current_user_ws(token: str = Depends(get_token_ws), uow=Depends(UnitOfWork)):
try:
payload = jwt.decode(token, settings.SECRET_KEY, settings.ALGORITHM)
except ExpiredSignatureError:
@ -67,7 +68,7 @@ async def get_current_user_ws(token: str = Depends(get_token_ws)):
if not user_id:
raise UserIsNotPresentException
user = await UserService.find_one_or_none(user_id=int(user_id))
user = await UserService.find_one_or_none(uow=uow, user_id=int(user_id))
if not user:
raise UserIsNotPresentException

View file

@ -7,31 +7,30 @@ from app.config import settings
from app.exceptions import (
PasswordsMismatchException,
WrongCodeException,
UserNotFoundException,
SomethingWentWrongException,
UserAlreadyExistsException,
IncorrectAuthDataException,
)
from app.users.auth import get_password_hash, create_access_token, VERIFICATED_USER, AuthService
from app.users.dao import UserDAO, UserCodesDAO
from app.services.redis_service import RedisService, get_redis_session
from app.unit_of_work import UnitOfWork
from app.users.auth import get_password_hash, create_access_token, VERIFICATED_USER, AuthService, verify_password
from app.users.dependencies import get_current_user
from app.users.schemas import (
SUserLogin,
SUserRegister,
SUserResponse,
SUserPasswordRecover,
SUserCode,
SUserPasswordChange,
SRecoverEmailSent,
SUserToken,
SEmailVerification,
SConfirmPasswordRecovery,
SPasswordRecovered,
SUserAvatars,
SUsername,
SEmail, SUser,
SEmail,
SUser,
SUserChangeData,
SUserSendConfirmationCode,
)
from app.tasks.tasks import (
send_registration_confirmation_email,
send_data_change_confirmation_email
)
from app.tasks.tasks import send_registration_confirmation_email, send_password_change_email, \
send_password_recover_email
router = APIRouter(prefix="/users", tags=["Пользователи"])
@ -41,9 +40,10 @@ router = APIRouter(prefix="/users", tags=["Пользователи"])
status_code=status.HTTP_200_OK,
response_model=list[SUserResponse],
)
async def get_all_users():
users = await UserDAO.find_all()
return users
async def get_all_users(uow=Depends(UnitOfWork)):
async with uow:
users = await uow.user.find_all()
return users
@router.post(
@ -51,10 +51,11 @@ async def get_all_users():
status_code=status.HTTP_200_OK,
response_model=None,
)
async def check_existing_username(username: SUsername):
user = await UserDAO.find_one_or_none(username=username.username)
if user:
raise UserAlreadyExistsException
async def check_existing_username(username: SUsername, uow=Depends(UnitOfWork)):
async with uow:
user = await uow.user.find_one_or_none(username=username.username)
if user:
raise UserAlreadyExistsException
@router.post(
@ -62,10 +63,11 @@ async def check_existing_username(username: SUsername):
status_code=status.HTTP_200_OK,
response_model=None,
)
async def check_existing_email(email: SEmail):
user = await UserDAO.find_one_or_none(email=email.email)
if user:
raise UserAlreadyExistsException
async def check_existing_email(email: SEmail, uow=Depends(UnitOfWork)):
async with uow:
user = await uow.user.find_one_or_none(email=email.email)
if user:
raise UserAlreadyExistsException
@router.post(
@ -73,29 +75,30 @@ async def check_existing_email(email: SEmail):
status_code=status.HTTP_201_CREATED,
response_model=SUserToken,
)
async def register_user(response: Response, user_data: SUserRegister):
async def register_user(response: Response, user_data: SUserRegister, uow=Depends(UnitOfWork)):
if user_data.password != user_data.password2:
raise PasswordsMismatchException
await AuthService.check_existing_user(user_data)
await AuthService.check_existing_user(uow, user_data)
hashed_password = get_password_hash(user_data.password)
user_id = await UserDAO.add(
email=user_data.email,
hashed_password=hashed_password,
username=user_data.username,
date_of_birth=user_data.date_of_birth,
)
async with uow:
user_id = await uow.user.add(
email=user_data.email,
hashed_password=hashed_password,
username=user_data.username,
date_of_birth=user_data.date_of_birth,
)
result = send_registration_confirmation_email.delay(
user_id=user_id, username=user_data.username, email_to=user_data.email, MODE=settings.MODE
)
result = result.get()
if await UserCodesDAO.set_user_codes(user_id=user_id, code=result, description="Код подтверждения почты") == result:
user = await AuthService.authenticate_user_by_email(user_data.email, user_data.password)
access_token = create_access_token({"sub": str(user.id)})
response.set_cookie(key="black_phoenix_access_token", value=access_token, httponly=True, secure=True)
return {"access_token": access_token}
user_code = result.get()
redis_session = get_redis_session()
await RedisService.set_verification_code(redis=redis_session, user_id=user_id, verification_code=user_code)
user = await AuthService.authenticate_user_by_email(uow, user_data.email, user_data.password)
access_token = create_access_token({"sub": str(user.id)})
response.set_cookie(key="black_phoenix_access_token", value=access_token, httponly=True, secure=True)
return {"access_token": access_token}
@router.get(
@ -103,17 +106,20 @@ async def register_user(response: Response, user_data: SUserRegister):
status_code=status.HTTP_200_OK,
response_model=SEmailVerification,
)
async def email_verification(user_code: str):
async def email_verification(user_code: str, uow=Depends(UnitOfWork)):
invitation_token = user_code.encode()
cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY)
user_data = cipher_suite.decrypt(invitation_token).decode("utf-8")
user_data = json.loads(user_data)
user_codes = await UserCodesDAO.get_user_codes(
user_id=user_data["user_id"], description="Код подтверждения почты", code=user_data["confirmation_code"]
)
if not user_codes or not await UserDAO.change_data(user_id=user_data["user_id"], role=VERIFICATED_USER):
raise WrongCodeException
return {"email_verification": True}
redis_session = get_redis_session()
async with uow:
verification_code = await RedisService.get_verification_code(redis=redis_session, user_id=user_data["user_id"])
if verification_code != user_data["confirmation_code"]:
raise WrongCodeException
await uow.user.change_data(user_id=user_data["user_id"], role=VERIFICATED_USER)
await uow.commit()
return {"email_verification": True}
@router.post(
@ -121,8 +127,8 @@ async def email_verification(user_code: str):
status_code=status.HTTP_200_OK,
response_model=SUserToken,
)
async def login_user(response: Response, user_data: SUserLogin):
user = await AuthService.authenticate_user(user_data.email_or_username, user_data.password)
async def login_user(response: Response, user_data: SUserLogin, uow=Depends(UnitOfWork)):
user = await AuthService.authenticate_user(uow, user_data.email_or_username, user_data.password)
access_token = create_access_token({"sub": str(user.id)})
response.set_cookie("black_phoenix_access_token", access_token, httponly=True)
return {"access_token": access_token}
@ -146,61 +152,56 @@ async def get_user(current_user: SUser = Depends(get_current_user)):
return current_user
@router.patch(
"/send_recovery_email",
status_code=status.HTTP_200_OK,
response_model=SRecoverEmailSent,
)
async def send_recovery_email(email: SUserPasswordRecover):
existing_user = await UserDAO.find_one_or_none(email=email.email)
if not existing_user:
raise UserNotFoundException
result = send_password_recover_email.delay(existing_user.username, existing_user.email, MODE=settings.MODE)
result = result.get()
if (
await UserCodesDAO.set_user_codes(
user_id=existing_user.user_id, code=result, description="Код восстановления пароля"
)
== result
):
return {"recover_email_sent": True}
raise SomethingWentWrongException
@router.post(
"/confirm_password_recovery",
status_code=status.HTTP_200_OK,
response_model=SConfirmPasswordRecovery,
)
async def confirm_password_recovery(user_code: SUserCode):
user_codes = await UserCodesDAO.get_user_codes(description="Код восстановления пароля", code=user_code.user_code)
if not user_codes:
raise WrongCodeException
return {"user_id": user_codes[0]["user_id"]}
@router.post(
"/password_recovery",
status_code=status.HTTP_200_OK,
response_model=SPasswordRecovered,
)
async def password_recovery(passwords: SUserPasswordChange):
if passwords.password1 != passwords.password2:
raise PasswordsMismatchException
hashed_password = get_password_hash(passwords.password1)
username = await UserDAO.change_data(passwords.user_id, hashed_password=hashed_password)
user = await UserDAO.find_one_or_none(username=username, id=passwords.user_id)
if not user:
raise UserNotFoundException
send_password_change_email.delay(user.username, user.email, MODE=settings.MODE)
return {"username": username}
@router.get(
"/avatars",
status_code=status.HTTP_200_OK,
response_model=SUserAvatars,
)
async def get_user_avatars_history(user: SUser = Depends(get_current_user)):
return await UserDAO.get_user_avatars(user_id=user.id)
async def get_user_avatars_history(user=Depends(get_current_user), uow=Depends(UnitOfWork)):
async with uow:
return await uow.user.get_user_avatars(user_id=user.id)
@router.post(
"/send_confirmation_code",
status_code=status.HTTP_200_OK,
response_model=None,
)
async def send_confirmation_code(user_data: SUserSendConfirmationCode, user: SUser = Depends(get_current_user)):
redis_session = get_redis_session()
if verify_password(user_data.current_password, user.hashed_password):
verification_code = send_data_change_confirmation_email.delay(
username=user.username, email_to=user_data.email, MODE=settings.MODE
)
verification_code = verification_code.get()
await RedisService.set_verification_code(redis=redis_session, user_id=user.id, verification_code=verification_code)
raise IncorrectAuthDataException
@router.post(
"/change_data",
status_code=status.HTTP_200_OK,
response_model=None,
)
async def change_user_data(user_data: SUserChangeData, user=Depends(get_current_user), uow=Depends(UnitOfWork)):
redis_session = get_redis_session()
verification_code = await RedisService.get_verification_code(redis=redis_session, user_id=user.id)
if verification_code != user_data.verification_code:
raise WrongCodeException
if user_data.new_password:
hashed_password = get_password_hash(user_data.new_password)
else:
hashed_password = user.hashed_password
async with uow:
await uow.user.change_data(
user_id=user.id,
email=user_data.email,
username=user_data.username,
avatar_url=user_data.avatar_url,
hashed_password=hashed_password
)
await uow.user.add_user_avatar(user_id=user.id, avatar=user_data.avatar_url)
await uow.commit()
await RedisService.delete_verification_code(redis=redis_session, user_id=user.id)

View file

@ -59,23 +59,6 @@ class SUser(BaseModel):
date_of_registration: date
class SUserRename(BaseModel):
username: str = Query(None, min_length=2, max_length=30)
password: str
class SUserAvatar(BaseModel):
password: str
new_avatar_image: HttpUrl
avatar_hex: str
class SUserPassword(BaseModel):
current_password: str = Query(None, min_length=8)
new_password: str = Query(None, min_length=8)
new_password2: str = Query(None, min_length=8)
class SUserPasswordRecover(BaseModel):
email: EmailStr
@ -90,6 +73,19 @@ class SUserPasswordChange(BaseModel):
password2: str = Query(None, min_length=8)
class SUserSendConfirmationCode(BaseModel):
email: EmailStr
current_password: str = Query(None, min_length=8)
class SUserChangeData(BaseModel):
verification_code: str
email: EmailStr
username: str = Query(None, min_length=2, max_length=30)
new_password: str | None = Query(None, min_length=8)
avatar_url: HttpUrl
class SRecoverEmailSent(BaseModel):
recover_email_sent: bool
@ -102,15 +98,6 @@ class SUserToken(BaseModel):
access_token: str
class SUserName(BaseModel):
username: str
class SNewAvatar(BaseModel):
new_avatar_image: HttpUrl
avatar_hex: str
class SConfirmPasswordRecovery(BaseModel):
user_id: int