Переделал бд

This commit is contained in:
urec56 2024-06-05 13:05:30 +05:00
parent 8cbd025395
commit 4deefeab4e
31 changed files with 294 additions and 471 deletions

View file

@ -30,3 +30,8 @@ INVITATION_LINK_HOST=
INVITATION_LINK_TOKEN_KEY= INVITATION_LINK_TOKEN_KEY=
SENTRY_DSN= SENTRY_DSN=
ADMIN_USER=
ADMIN_USER_ID=
REGISTRATED_USER=
VERIFICATED_USER=

View file

@ -1,5 +1,3 @@
from cryptography.fernet import Fernet
from fastapi import APIRouter, Depends, status from fastapi import APIRouter, Depends, status
from app.config import settings from app.config import settings
@ -16,10 +14,9 @@ from app.chat.shemas import (
SPinnedMessages SPinnedMessages
) )
from app.unit_of_work import UnitOfWork from app.unit_of_work import UnitOfWork
from app.users.dependencies import check_verificated_user_with_exc from app.users.dependencies import check_verificated_user_with_exc
from app.users.auth import ADMIN_USER_ID, AuthService from app.utils.auth import AuthService, encode_invitation_token, decode_invitation_token
from app.users.schemas import SCreateInvitationLink, SUserAddedToChat, SUser from app.users.schemas import SCreateInvitationLink, SUserAddedToChat, SUser, SInvitationData
router = APIRouter(prefix="/chat", tags=["Чат"]) router = APIRouter(prefix="/chat", tags=["Чат"])
@ -51,7 +48,7 @@ async def create_chat(
raise UserCanNotReadThisChatException raise UserCanNotReadThisChatException
chat_id = await uow.chat.create(user_id=user_to_exclude, chat_name=chat_name, created_by=user.id) chat_id = await uow.chat.create(user_id=user_to_exclude, chat_name=chat_name, created_by=user.id)
user_added_to_chat = await uow.chat.add_user_to_chat(user.id, chat_id) user_added_to_chat = await uow.chat.add_user_to_chat(user.id, chat_id)
await uow.chat.add_user_to_chat(ADMIN_USER_ID, chat_id) await uow.chat.add_user_to_chat(settings.ADMIN_USER_ID, chat_id)
return user_added_to_chat return user_added_to_chat
@ -128,9 +125,9 @@ async def create_invitation_link(
uow=Depends(UnitOfWork) uow=Depends(UnitOfWork)
): ):
await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id) await AuthService.validate_user_access_to_chat(uow=uow, chat_id=chat_id, user_id=user.id)
cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY) invitation_data = SInvitationData.model_validate({"chat_id": chat_id})
invitation_token = cipher_suite.encrypt(str(chat_id).encode()) invitation_token = encode_invitation_token(invitation_data)
invitation_link = settings.INVITATION_LINK_HOST + "/api/chat/invite_to_chat/" + str(invitation_token).split("'")[1] invitation_link = settings.INVITATION_LINK_HOST + "/api/chat/invite_to_chat/" + invitation_token
return {"invitation_link": invitation_link} return {"invitation_link": invitation_link}
@ -144,14 +141,12 @@ async def invite_to_chat(
user: SUser = Depends(check_verificated_user_with_exc), user: SUser = Depends(check_verificated_user_with_exc),
uow=Depends(UnitOfWork) uow=Depends(UnitOfWork)
): ):
invitation_data = decode_invitation_token(invitation_token)
async with uow: async with uow:
invitation_token = invitation_token.encode() chat = await uow.chat.find_one_or_none(id=invitation_data.chat_id)
cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY)
chat_id = int(cipher_suite.decrypt(invitation_token))
chat = await uow.chat.find_one_or_none(id=chat_id)
if user.id == chat.chat_for: if user.id == chat.chat_for:
raise UserCanNotReadThisChatException raise UserCanNotReadThisChatException
return {"user_added_to_chat": await uow.chat.add_user_to_chat(chat_id=chat_id, user_id=user.id)} return {"user_added_to_chat": await uow.chat.add_user_to_chat(chat_id=invitation_data.chat_id, user_id=user.id)}
@router.delete( @router.delete(

View file

@ -3,7 +3,7 @@ from fastapi import WebSocket, WebSocketDisconnect, Depends
from app.exceptions import IncorrectDataException, UserDontHavePermissionException from app.exceptions import IncorrectDataException, UserDontHavePermissionException
from app.services.message_service import MessageService from app.services.message_service import MessageService
from app.unit_of_work import UnitOfWork from app.unit_of_work import UnitOfWork
from app.users.auth import AuthService from app.utils.auth import AuthService
from app.chat.router import router from app.chat.router import router
from app.chat.shemas import SSendMessage, SMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage from app.chat.shemas import SSendMessage, SMessage, SDeleteMessage, SEditMessage, SPinMessage, SUnpinMessage
from app.users.dependencies import get_current_user_ws from app.users.dependencies import get_current_user_ws

View file

@ -39,5 +39,10 @@ class Settings(BaseSettings):
SENTRY_DSN: str SENTRY_DSN: str
ADMIN_USER: int
ADMIN_USER_ID: int
REGISTRATED_USER: int
VERIFICATED_USER: int
settings = Settings() settings = Settings()

View file

@ -1,3 +0,0 @@
import os
os.environ["MODE"] = "TEST"

View file

@ -5,21 +5,26 @@ from app.database import engine # noqa
from app.exceptions import UserAlreadyInChatException, UserAlreadyPinnedChatException from app.exceptions import UserAlreadyInChatException, UserAlreadyPinnedChatException
from app.chat.shemas import SMessage from app.chat.shemas import SMessage
from app.models.users import Users from app.models.users import Users
from app.models.answer import Answer from app.models.message_answer import MessageAnswer
from app.models.chat import Chats from app.models.chat import Chats
from app.models.message import Message from app.models.message import Message
from app.models.pinned_chat import PinnedChats from app.models.pinned_chat import PinnedChat
from app.models.pinned_message import PinnedMessages from app.models.pinned_message import PinnedMessage
from app.models.user_chat import UserChat from app.models.user_chat import UserChat
class ChatDAO(BaseDAO): class ChatDAO(BaseDAO):
model = Chats model = Chats
async def create(self, user_id: int, chat_name: str, created_by: int) -> int: async def find_one_or_none(self, **filter_by):
query = insert(Chats).values(chat_for=user_id, chat_name=chat_name, created_by=created_by).returning(Chats.id) query = select(Chats.__table__.columns).filter_by(**filter_by)
result = await self.session.execute(query) result = await self.session.execute(query)
result = result.scalar_one_or_none()
async def create(self, user_id: int, chat_name: str, created_by: int) -> int:
stmt = insert(Chats).values(chat_for=user_id, chat_name=chat_name, created_by=created_by).returning(Chats.id)
result = await self.session.execute(stmt)
await self.session.commit() await self.session.commit()
result = result.scalar() result = result.scalar()
return result return result
@ -30,8 +35,8 @@ class ChatDAO(BaseDAO):
result = result.scalars().all() result = result.scalars().all()
if user_id in result: if user_id in result:
raise UserAlreadyInChatException raise UserAlreadyInChatException
query = insert(UserChat).values(user_id=user_id, chat_id=chat_id) stmt = insert(UserChat).values(user_id=user_id, chat_id=chat_id)
await self.session.execute(query) await self.session.execute(stmt)
await self.session.commit() await self.session.commit()
return True return True
@ -56,12 +61,12 @@ class ChatDAO(BaseDAO):
Users.avatar_image, Users.avatar_image,
Users.username, Users.username,
Users.avatar_hex, Users.avatar_hex,
Answer.self_id, MessageAnswer.self_id,
Answer.answer_id, MessageAnswer.answer_id,
) )
.select_from(inserted_image) .select_from(inserted_image)
.join(Users, Users.id == inserted_image.c.user_id) .join(Users, Users.id == inserted_image.c.user_id)
.join(Answer, Answer.self_id == inserted_image.c.id, isouter=True) .join(MessageAnswer, MessageAnswer.self_id == inserted_image.c.id, isouter=True)
) )
result = await self.session.execute(query) result = await self.session.execute(query)
@ -81,12 +86,12 @@ class ChatDAO(BaseDAO):
Users.avatar_image, Users.avatar_image,
Users.username, Users.username,
Users.avatar_hex, Users.avatar_hex,
Answer.self_id, MessageAnswer.self_id,
Answer.answer_id, MessageAnswer.answer_id,
) )
.select_from(Message) .select_from(Message)
.join(Users, Users.id == Message.user_id) .join(Users, Users.id == Message.user_id)
.join(Answer, Answer.self_id == Message.id, isouter=True) .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
.where(Message.id == message_id, Message.visibility == True) # noqa: E712 .where(Message.id == message_id, Message.visibility == True) # noqa: E712
) )
result = await self.session.execute(query) result = await self.session.execute(query)
@ -115,10 +120,10 @@ class ChatDAO(BaseDAO):
LIMIT 15 OFFSET 0; LIMIT 15 OFFSET 0;
""" """
messages_with_users = ( messages_with_users = (
select(Message.__table__.columns, Users.__table__.columns, Answer.__table__.columns) select(Message.__table__.columns, Users.__table__.columns, MessageAnswer.__table__.columns)
.select_from(Message) .select_from(Message)
.join(Users, Message.user_id == Users.id) .join(Users, Message.user_id == Users.id)
.join(Answer, Answer.self_id == Message.id, isouter=True) .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
.cte("messages_with_users") .cte("messages_with_users")
) )
@ -155,9 +160,9 @@ class ChatDAO(BaseDAO):
async def add_answer(self, self_id: int, answer_id: int) -> SMessage: async def add_answer(self, self_id: int, answer_id: int) -> SMessage:
answer = ( answer = (
insert(Answer) insert(MessageAnswer)
.values(self_id=self_id, answer_id=answer_id) .values(self_id=self_id, answer_id=answer_id)
.returning(Answer.self_id, Answer.answer_id) .returning(MessageAnswer.self_id, MessageAnswer.answer_id)
.cte("answer") .cte("answer")
) )
@ -199,18 +204,18 @@ class ChatDAO(BaseDAO):
return True return True
async def pin_chat(self, chat_id: int, user_id: int) -> bool: async def pin_chat(self, chat_id: int, user_id: int) -> bool:
query = select(PinnedChats.chat_id).where(PinnedChats.user_id == user_id) query = select(PinnedChat.chat_id).where(PinnedChat.user_id == user_id)
result = await self.session.execute(query) result = await self.session.execute(query)
result = result.scalars().all() result = result.scalars().all()
if chat_id in result: if chat_id in result:
raise UserAlreadyPinnedChatException raise UserAlreadyPinnedChatException
query = insert(PinnedChats).values(chat_id=chat_id, user_id=user_id) stmt = insert(PinnedChat).values(chat_id=chat_id, user_id=user_id)
await self.session.execute(query) await self.session.execute(stmt)
await self.session.commit() await self.session.commit()
return True return True
async def unpin_chat(self, chat_id: int, user_id: int) -> bool: async def unpin_chat(self, chat_id: int, user_id: int) -> bool:
query = delete(PinnedChats).where(PinnedChats.chat_id == chat_id, PinnedChats.user_id == user_id) query = delete(PinnedChat).where(PinnedChat.chat_id == chat_id, PinnedChat.user_id == user_id)
await self.session.execute(query) await self.session.execute(query)
await self.session.commit() await self.session.commit()
return True return True
@ -247,8 +252,8 @@ class ChatDAO(BaseDAO):
chats_with_avatars.c.avatar_hex, chats_with_avatars.c.avatar_hex,
) )
.distinct() .distinct()
.select_from(PinnedChats) .select_from(PinnedChat)
.join(chats_with_avatars, PinnedChats.chat_id == chats_with_avatars.c.chat_id) .join(chats_with_avatars, PinnedChat.chat_id == chats_with_avatars.c.chat_id)
.where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True) # noqa: E712 .where(chats_with_avatars.c.id == user_id, chats_with_avatars.c.visibility == True) # noqa: E712
) )
# print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса # print(query.compile(engine, compile_kwargs={"literal_binds": True})) # Проверка SQL запроса
@ -257,12 +262,12 @@ class ChatDAO(BaseDAO):
return result return result
async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> bool: async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> bool:
query = insert(PinnedMessages).values(chat_id=chat_id, message_id=message_id, user_id=user_id) query = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id)
await self.session.execute(query) await self.session.execute(query)
return True return True
async def unpin_message(self, chat_id: int, message_id: int) -> bool: async def unpin_message(self, chat_id: int, message_id: int) -> bool:
query = delete(PinnedMessages).where(PinnedMessages.chat_id == chat_id, PinnedMessages.message_id == message_id) query = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id)
await self.session.execute(query) await self.session.execute(query)
return True return True
@ -278,14 +283,14 @@ class ChatDAO(BaseDAO):
Users.avatar_image, Users.avatar_image,
Users.username, Users.username,
Users.avatar_hex, Users.avatar_hex,
Answer.self_id, MessageAnswer.self_id,
Answer.answer_id, MessageAnswer.answer_id,
) )
.select_from(PinnedMessages) .select_from(PinnedMessage)
.join(Message, PinnedMessages.message_id == Message.id, isouter=True) .join(Message, PinnedMessage.message_id == Message.id, isouter=True)
.join(Users, PinnedMessages.user_id == Users.id, isouter=True) .join(Users, PinnedMessage.user_id == Users.id, isouter=True)
.join(Answer, Answer.self_id == Message.id, isouter=True) .join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
.where(PinnedMessages.chat_id == chat_id, Message.visibility == True) # noqa: E712 .where(PinnedMessage.chat_id == chat_id, Message.visibility == True) # noqa: E712
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
) )

View file

@ -22,13 +22,14 @@ app.include_router(pages_router)
app.include_router(image_router) app.include_router(image_router)
origins = ["http://localhost:5173"] origins = ["http://localhost:5173"]
headers = ["Content-Type", "Set-Cookie", "Access-Control-Allow-Headers", "Authorization", "Accept", "Access-Control-Allow-Origin"]
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=origins,
allow_credentials=True, allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"], allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
allow_headers=["*"], allow_headers=headers,
) )
app.mount("/static", StaticFiles(directory="app/static"), name="static") app.mount("/static", StaticFiles(directory="app/static"), name="static")

View file

@ -6,14 +6,13 @@ from alembic import context
from app.database import DATABASE_URL, Base from app.database import DATABASE_URL, Base
from app.models.users import Users # noqa from app.models.users import Users # noqa
from app.models.answer import Answer # noqa from app.models.message_answer import MessageAnswer # noqa
from app.models.chat import Chats # noqa from app.models.chat import Chats # noqa
from app.models.message import Message # noqa from app.models.message import Message # noqa
from app.models.pinned_chat import PinnedChats # noqa from app.models.pinned_chat import PinnedChat # noqa
from app.models.pinned_message import PinnedMessages # noqa from app.models.pinned_message import PinnedMessage # noqa
from app.models.user_chat import UserChat # noqa from app.models.user_chat import UserChat # noqa
from app.models.user_avatar import UserAvatar # noqa from app.models.user_avatar import UserAvatar # noqa
from app.models.user_verification_code import UserVerificationCode # noqa
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.

View file

@ -1,181 +0,0 @@
"""Database Creation
Revision ID: 66b93ccf9063
Revises:
Create Date: 2024-05-06 17:29:33.815613
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "66b93ccf9063"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"users",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("email", sa.String(), nullable=False),
sa.Column("username", sa.String(), nullable=False),
sa.Column("hashed_password", sa.String(), nullable=False),
sa.Column("role", sa.Integer(), server_default="0", nullable=False),
sa.Column("black_phoenix", sa.Boolean(), server_default="false", nullable=False),
sa.Column(
"avatar_image",
sa.String(),
server_default="https://images.black-phoenix.ru/static/images/%D1%82%D1%8B%20%D1%83%D0%B6%D0%B5%20%D0%BF%D0%B5%D1%88%D0%BA%D0%B0%20BP.png",
nullable=True,
),
sa.Column("avatar_hex", sa.String(), server_default="#30293f", nullable=True),
sa.Column("date_of_birth", sa.Date(), nullable=False),
sa.Column("date_of_registration", sa.Date(), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"chats",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_by", sa.Integer(), nullable=False),
sa.Column("chat_for", sa.Integer(), nullable=True),
sa.Column("chat_name", sa.String(), nullable=False),
sa.Column("visibility", sa.Boolean(), server_default="true", nullable=False),
sa.ForeignKeyConstraint(
["chat_for"],
["users.id"],
),
sa.ForeignKeyConstraint(
["created_by"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"usersavatars",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("avatar_image", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"usersverificationcodes",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("code", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=False),
sa.Column("date_of_creation", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"messages",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("message", sa.String(), nullable=True),
sa.Column("image_url", sa.String(), nullable=True),
sa.Column("chat_id", sa.Integer(), nullable=True),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("visibility", sa.Boolean(), server_default="true", nullable=False),
sa.ForeignKeyConstraint(
["chat_id"],
["chats.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"pinnedchats",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("chat_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["chat_id"],
["chats.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"usersxchats",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("chat_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["chat_id"],
["chats.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"answers",
sa.Column("self_id", sa.Integer(), nullable=False),
sa.Column("answer_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["answer_id"],
["messages.id"],
),
sa.ForeignKeyConstraint(
["self_id"],
["messages.id"],
),
sa.PrimaryKeyConstraint("self_id"),
)
op.create_table(
"pinnedmessages",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("chat_id", sa.Integer(), nullable=True),
sa.Column("message_id", sa.Integer(), nullable=True),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["chat_id"],
["chats.id"],
),
sa.ForeignKeyConstraint(
["message_id"],
["messages.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("pinnedmessages")
op.drop_table("answers")
op.drop_table("usersxchats")
op.drop_table("pinnedchats")
op.drop_table("messages")
op.drop_table("usersverificationcodes")
op.drop_table("usersavatars")
op.drop_table("chats")
op.drop_table("users")
# ### end Alembic commands ###

View file

@ -0,0 +1,112 @@
"""Database Creation
Revision ID: 00acc3992d64
Revises:
Create Date: 2024-06-05 12:56:38.627620
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '00acc3992d64'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('users',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('email', sa.String(), nullable=False),
sa.Column('username', sa.String(), nullable=False),
sa.Column('hashed_password', sa.String(), nullable=False),
sa.Column('role', sa.Integer(), server_default='0', nullable=False),
sa.Column('black_phoenix', sa.Boolean(), server_default='false', nullable=False),
sa.Column('avatar_image', sa.String(), server_default='https://images.black-phoenix.ru/static/images/%D1%82%D1%8B%20%D1%83%D0%B6%D0%B5%20%D0%BF%D0%B5%D1%88%D0%BA%D0%B0%20BP.png', nullable=False),
sa.Column('date_of_birth', sa.Date(), nullable=False),
sa.Column('date_of_registration', sa.Date(), server_default=sa.text('now()'), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('email'),
sa.UniqueConstraint('username')
)
op.create_table('chat',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=False),
sa.Column('chat_for', sa.Integer(), nullable=True),
sa.Column('chat_name', sa.String(), nullable=False),
sa.Column('visibility', sa.Boolean(), server_default='true', nullable=False),
sa.ForeignKeyConstraint(['chat_for'], ['users.id'], ),
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('user_avatar',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('avatar_image', sa.String(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('message',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('message', sa.String(), nullable=True),
sa.Column('image_url', sa.String(), nullable=True),
sa.Column('chat_id', sa.Integer(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('visibility', sa.Boolean(), server_default='true', nullable=False),
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('pinned_chat',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('chat_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('user_chat',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('chat_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('message_answer',
sa.Column('self_id', sa.Integer(), nullable=False),
sa.Column('answer_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['answer_id'], ['message.id'], ),
sa.ForeignKeyConstraint(['self_id'], ['message.id'], ),
sa.PrimaryKeyConstraint('self_id')
)
op.create_table('pinned_message',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('chat_id', sa.Integer(), nullable=True),
sa.Column('message_id', sa.Integer(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ),
sa.ForeignKeyConstraint(['message_id'], ['message.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('pinned_message')
op.drop_table('message_answer')
op.drop_table('user_chat')
op.drop_table('pinned_chat')
op.drop_table('message')
op.drop_table('user_avatar')
op.drop_table('chat')
op.drop_table('users')
# ### end Alembic commands ###

View file

@ -1,11 +0,0 @@
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class Answer(Base):
__tablename__ = "answers"
self_id: Mapped[int] = mapped_column(ForeignKey("messages.id"), primary_key=True)
answer_id: Mapped[int] = mapped_column(ForeignKey("messages.id"))

View file

@ -5,7 +5,7 @@ from app.database import Base
class Chats(Base): class Chats(Base):
__tablename__ = "chats" __tablename__ = "chat"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
created_by: Mapped[int] = mapped_column(ForeignKey("users.id")) created_by: Mapped[int] = mapped_column(ForeignKey("users.id"))

View file

@ -7,10 +7,10 @@ from app.database import Base
class Message(Base): class Message(Base):
__tablename__ = "messages" __tablename__ = "message"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
chat_id = mapped_column(ForeignKey("chats.id")) chat_id = mapped_column(ForeignKey("chat.id"))
user_id = mapped_column(ForeignKey("users.id")) user_id = mapped_column(ForeignKey("users.id"))
message: Mapped[str | None] message: Mapped[str | None]
image_url: Mapped[str | None] image_url: Mapped[str | None]

View file

@ -0,0 +1,11 @@
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class MessageAnswer(Base):
__tablename__ = "message_answer"
self_id: Mapped[int] = mapped_column(ForeignKey("message.id"), primary_key=True)
answer_id: Mapped[int] = mapped_column(ForeignKey("message.id"))

View file

@ -4,9 +4,9 @@ from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base from app.database import Base
class PinnedChats(Base): class PinnedChat(Base):
__tablename__ = "pinnedchats" __tablename__ = "pinned_chat"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
user_id = mapped_column(ForeignKey("users.id")) user_id = mapped_column(ForeignKey("users.id"))
chat_id = mapped_column(ForeignKey("chats.id")) chat_id = mapped_column(ForeignKey("chat.id"))

View file

@ -4,10 +4,10 @@ from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base from app.database import Base
class PinnedMessages(Base): class PinnedMessage(Base):
__tablename__ = "pinnedmessages" __tablename__ = "pinned_message"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
chat_id = mapped_column(ForeignKey("chats.id")) chat_id = mapped_column(ForeignKey("chat.id"))
message_id = mapped_column(ForeignKey("messages.id")) message_id = mapped_column(ForeignKey("message.id"))
user_id = mapped_column(ForeignKey("users.id")) user_id = mapped_column(ForeignKey("users.id"))

View file

@ -5,7 +5,7 @@ from app.database import Base
class UserAvatar(Base): class UserAvatar(Base):
__tablename__ = "usersavatars" __tablename__ = "user_avatar"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id")) user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))

View file

@ -5,8 +5,8 @@ from app.database import Base
class UserChat(Base): class UserChat(Base):
__tablename__ = "usersxchats" __tablename__ = "user_chat"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
user_id = mapped_column(ForeignKey("users.id")) user_id = mapped_column(ForeignKey("users.id"))
chat_id = mapped_column(ForeignKey("chats.id")) chat_id = mapped_column(ForeignKey("chat.id"))

View file

@ -1,16 +0,0 @@
from datetime import datetime
from sqlalchemy import func, ForeignKey, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class UserVerificationCode(Base):
__tablename__ = "usersverificationcodes"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
code: Mapped[str]
description: Mapped[str]
date_of_creation: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())

View file

@ -16,6 +16,5 @@ class Users(Base):
role: Mapped[int] = mapped_column(server_default="0") role: Mapped[int] = mapped_column(server_default="0")
black_phoenix: Mapped[bool] = mapped_column(server_default="false") black_phoenix: Mapped[bool] = mapped_column(server_default="false")
avatar_image: Mapped[str] = mapped_column(server_default="https://images.black-phoenix.ru/static/images/%D1%82%D1%8B%20%D1%83%D0%B6%D0%B5%20%D0%BF%D0%B5%D1%88%D0%BA%D0%B0%20BP.png") # noqa: E501 avatar_image: Mapped[str] = mapped_column(server_default="https://images.black-phoenix.ru/static/images/%D1%82%D1%8B%20%D1%83%D0%B6%D0%B5%20%D0%BF%D0%B5%D1%88%D0%BA%D0%B0%20BP.png") # noqa: E501
avatar_hex: Mapped[str] = mapped_column(server_default="#30293f")
date_of_birth: Mapped[date] date_of_birth: Mapped[date]
date_of_registration: Mapped[date] = mapped_column(server_default=func.now()) date_of_registration: Mapped[date] = mapped_column(server_default=func.now())

View file

@ -1,17 +0,0 @@
from fastapi import APIRouter, Request
from fastapi.templating import Jinja2Templates
router = APIRouter(prefix="/pages", tags=["Страницы"])
templates = Jinja2Templates(directory="app/templates")
@router.get("/base")
async def base(request: Request):
return templates.TemplateResponse("base.html", {"request": request})
@router.get("/chat")
async def get_chat_page(request: Request):
return templates.TemplateResponse("chat.html", {"request": request})

View file

@ -12,8 +12,8 @@ from app.tasks.email_templates import (
create_password_change_confirmation_template, create_password_change_confirmation_template,
create_password_recover_template, create_password_recover_template,
) )
from app.users.auth import encode_invitation_token from app.utils.auth import encode_confirmation_token
from app.users.schemas import SInvitationData from app.users.schemas import SConfirmationData
def generate_confirmation_code(length=6) -> str: def generate_confirmation_code(length=6) -> str:
@ -23,8 +23,8 @@ def generate_confirmation_code(length=6) -> str:
@celery.task @celery.task
def send_registration_confirmation_email(user_data: SInvitationData): def send_registration_confirmation_email(user_data: SConfirmationData):
invitation_token = encode_invitation_token(user_data) invitation_token = encode_confirmation_token(user_data)
confirmation_link = settings.INVITATION_LINK_HOST + "/api/users/email_verification/" + invitation_token confirmation_link = settings.INVITATION_LINK_HOST + "/api/users/email_verification/" + invitation_token

View file

@ -1,36 +0,0 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport"
content="width=device-width, user-scalable=no, initial-scale=1.0, maximum-scale=1.0, minimum-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="ie=edge">
{% block head %}{% endblock %}
<title>BlackPhoenix</title>
</head>
<body>
<nav class="flex justify-between text-3xl my-3">
<ul>
<li class="BpMainButton"><button>Black Phoenix</button></li>
<li class="BpMyAccount"><button>Мой аккаунт</button></li>
</ul>
<style>
ul,li{
display: inline;
}
.BpMainButton{
float: left;
}
.BpMyAccount{
float: right;
}
</style>
</nav>
<hr>
<div id="content">
{% block content %}{% endblock %}
</div>
</body>
</html>

View file

@ -1,56 +0,0 @@
{% extends "base.html" %}
{% block content %}
<div class="flex flex-col items-center">
<h1>WebSocket Chat</h1>
<h2>Your ID: <span id="ws-id"></span></h2>
<form action="" onsubmit="sendMessage(event)">
<input class="bg-green-300" type="text" id="messageText" autocomplete="off"/>
<button>Send</button>
</form>
<ul id='messages'>
</ul>
</div>
<script>
async function getLastMessages() {
const url = 'http://localhost:8000/api/chat/get_last_message/2'
const response = await fetch(url, {
method: 'GET'
})
return response.json()
}
getLastMessages()
.then(messages => {
appendMessage("Предыдущие 5 сообщений:")
messages.forEach(msg => {
appendMessage(msg.message)
})
appendMessage("\nНовые сообщения:")
})
function appendMessage(msg) {
let messages = document.getElementById('messages')
let message = document.createElement('li')
let content = document.createTextNode(msg)
message.appendChild(content)
messages.appendChild(message)
}
let chat_id = 2
document.querySelector("#ws-id").textContent = chat_id;
let ws = new WebSocket(`ws://localhost:8000/chat/ws/${chat_id}?user_id=1`);
ws.onmessage = function (event) {
const data = JSON.parse(event.data)
appendMessage(data.message);
console.log(data)
};
function sendMessage(event) {
let input = document.getElementById("messageText")
ws.send(JSON.stringify({'message':input.value, "image_url": 'https://images.black-phoenix.ru/static/images/avatars/0qQOzJcY5lOuuA1u_avatar.png'}))
input.value = ''
event.preventDefault()
}
</script>
{% endblock %}

View file

@ -14,7 +14,6 @@ from app.exceptions import (
) )
from app.services.user_service import UserService from app.services.user_service import UserService
from app.unit_of_work import UnitOfWork from app.unit_of_work import UnitOfWork
from app.users.auth import VERIFICATED_USER
from app.users.schemas import SUser from app.users.schemas import SUser
auth_schema = HTTPBearer() auth_schema = HTTPBearer()
@ -46,7 +45,7 @@ async def get_current_user(token: str = Depends(get_token), uow=Depends(UnitOfWo
async def check_verificated_user_with_exc(user: SUser = Depends(get_current_user)) -> SUser: async def check_verificated_user_with_exc(user: SUser = Depends(get_current_user)) -> SUser:
if not user.role >= VERIFICATED_USER: if not user.role >= settings.VERIFICATED_USER:
raise UserMustConfirmEmailException raise UserMustConfirmEmailException
return user return user

View file

@ -9,8 +9,7 @@ from app.exceptions import (
) )
from app.services.redis_service import RedisService, get_redis_session from app.services.redis_service import RedisService, get_redis_session
from app.unit_of_work import UnitOfWork from app.unit_of_work import UnitOfWork
from app.users.auth import get_password_hash, create_access_token, VERIFICATED_USER, AuthService, verify_password, \ from app.utils.auth import get_password_hash, create_access_token, AuthService, verify_password, decode_confirmation_token
decode_invitation_token
from app.users.dependencies import get_current_user from app.users.dependencies import get_current_user
from app.users.schemas import ( from app.users.schemas import (
SUserLogin, SUserLogin,
@ -24,7 +23,7 @@ from app.users.schemas import (
SUserSendConfirmationCode, SUserSendConfirmationCode,
STokenLogin, STokenLogin,
SUsers, SUsers,
SInvitationData, SConfirmationData,
) )
from app.tasks.tasks import ( from app.tasks.tasks import (
send_registration_confirmation_email, send_registration_confirmation_email,
@ -77,7 +76,7 @@ async def register_user(user_data: SUserRegister, uow=Depends(UnitOfWork)):
await uow.commit() await uow.commit()
user_code = generate_confirmation_code() user_code = generate_confirmation_code()
user_mail_data = SInvitationData.model_validate( user_mail_data = SConfirmationData.model_validate(
{"user_id": user_id, "username": user_data.username, "email_to": user_data.email, "confirmation_code": user_code} {"user_id": user_id, "username": user_data.username, "email_to": user_data.email, "confirmation_code": user_code}
) )
send_registration_confirmation_email.delay(user_mail_data) send_registration_confirmation_email.delay(user_mail_data)
@ -93,14 +92,14 @@ async def register_user(user_data: SUserRegister, uow=Depends(UnitOfWork)):
response_model=SEmailVerification, response_model=SEmailVerification,
) )
async def email_verification(user_code: str, uow=Depends(UnitOfWork)): async def email_verification(user_code: str, uow=Depends(UnitOfWork)):
user_data = decode_invitation_token(user_code) user_data = decode_confirmation_token(user_code)
redis_session = get_redis_session() redis_session = get_redis_session()
async with uow: async with uow:
verification_code = await RedisService.get_verification_code(redis=redis_session, user_id=user_data.user_id) verification_code = await RedisService.get_verification_code(redis=redis_session, user_id=user_data.user_id)
if verification_code != user_data.confirmation_code: if verification_code != user_data.confirmation_code:
raise WrongCodeException raise WrongCodeException
await uow.user.change_data(user_id=user_data.user_id, role=VERIFICATED_USER) await uow.user.change_data(user_id=user_data.user_id, role=settings.VERIFICATED_USER)
await uow.commit() await uow.commit()
return {"email_verification": True} return {"email_verification": True}

View file

@ -104,8 +104,12 @@ class SUserFilter(BaseModel):
email: EmailStr | None = None email: EmailStr | None = None
class SInvitationData(BaseModel): class SConfirmationData(BaseModel):
user_id: int user_id: int
username: str username: str
email_to: EmailStr email_to: EmailStr
confirmation_code: str confirmation_code: str
class SInvitationData(BaseModel):
chat_id: int

View file

@ -13,15 +13,10 @@ from app.exceptions import (
UserMustConfirmEmailException, UserMustConfirmEmailException,
) )
from app.unit_of_work import UnitOfWork from app.unit_of_work import UnitOfWork
from app.users.schemas import SUser, SInvitationData from app.users.schemas import SUser, SConfirmationData, SInvitationData
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
ADMIN_USER = 100
ADMIN_USER_ID = 3
REGISTRATED_USER = 0
VERIFICATED_USER = 1
def get_password_hash(password: str) -> str: def get_password_hash(password: str) -> str:
return pwd_context.hash(password) return pwd_context.hash(password)
@ -52,6 +47,19 @@ def decode_invitation_token(invitation_token: str) -> SInvitationData:
return SInvitationData.model_validate_json(user_data) return SInvitationData.model_validate_json(user_data)
def encode_confirmation_token(user_data: SConfirmationData) -> str:
cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY)
invitation_token = cipher_suite.encrypt(user_data.model_dump_json().encode())
return invitation_token.decode()
def decode_confirmation_token(invitation_token: str) -> SConfirmationData:
user_code = invitation_token.encode()
cipher_suite = Fernet(settings.INVITATION_LINK_TOKEN_KEY)
user_data = cipher_suite.decrypt(user_code)
return SConfirmationData.model_validate_json(user_data)
class AuthService: class AuthService:
@staticmethod @staticmethod
async def authenticate_user_by_email(uow: UnitOfWork, email: EmailStr, password: str) -> SUser | None: async def authenticate_user_by_email(uow: UnitOfWork, email: EmailStr, password: str) -> SUser | None:
@ -84,7 +92,7 @@ class AuthService:
user = await uow.user.find_one_or_none(id=user_id) user = await uow.user.find_one_or_none(id=user_id)
if not user: if not user:
raise UserNotFoundException raise UserNotFoundException
return user.role >= VERIFICATED_USER return user.role >= settings.VERIFICATED_USER
@classmethod @classmethod
async def check_verificated_user_with_exc(cls, uow: UnitOfWork, user_id: int): async def check_verificated_user_with_exc(cls, uow: UnitOfWork, user_id: int):
@ -109,6 +117,6 @@ class AuthService:
async def validate_user_admin(uow: UnitOfWork, user_id: int) -> bool: async def validate_user_admin(uow: UnitOfWork, user_id: int) -> bool:
async with uow: async with uow:
user_role = await uow.user.get_user_role(user_id=user_id) user_role = await uow.user.get_user_role(user_id=user_id)
if user_role == ADMIN_USER: if user_role == settings.ADMIN_USER:
return True return True
return False return False

View file

@ -1,63 +1,3 @@
import json import os
from datetime import datetime
import pytest os.environ["MODE"] = "TEST"
from sqlalchemy import insert, update
from httpx import AsyncClient
from app.config import settings
from app.database import Base, async_session_maker, engine
from app.models.users import Users
from app.models.user_verification_code import UserVerificationCode
from app.models.chat import Chats
from app.models.message import Message
from app.models.user_chat import UserChat
from app.main import app as fastapi_app
@pytest.fixture(autouse=True, scope="module")
async def prepare_database():
assert settings.MODE == "TEST"
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
raise
await conn.run_sync(Base.metadata.create_all)
def open_mock_json(model: str):
with open(f"app/tests/mock_{model}.json", "r", encoding="utf8") as file:
return json.load(file)
users = open_mock_json("users")
users_verification_codes = open_mock_json("verification_codes")
chats = open_mock_json("chats")
users_x_chats = open_mock_json("x_chats")
messages = open_mock_json("messages")
new_users = []
for i in users:
i["date_of_birth"] = datetime.strptime(i["date_of_birth"], "%Y-%m-%d")
new_users.append(i)
async with async_session_maker() as session:
add_users = insert(Users).values(new_users)
add_users_verification_codes = insert(UserVerificationCode).values(users_verification_codes)
add_chats = insert(Chats).values(chats)
add_users_x_chats = insert(UserChat).values(users_x_chats)
add_messages = insert(Message).values(messages)
set_verified_user = update(Users).values(role=1).where(Users.id == 3)
await session.execute(add_users)
await session.execute(add_users_verification_codes)
await session.execute(add_chats)
await session.execute(add_users_x_chats)
await session.execute(add_messages)
await session.execute(set_verified_user)
await session.commit()
@pytest.fixture(scope="function")
async def ac():
async with AsyncClient(app=fastapi_app, base_url="http://test") as ac:
yield ac

View file

@ -0,0 +1,61 @@
import json
from datetime import datetime
import pytest
from sqlalchemy import insert, update
from httpx import AsyncClient
from app.config import settings
from app.database import Base, async_session_maker, engine
from app.models.users import Users
from app.models.chat import Chats
from app.models.message import Message
from app.models.user_chat import UserChat
from app.main import app as fastapi_app
@pytest.fixture(autouse=True, scope="module")
async def prepare_database():
assert settings.MODE == "TEST"
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
raise
await conn.run_sync(Base.metadata.create_all)
def open_mock_json(model: str):
with open(f"app/tests/mock_{model}.json", "r", encoding="utf8") as file:
return json.load(file)
users = open_mock_json("users")
users_verification_codes = open_mock_json("verification_codes")
chats = open_mock_json("chats")
users_x_chats = open_mock_json("x_chats")
messages = open_mock_json("messages")
new_users = []
for i in users:
i["date_of_birth"] = datetime.strptime(i["date_of_birth"], "%Y-%m-%d")
new_users.append(i)
async with async_session_maker() as session:
add_users = insert(Users).values(new_users)
add_chats = insert(Chats).values(chats)
add_users_x_chats = insert(UserChat).values(users_x_chats)
add_messages = insert(Message).values(messages)
set_verified_user = update(Users).values(role=1).where(Users.id == 3)
await session.execute(add_users)
await session.execute(add_users_verification_codes)
await session.execute(add_chats)
await session.execute(add_users_x_chats)
await session.execute(add_messages)
await session.execute(set_verified_user)
await session.commit()
@pytest.fixture(scope="function")
async def ac():
async with AsyncClient(app=fastapi_app, base_url="http://test") as ac:
yield ac

View file

@ -2,7 +2,7 @@ import pytest
from httpx import AsyncClient from httpx import AsyncClient
from app.services.user_service import UserService from app.services.user_service import UserService
from app.users.auth import verify_password from app.utils.auth import verify_password
async def test_get_users(ac: AsyncClient): async def test_get_users(ac: AsyncClient):