Почистил код
This commit is contained in:
parent
47ae518794
commit
f4448a5402
8 changed files with 99 additions and 111 deletions
|
@ -29,8 +29,8 @@ router = APIRouter(prefix="/chat", tags=["Чат"])
|
|||
)
|
||||
async def get_all_chats(user: SUser = Depends(check_verificated_user_with_exc), uow=Depends(UnitOfWork)):
|
||||
async with uow:
|
||||
result = await uow.user.get_user_allowed_chats(user.id)
|
||||
return {"allowed_chats": result}
|
||||
allowed_chats = await uow.user.get_user_allowed_chats(user.id)
|
||||
return allowed_chats
|
||||
|
||||
|
||||
@router.post(
|
||||
|
|
|
@ -38,7 +38,7 @@ class SChat(BaseModel):
|
|||
|
||||
|
||||
class SAllowedChats(BaseModel):
|
||||
allowed_chats: list[SChat | None]
|
||||
allowed_chats: list[SChat] | None
|
||||
|
||||
|
||||
class SPinnedChats(BaseModel):
|
||||
|
|
|
@ -95,7 +95,7 @@ class ChatDAO(BaseDAO):
|
|||
raise MessageNotFoundException
|
||||
|
||||
async def delete_message(self, message_id: int) -> None:
|
||||
stmt = update(Message).where(Message.id == message_id).values(visibility=False)
|
||||
stmt = update(Message).values(visibility=False).where(Message.id == message_id)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageList:
|
||||
|
|
135
app/dao/user.py
135
app/dao/user.py
|
@ -1,10 +1,11 @@
|
|||
from pydantic import HttpUrl
|
||||
from sqlalchemy import update, select, insert, func
|
||||
from sqlalchemy.exc import MultipleResultsFound, IntegrityError
|
||||
from sqlalchemy.exc import MultipleResultsFound, IntegrityError, NoResultFound
|
||||
|
||||
from app.chat.shemas import SAllowedChats
|
||||
from app.dao.base import BaseDAO
|
||||
from app.database import engine # noqa
|
||||
from app.exceptions import IncorrectDataException
|
||||
from app.exceptions import IncorrectDataException, UserNotFoundException
|
||||
from app.users.exceptions import UserAlreadyExistsException
|
||||
from app.models.chat import Chat
|
||||
from app.models.user_avatar import UserAvatar
|
||||
|
@ -16,96 +17,86 @@ from app.users.schemas import SUser, SUserAvatars, SUsers
|
|||
class UserDAO(BaseDAO):
|
||||
model = Users
|
||||
|
||||
async def add(self, **data) -> int:
|
||||
@staticmethod
|
||||
def check_query_compile(query):
|
||||
print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса
|
||||
|
||||
async def add(self, **data) -> SUser:
|
||||
try:
|
||||
stmt = insert(self.model).values(**data).returning(Users.id)
|
||||
stmt = insert(self.model).values(**data).returning(Users.__table__.columns)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar()
|
||||
result = result.mappings().one()
|
||||
return SUser.model_validate(result)
|
||||
except IntegrityError:
|
||||
raise UserAlreadyExistsException
|
||||
|
||||
async def find_one_or_none(self, **filter_by) -> SUser | None:
|
||||
async def find_one_or_none(self, **filter_by) -> SUser:
|
||||
try:
|
||||
query = select(Users.__table__.columns).filter_by(**filter_by)
|
||||
result = await self.session.execute(query)
|
||||
result = result.mappings().one_or_none()
|
||||
if result:
|
||||
result = result.mappings().one()
|
||||
return SUser.model_validate(result)
|
||||
except MultipleResultsFound:
|
||||
raise IncorrectDataException
|
||||
except NoResultFound:
|
||||
raise UserNotFoundException
|
||||
|
||||
async def find_all(self) -> SUsers:
|
||||
query = select(Users.__table__.columns).where(Users.role != 100)
|
||||
result = await self.session.execute(query)
|
||||
return SUsers.model_validate({"users": result.mappings().all()})
|
||||
|
||||
async def change_data(self, user_id: int, **data_to_change) -> str:
|
||||
query = update(Users).where(Users.id == user_id).values(**data_to_change).returning(Users.username)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar()
|
||||
|
||||
async def get_user_role(self, user_id: int) -> int:
|
||||
query = select(Users.role).where(Users.id == user_id)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar()
|
||||
|
||||
async def get_user_allowed_chats(self, user_id: int):
|
||||
"""
|
||||
WITH chats_with_descriptions AS (
|
||||
SELECT *
|
||||
FROM usersxchats
|
||||
LEFT JOIN chats ON usersxchats.chat_id = chats.id
|
||||
),
|
||||
|
||||
chats_with_avatars as (
|
||||
SELECT *
|
||||
FROM chats_with_descriptions
|
||||
LEFT JOIN users ON chats_with_descriptions.user_id = users.id
|
||||
)
|
||||
|
||||
SELECT chat_id, chat_for, chat_name, avatar_image
|
||||
FROM chats_with_avatars
|
||||
WHERE user_id = 1
|
||||
"""
|
||||
chats_with_descriptions = (
|
||||
select(UserChat.__table__.columns, Chat.__table__.columns)
|
||||
.select_from(UserChat)
|
||||
.join(Chat, UserChat.chat_id == Chat.id)
|
||||
).cte("chats_with_descriptions")
|
||||
|
||||
chats_with_avatars = (
|
||||
select(
|
||||
chats_with_descriptions.c.chat_id,
|
||||
chats_with_descriptions.c.chat_for,
|
||||
chats_with_descriptions.c.chat_name,
|
||||
chats_with_descriptions.c.visibility,
|
||||
Users.id,
|
||||
Users.avatar_image,
|
||||
)
|
||||
.select_from(chats_with_descriptions)
|
||||
.join(Users, Users.id == chats_with_descriptions.c.user_id)
|
||||
.cte("chats_with_avatars")
|
||||
)
|
||||
query = (
|
||||
select(
|
||||
chats_with_avatars.c.chat_id,
|
||||
chats_with_avatars.c.chat_for,
|
||||
chats_with_avatars.c.chat_name,
|
||||
chats_with_avatars.c.avatar_image,
|
||||
func.json_build_object(
|
||||
"users", func.json_agg(
|
||||
func.json_build_object(
|
||||
"id", Users.id,
|
||||
"username", Users.username,
|
||||
"email", Users.email,
|
||||
"black_phoenix", Users.black_phoenix,
|
||||
"avatar_image", Users.avatar_image,
|
||||
"date_of_birth", Users.date_of_birth,
|
||||
"date_of_registration", Users.date_of_registration,
|
||||
)
|
||||
.select_from(chats_with_avatars)
|
||||
.where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True) # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
.select_from(Users)
|
||||
.where(Users.role != 100)
|
||||
)
|
||||
result = await self.session.execute(query)
|
||||
result = result.scalar_one()
|
||||
return SUsers.model_validate(result)
|
||||
|
||||
async def change_data(self, user_id: int, **data_to_change) -> None:
|
||||
stmt = update(Users).where(Users.id == user_id).values(**data_to_change)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
async def get_user_allowed_chats(self, user_id: int) -> SAllowedChats:
|
||||
|
||||
query = (
|
||||
select(
|
||||
func.json_build_object(
|
||||
"allowed_chats", func.json_agg(
|
||||
func.json_build_object(
|
||||
"chat_id", Chat.id,
|
||||
"chat_for", Chat.chat_for,
|
||||
"chat_name", Chat.chat_name,
|
||||
"avatar_image", Users.avatar_image,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
.select_from(UserChat)
|
||||
.join(Chat, Chat.id == UserChat.chat_id)
|
||||
.join(Users, Users.id == UserChat.user_id)
|
||||
.where(UserChat.user_id == user_id, Chat.visibility == True) # noqa: E712
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
result = result.mappings().all()
|
||||
return result
|
||||
result = result.scalar_one()
|
||||
return SAllowedChats.model_validate(result)
|
||||
|
||||
async def add_user_avatar(self, user_id: int, avatar: HttpUrl) -> bool:
|
||||
query = insert(UserAvatar).values(user_id=user_id, avatar_image=avatar)
|
||||
await self.session.execute(query)
|
||||
await self.session.commit()
|
||||
return True
|
||||
async def add_user_avatar(self, user_id: int, avatar: HttpUrl) -> None:
|
||||
stmt = insert(UserAvatar).values(user_id=user_id, avatar_image=avatar)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
async def get_user_avatars(self, user_id: int) -> SUserAvatars:
|
||||
query = select(
|
||||
|
|
|
@ -38,8 +38,6 @@ async def get_current_user(token: str = Depends(get_token), uow=Depends(UnitOfWo
|
|||
raise UserNotFoundException
|
||||
|
||||
user = await UserService.find_one_or_none(uow=uow, user_id=int(user_id))
|
||||
if not user:
|
||||
raise UserNotFoundException
|
||||
|
||||
return user
|
||||
|
||||
|
@ -69,8 +67,6 @@ async def get_current_user_ws(token: str = Depends(get_token_ws), uow=Depends(Un
|
|||
raise UserNotFoundException
|
||||
|
||||
user = await UserService.find_one_or_none(uow=uow, user_id=int(user_id))
|
||||
if not user:
|
||||
raise UserNotFoundException
|
||||
|
||||
return user
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from fastapi import APIRouter, Depends, status
|
||||
|
||||
from app.config import settings
|
||||
from app.exceptions import UserNotFoundException
|
||||
from app.users.exceptions import (
|
||||
UserAlreadyExistsException,
|
||||
IncorrectAuthDataException,
|
||||
|
@ -58,9 +59,11 @@ async def get_all_users(uow=Depends(UnitOfWork)):
|
|||
)
|
||||
async def check_existing_user(user_filter: SUserFilter, uow=Depends(UnitOfWork)):
|
||||
async with uow:
|
||||
user = await uow.user.find_one_or_none(**user_filter.model_dump(exclude_none=True))
|
||||
if user:
|
||||
try:
|
||||
await uow.user.find_one_or_none(**user_filter.model_dump(exclude_none=True))
|
||||
raise UserAlreadyExistsException
|
||||
except UserNotFoundException:
|
||||
pass
|
||||
|
||||
|
||||
@router.post(
|
||||
|
@ -74,22 +77,23 @@ async def register_user(user_data: SUserRegister, uow=Depends(UnitOfWork)):
|
|||
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
async with uow:
|
||||
user_id = await uow.user.add(
|
||||
user = await uow.user.add(
|
||||
email=user_data.email,
|
||||
hashed_password=hashed_password,
|
||||
username=user_data.username,
|
||||
date_of_birth=user_data.date_of_birth,
|
||||
)
|
||||
await uow.user.add_user_avatar(user_id=user.id, avatar=str(user.avatar_image))
|
||||
await uow.commit()
|
||||
|
||||
user_code = generate_confirmation_code()
|
||||
user_mail_data = SConfirmationData.model_validate(
|
||||
{"user_id": user_id, "username": user_data.username, "email_to": user_data.email, "confirmation_code": user_code}
|
||||
{"user_id": user.id, "username": user_data.username, "email_to": user_data.email, "confirmation_code": user_code}
|
||||
)
|
||||
send_registration_confirmation_email.delay(user_mail_data.model_dump())
|
||||
redis_session = get_redis_session()
|
||||
await RedisService.set_verification_code(redis=redis_session, user_id=user_id, verification_code=user_code)
|
||||
access_token = create_access_token({"sub": str(user_id)})
|
||||
await RedisService.set_verification_code(redis=redis_session, user_id=user.id, verification_code=user_code)
|
||||
access_token = create_access_token({"sub": str(user.id)})
|
||||
return {"authorization": f"Bearer {access_token}"}
|
||||
|
||||
|
||||
|
|
|
@ -35,9 +35,9 @@ class SUserRegister(BaseModel):
|
|||
|
||||
|
||||
class SUserResponse(BaseModel):
|
||||
email: EmailStr
|
||||
id: int
|
||||
username: str
|
||||
email: EmailStr
|
||||
black_phoenix: bool
|
||||
avatar_image: HttpUrl
|
||||
date_of_birth: date
|
||||
|
@ -45,7 +45,7 @@ class SUserResponse(BaseModel):
|
|||
|
||||
|
||||
class SUsers(BaseModel):
|
||||
users: list[SUserResponse]
|
||||
users: list[SUserResponse] | None
|
||||
|
||||
|
||||
class SUser(BaseModel):
|
||||
|
|
|
@ -61,19 +61,25 @@ def decode_confirmation_token(invitation_token: str) -> SConfirmationData:
|
|||
class AuthService:
|
||||
@staticmethod
|
||||
async def authenticate_user_by_email(uow: UnitOfWork, email: EmailStr, password: str) -> SUser | None:
|
||||
try:
|
||||
async with uow:
|
||||
user = await uow.user.find_one_or_none(email=email)
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
return None
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return
|
||||
return user
|
||||
except UserNotFoundException:
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_user_by_username(uow: UnitOfWork, username: str, password: str) -> SUser | None:
|
||||
try:
|
||||
async with uow:
|
||||
user = await uow.user.find_one_or_none(username=username)
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
return None
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return
|
||||
return user
|
||||
except UserNotFoundException:
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def authenticate_user(cls, uow: UnitOfWork, email_or_username: str, password: str) -> SUser:
|
||||
|
@ -88,20 +94,18 @@ class AuthService:
|
|||
async def check_verificated_user(uow: UnitOfWork, user_id: int) -> bool:
|
||||
async with uow:
|
||||
user = await uow.user.find_one_or_none(id=user_id)
|
||||
if not user:
|
||||
raise UserNotFoundException
|
||||
return user.role >= settings.VERIFICATED_USER
|
||||
|
||||
@classmethod
|
||||
async def check_verificated_user_with_exc(cls, uow: UnitOfWork, user_id: int):
|
||||
async def check_verificated_user_with_exc(cls, uow: UnitOfWork, user_id: int) -> None:
|
||||
if not await cls.check_verificated_user(uow=uow, user_id=user_id):
|
||||
raise UserMustConfirmEmailException
|
||||
|
||||
@staticmethod
|
||||
async def get_user_allowed_chats_id(uow: UnitOfWork, user_id: int) -> list[int]:
|
||||
async def get_user_allowed_chats_id(uow: UnitOfWork, user_id: int) -> set[int]:
|
||||
async with uow:
|
||||
user_allowed_chats = await uow.user.get_user_allowed_chats(user_id)
|
||||
user_allowed_chats_id = [chat["chat_id"] for chat in user_allowed_chats]
|
||||
user_allowed_chats_id = {chat["chat_id"] for chat in user_allowed_chats.allowed_chats}
|
||||
return user_allowed_chats_id
|
||||
|
||||
@classmethod
|
||||
|
@ -111,10 +115,3 @@ class AuthService:
|
|||
raise UserDontHavePermissionException
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def validate_user_admin(uow: UnitOfWork, user_id: int) -> bool:
|
||||
async with uow:
|
||||
user_role = await uow.user.get_user_role(user_id=user_id)
|
||||
if user_role == settings.ADMIN_USER:
|
||||
return True
|
||||
return False
|
||||
|
|
Loading…
Add table
Reference in a new issue