Переделал дао, переместил тесты

This commit is contained in:
urec56 2024-06-01 21:38:14 +05:00
parent ccd1e209f1
commit 9ad60c2a8e
16 changed files with 54 additions and 114 deletions

View file

@ -1,6 +1,6 @@
[flake8]
max-line-length = 125
extend-ignore = W191, W391, E712
extend-ignore = W191, W391
python_version = 3.12
exclude =
.git,

View file

@ -16,8 +16,7 @@ from app.models.user_chat import UserChat
class ChatDAO(BaseDAO):
model = Chats
@staticmethod
async def create(user_id: int, chat_name: str, created_by: int) -> int:
async def create(self, user_id: int, chat_name: str, created_by: int) -> int:
query = insert(Chats).values(chat_for=user_id, chat_name=chat_name, created_by=created_by).returning(Chats.id)
async with async_session_maker() as session:
result = await session.execute(query)
@ -25,8 +24,7 @@ class ChatDAO(BaseDAO):
result = result.scalar()
return result
@staticmethod
async def add_user_to_chat(user_id: int, chat_id: int) -> bool:
async def add_user_to_chat(self, user_id: int, chat_id: int) -> bool:
query = select(UserChat.user_id).where(UserChat.chat_id == chat_id)
async with async_session_maker() as session:
result = await session.execute(query)
@ -38,8 +36,7 @@ class ChatDAO(BaseDAO):
await session.commit()
return True
@staticmethod
async def send_message(user_id: int, chat_id: int, message: str, image_url: str | None = None) -> SMessage:
async def send_message(self, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> SMessage:
inserted_image = (
insert(Message)
.values(chat_id=chat_id, user_id=user_id, message=message, image_url=image_url)
@ -74,8 +71,7 @@ class ChatDAO(BaseDAO):
result = result.mappings().one()
return SMessage.model_validate(result, from_attributes=True)
@staticmethod
async def get_message_by_id(message_id: int):
async def get_message_by_id(self, message_id: int):
query = (
select(
Message.id,
@ -93,7 +89,7 @@ class ChatDAO(BaseDAO):
.select_from(Message)
.join(Users, Users.id == Message.user_id)
.join(Answer, Answer.self_id == Message.id, isouter=True)
.where(Message.id == message_id, Message.visibility == True) #
.where(Message.id == message_id, Message.visibility == True) # noqa: E712
)
async with async_session_maker() as session:
result = await session.execute(query)
@ -101,16 +97,14 @@ class ChatDAO(BaseDAO):
if result:
return result[0]
@staticmethod
async def delete_message(message_id: int) -> bool:
async def delete_message(self, message_id: int) -> bool:
query = update(Message).where(Message.id == message_id).values(visibility=False)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
@staticmethod
async def get_some_messages(chat_id: int, message_number_from: int, messages_to_get: int) -> list[dict]:
async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> list[dict]:
"""
WITH messages_with_users AS (
SELECT *
@ -146,7 +140,7 @@ class ChatDAO(BaseDAO):
messages_with_users.c.self_id,
messages_with_users.c.answer_id,
)
.where(messages_with_users.c.chat_id == chat_id, messages_with_users.c.visibility == True)
.where(messages_with_users.c.chat_id == chat_id, messages_with_users.c.visibility == True) # noqa: E712
.order_by(messages_with_users.c.created_at.desc())
.limit(messages_to_get)
.offset(message_number_from)
@ -158,16 +152,14 @@ class ChatDAO(BaseDAO):
result = [dict(res) for res in result]
return result
@staticmethod
async def edit_message(message_id: int, new_message: str, new_image_url: str) -> bool:
async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> bool:
query = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
@staticmethod
async def add_answer(self_id: int, answer_id: int) -> SMessage:
async def add_answer(self, self_id: int, answer_id: int) -> SMessage:
answer = (
insert(Answer)
.values(self_id=self_id, answer_id=answer_id)
@ -201,24 +193,21 @@ class ChatDAO(BaseDAO):
result = result.mappings().one()
return SMessage.model_validate(result)
@staticmethod
async def delete_chat(chat_id: int) -> bool:
async def delete_chat(self, chat_id: int) -> bool:
query = update(Chats).where(Chats.id == chat_id).values(visibility=False)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
@staticmethod
async def delete_user(chat_id: int, user_id: int) -> bool:
async def delete_user(self, chat_id: int, user_id: int) -> bool:
query = delete(UserChat).where(UserChat.chat_id == chat_id, UserChat.user_id == user_id)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
@staticmethod
async def pinn_chat(chat_id: int, user_id: int) -> bool:
async def pin_chat(self, chat_id: int, user_id: int) -> bool:
query = select(PinnedChats.chat_id).where(PinnedChats.user_id == user_id)
async with async_session_maker() as session:
result = await session.execute(query)
@ -230,16 +219,14 @@ class ChatDAO(BaseDAO):
await session.commit()
return True
@staticmethod
async def unpinn_chat(chat_id: int, user_id: int) -> bool:
async def unpin_chat(self, chat_id: int, user_id: int) -> bool:
query = delete(PinnedChats).where(PinnedChats.chat_id == chat_id, PinnedChats.user_id == user_id)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
@staticmethod
async def get_pinned_chats(user_id: int):
async def get_pinned_chats(self, user_id: int):
chats_with_descriptions = (
select(UserChat.__table__.columns, Chats.__table__.columns)
.select_from(UserChat)
@ -273,7 +260,7 @@ class ChatDAO(BaseDAO):
.distinct()
.select_from(PinnedChats)
.join(chats_with_avatars, PinnedChats.chat_id == chats_with_avatars.c.chat_id)
.where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True)
.where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True) # noqa: E712
)
# print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса
async with async_session_maker() as session:
@ -281,34 +268,21 @@ class ChatDAO(BaseDAO):
result = result.mappings().all()
return result
@staticmethod
async def pinn_message(chat_id: int, message_id: int, user_id: int) -> bool:
async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> bool:
query = insert(PinnedMessages).values(chat_id=chat_id, message_id=message_id, user_id=user_id)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
@staticmethod
async def get_message_pinner(chat_id: int, message_id: int) -> bool:
query = select(PinnedMessages.user_id).where(
PinnedMessages.chat_id == chat_id, PinnedMessages.message_id == message_id
)
async with async_session_maker() as session:
result = await session.execute(query)
result = result.scalar()
return result
@staticmethod
async def unpinn_message(chat_id: int, message_id: int) -> bool:
async def unpin_message(self, chat_id: int, message_id: int) -> bool:
query = delete(PinnedMessages).where(PinnedMessages.chat_id == chat_id, PinnedMessages.message_id == message_id)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
@staticmethod
async def get_pinned_messages(chat_id: int) -> list[dict]:
async def get_pinned_messages(self, chat_id: int) -> list[dict]:
query = (
select(
Message.id,
@ -327,7 +301,7 @@ class ChatDAO(BaseDAO):
.join(Message, PinnedMessages.message_id == Message.id, isouter=True)
.join(Users, PinnedMessages.user_id == Users.id, isouter=True)
.join(Answer, Answer.self_id == Message.id, isouter=True)
.where(PinnedMessages.chat_id == chat_id, Message.visibility == True)
.where(PinnedMessages.chat_id == chat_id, Message.visibility == True) # noqa: E712
.order_by(Message.created_at.desc())
)
async with async_session_maker() as session:

View file

@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, status
from app.config import settings
from app.exceptions import UserDontHavePermissionException, MessageNotFoundException, UserCanNotReadThisChatException
from app.chat.dao import ChatDAO
from app.chat.shemas import SMessage, SLastMessages, SPinnedMessage, SPinnedChat, SDeletedUser, SChat, SDeletedChat
from app.chat.shemas import SMessage, SLastMessages, SPinnedChat, SDeletedUser, SChat, SDeletedChat
from app.users.dao import UserDAO
from app.users.dependencies import check_verificated_user_with_exc
@ -181,7 +181,7 @@ async def delete_user_from_chat(chat_id: int, user_id: int, user: SUser = Depend
)
async def pinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id)
await ChatDAO.pinn_chat(chat_id=chat_id, user_id=user.id)
await ChatDAO.pin_chat(chat_id=chat_id, user_id=user.id)
return {"chat_id": chat_id, "user_id": user.id}
@ -192,7 +192,7 @@ async def pinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_w
)
async def unpinn_chat(chat_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id)
await ChatDAO.unpinn_chat(chat_id=chat_id, user_id=user.id)
await ChatDAO.unpin_chat(chat_id=chat_id, user_id=user.id)
return {"chat_id": chat_id, "user_id": user.id}
@ -205,31 +205,6 @@ async def get_pinned_chats(user: SUser = Depends(check_verificated_user_with_exc
return await ChatDAO.get_pinned_chats(user_id=user.id)
@router.post(
"/pin_message",
status_code=status.HTTP_200_OK,
response_model=SPinnedMessage
)
async def pinn_message(chat_id: int, message_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id)
await ChatDAO.pinn_message(chat_id=chat_id, message_id=message_id, user_id=user.id)
return {"message_id": message_id, "user_id": user.id, "chat_id": chat_id}
@router.delete(
"/unpin_message",
status_code=status.HTTP_200_OK,
response_model=SPinnedMessage
)
async def unpinn_message(chat_id: int, message_id: int, user: SUser = Depends(check_verificated_user_with_exc)):
await AuthService.validate_user_access_to_chat(chat_id=chat_id, user_id=user.id)
message_pinner = await ChatDAO.get_message_pinner(chat_id=chat_id, message_id=message_id)
if message_pinner == user.id:
await ChatDAO.unpinn_message(chat_id=chat_id, message_id=message_id)
return {"message_id": message_id, "user_id": user.id, "chat_id": chat_id}
raise UserDontHavePermissionException
@router.get(
"/pinned_messages/{chat_id}",
status_code=status.HTTP_200_OK,

View file

@ -1,4 +1,5 @@
from sqlalchemy import select, insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import async_session_maker
@ -6,24 +7,24 @@ from app.database import async_session_maker
class BaseDAO:
model = None
@classmethod
async def add(cls, **data): # Метод добавляет данные в БД
def __init__(self, session: AsyncSession):
self.session = session
async def add(self, **data): # Метод добавляет данные в БД
async with async_session_maker() as session:
query = insert(cls.model).values(**data).returning(cls.model.id)
query = insert(self.model).values(**data).returning(self.model.id)
result = await session.execute(query)
await session.commit()
return result.scalar()
@classmethod
async def find_one_or_none(cls, **filter_by): # Метод проверяет наличие строки с заданными параметрами
async def find_one_or_none(self, **filter_by): # Метод проверяет наличие строки с заданными параметрами
async with async_session_maker() as session:
query = select(cls.model).filter_by(**filter_by)
query = select(self.model).filter_by(**filter_by)
result = await session.execute(query)
return result.scalar_one_or_none()
@classmethod
async def find_all(cls, **filter_by): # Метод возвращает все строки таблицы или те, которые соответствуют отбору
async def find_all(self, **filter_by): # Метод возвращает все строки таблицы или те, которые соответствуют отбору
async with async_session_maker() as session:
query = select(cls.model.__table__.columns).filter_by(**filter_by)
query = select(self.model.__table__.columns).filter_by(**filter_by)
result = await session.execute(query)
return result.mappings().all()

View file

@ -17,7 +17,6 @@ from app.images.router import router as image_router
app = FastAPI(title="Чат BP", root_path="/api")
app.include_router(websocket_router)
# app.include_router(chat_router)
app.include_router(user_router)
app.include_router(pages_router)
app.include_router(image_router)

View file

@ -24,9 +24,11 @@ class MessageService:
return new_message
@staticmethod
async def pin_message(chat_id: int, user_id: int, message_id: int):
pass
async def pin_message(chat_id: int, user_id: int, message_id: int) -> SMessage:
pinned_message = await ChatDAO.pin_message(chat_id=chat_id, message_id=message_id, user_id=user_id)
return pinned_message
@staticmethod
async def unpin_message(chat_id: int, message_id: int):
pass
async def unpin_message(chat_id: int, message_id: int) -> int:
unpinned_message = await ChatDAO.unpin_message(chat_id=chat_id, message_id=message_id)
return unpinned_message

View file

@ -13,8 +13,7 @@ from app.users.schemas import SUser, SUserAvatars
class UserDAO(BaseDAO):
model = Users
@staticmethod
async def find_one_or_none(**filter_by) -> SUser | None:
async def find_one_or_none(self, **filter_by) -> SUser | None:
async with async_session_maker() as session:
query = select(Users).filter_by(**filter_by)
result = await session.execute(query)
@ -22,15 +21,13 @@ class UserDAO(BaseDAO):
if result:
return SUser.model_validate(result, from_attributes=True)
@staticmethod
async def find_all(**filter_by):
async def find_all(self, **filter_by):
async with async_session_maker() as session:
query = select(Users.__table__.columns).filter_by(**filter_by).where(Users.role != 100)
result = await session.execute(query)
return result.mappings().all()
@staticmethod
async def change_data(user_id: int, **data_to_change) -> str:
async def change_data(self, user_id: int, **data_to_change) -> str:
query = update(Users).where(Users.id == user_id).values(**data_to_change)
async with async_session_maker() as session:
await session.execute(query)
@ -39,15 +36,13 @@ class UserDAO(BaseDAO):
result = await session.execute(query)
return result.scalar()
@staticmethod
async def get_user_role(user_id: int) -> int:
async def get_user_role(self, user_id: int) -> int:
query = select(Users.role).where(Users.id == user_id)
async with async_session_maker() as session:
result = await session.execute(query)
return result.scalar()
@staticmethod
async def get_user_allowed_chats(user_id: int):
async def get_user_allowed_chats(self, user_id: int):
"""
WITH chats_with_descriptions AS (
SELECT *
@ -94,7 +89,7 @@ class UserDAO(BaseDAO):
chats_with_avatars.c.avatar_hex,
)
.select_from(chats_with_avatars)
.where(and_(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True))
.where(and_(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True)) # noqa: E712
)
async with async_session_maker() as session:
@ -102,23 +97,20 @@ class UserDAO(BaseDAO):
result = result.mappings().all()
return result
@staticmethod
async def get_user_avatar(user_id: int) -> str:
async def get_user_avatar(self, user_id: int) -> str:
query = select(Users.avatar_image).where(Users.id == user_id)
async with async_session_maker() as session:
result = await session.execute(query)
return result.scalar()
@staticmethod
async def add_user_avatar(user_id: int, avatar: str) -> bool:
async def add_user_avatar(self, user_id: int, avatar: str) -> bool:
query = insert(UserAvatar).values(user_id=user_id, avatar_image=avatar)
async with async_session_maker() as session:
await session.execute(query)
await session.commit()
return True
@staticmethod
async def get_user_avatars(user_id: int) -> SUserAvatars:
async def get_user_avatars(self, user_id: int) -> SUserAvatars:
query = select(
func.json_build_object(
"user_avatars", select(
@ -137,8 +129,7 @@ class UserDAO(BaseDAO):
result = result.scalar()
return SUserAvatars.model_validate(result)
@staticmethod
async def delete_user_avatar(avatar_id: int, user_id: int) -> bool:
async def delete_user_avatar(self, avatar_id: int, user_id: int) -> bool:
query = delete(UserAvatar).where(and_(UserAvatar.id == avatar_id, UserAvatar.user_id == user_id))
async with async_session_maker() as session:
await session.execute(query)
@ -149,8 +140,7 @@ class UserDAO(BaseDAO):
class UserCodesDAO(BaseDAO):
model = UserVerificationCode
@classmethod
async def set_user_codes(cls, user_id: int, code: str, description: str):
async def set_user_codes(self, cls, user_id: int, code: str, description: str):
query = (
insert(UserVerificationCode)
.values(user_id=user_id, code=code, description=description)
@ -161,8 +151,7 @@ class UserCodesDAO(BaseDAO):
await session.commit()
return result.scalar()
@staticmethod
async def get_user_codes(**filter_by) -> list[dict | None]:
async def get_user_codes(self, **filter_by) -> list[dict | None]:
"""
SELECT
usersverificationcodes.id,

View file

@ -42,7 +42,7 @@ target-version = "py312"
[lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
select = ["F", "E", "W", "C"]
ignore = ["W191", "W391", "C901", "E712"]
ignore = ["W191", "W391", "C901"]
# Allow fix for all enabled rules (when `--fix`) is provided.
fixable = ["ALL"]

View file

@ -62,7 +62,7 @@ async def test_get_user(ac: AsyncClient):
response = await ac.get("/users/me")
assert response.status_code == 200
assert response.json()["email"] == "urec@urec.com"
assert response.json()["black_phoenix"] == False
assert response.json()["black_phoenix"] == False # noqa: E712
@pytest.mark.parametrize(