Переделал чат на монгу

This commit is contained in:
urec56 2024-08-20 12:34:04 +04:00
parent 6e3ad02e6b
commit afc77b2000
17 changed files with 212 additions and 361 deletions

View file

@ -197,7 +197,7 @@ async def get_some_messages(
uow=uow, uow=uow,
chat_id=chat_id, chat_id=chat_id,
message_number_from=last_messages.messages_loaded, message_number_from=last_messages.messages_loaded,
messages_to_get=last_messages.messages_to_get messages_to_get=last_messages.messages_to_get,
) )
return messages return messages

View file

@ -1,10 +1,11 @@
from datetime import datetime from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, HttpUrl from pydantic import BaseModel, HttpUrl
class SMessage(BaseModel): class SMessage(BaseModel):
id: int id: int | UUID # TODO: Заменить на UUID
message: str | None = None message: str | None = None
image_url: str | None = None image_url: str | None = None
chat_id: int chat_id: int
@ -12,19 +13,19 @@ class SMessage(BaseModel):
username: str username: str
created_at: datetime created_at: datetime
avatar_image: str avatar_image: str
answer_id: int | None answer_id: int | None | UUID # TODO: Заменить на UUID
answer_message: str | None answer_message: str | None
answer_image_url: str | None answer_image_url: str | None
class SMessageRaw(BaseModel): class SMessageRaw(BaseModel):
_id: str id: UUID
message: str | None = None message: str | None = None
image_url: str | None = None image_url: str | None = None
chat_id: int chat_id: int
user_id: int user_id: int
created_at: datetime created_at: datetime
answer_id: int | None = None answer_id: UUID | None = None
answer_message: str | None = None answer_message: str | None = None
answer_image_url: str | None = None answer_image_url: str | None = None
@ -76,7 +77,7 @@ class SSendMessage(BaseModel):
flag: str flag: str
message: str message: str
image_url: str | None image_url: str | None
answer: int | None answer: UUID | None
class SDeleteMessage(BaseModel): class SDeleteMessage(BaseModel):

View file

@ -4,7 +4,7 @@ from collections import defaultdict
import websockets import websockets
from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException, status from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException, status
from app.chat.exceptions import UseWSException from app.chat.exceptions import UseWSException, MessageNotFoundException, MessageAlreadyPinnedException
from app.exceptions import IncorrectDataException from app.exceptions import IncorrectDataException
from app.chat.exceptions import UserDontHavePermissionException from app.chat.exceptions import UserDontHavePermissionException
from app.services.message_service import MessageService from app.services.message_service import MessageService
@ -20,33 +20,37 @@ class ConnectionManager:
def __init__(self): def __init__(self):
self.active_connections: dict[int, list[WebSocket]] = defaultdict(list) self.active_connections: dict[int, list[WebSocket]] = defaultdict(list)
self.message_methods = { self.message_methods = {
"send": self.send, "send": self._send,
"delete": self.delete, "delete": self._delete,
"edit": self.edit, "edit": self._edit,
"pin": self.pin, "pin": self._pin,
"unpin": self.unpin, "unpin": self._unpin,
} }
async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None): async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None) -> None:
await websocket.accept(subprotocol=subprotocol) await websocket.accept(subprotocol=subprotocol)
self.active_connections[chat_id].append(websocket) self.active_connections[chat_id].append(websocket)
async def disconnect(self, chat_id: int, websocket: WebSocket, code_and_reason: tuple[int, str] | None = None): async def disconnect(
self, chat_id: int, websocket: WebSocket, code_and_reason: tuple[int, str] | None = None
) -> None:
self.active_connections[chat_id].remove(websocket) self.active_connections[chat_id].remove(websocket)
if code_and_reason: if code_and_reason:
await websocket.close(code=code_and_reason[0], reason=code_and_reason[1]) await websocket.close(code=code_and_reason[0], reason=code_and_reason[1])
async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict): async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> None:
try: try:
new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message) new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message)
for websocket in self.active_connections[chat_id]: for websocket in self.active_connections[chat_id]:
await websocket.send_json(new_message) await websocket.send_json(new_message)
await polling_manager.send(chat_id, new_message) await polling_manager.send(chat_id, new_message)
except (MessageNotFoundException, MessageAlreadyPinnedException):
pass
except KeyError: except KeyError:
raise IncorrectDataException raise IncorrectDataException
@staticmethod @staticmethod
async def send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict: async def _send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict:
message = SSendMessage.model_validate(message) message = SSendMessage.model_validate(message)
new_message = await MessageService.send_message( new_message = await MessageService.send_message(
@ -58,7 +62,7 @@ class ConnectionManager:
return new_message return new_message
@staticmethod @staticmethod
async def delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: async def _delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
message = SDeleteMessage.model_validate(message) message = SDeleteMessage.model_validate(message)
if message.user_id != user_id: if message.user_id != user_id:
@ -69,7 +73,7 @@ class ConnectionManager:
return new_message return new_message
@staticmethod @staticmethod
async def edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict: async def _edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
message = SEditMessage.model_validate(message) message = SEditMessage.model_validate(message)
if message.user_id != user_id: if message.user_id != user_id:
@ -87,7 +91,7 @@ class ConnectionManager:
return new_message return new_message
@staticmethod @staticmethod
async def pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: async def _pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
message = SPinMessage.model_validate(message) message = SPinMessage.model_validate(message)
pinned_message = await MessageService.pin_message( pinned_message = await MessageService.pin_message(
uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id
@ -98,7 +102,7 @@ class ConnectionManager:
return new_message return new_message
@staticmethod @staticmethod
async def unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict: async def _unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
message = SUnpinMessage.model_validate(message) message = SUnpinMessage.model_validate(message)
await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id) await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id)
new_message = {"flag": "unpin", "id": message.id} new_message = {"flag": "unpin", "id": message.id}
@ -179,7 +183,7 @@ class PollingManager:
self.waiters: dict[int, list[asyncio.Future]] = defaultdict(list) self.waiters: dict[int, list[asyncio.Future]] = defaultdict(list)
self.messages: dict[int, list[dict]] = defaultdict(list) self.messages: dict[int, list[dict]] = defaultdict(list)
async def poll(self, chat_id: int): async def poll(self, chat_id: int) -> dict:
future = asyncio.Future() future = asyncio.Future()
self.waiters[chat_id].append(future) self.waiters[chat_id].append(future)
try: try:
@ -196,7 +200,7 @@ class PollingManager:
user_id: int, user_id: int,
chat_id: int, chat_id: int,
message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage, message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage,
): ) -> None:
message = message.model_dump() message = message.model_dump()
await manager.broadcast(uow=uow, user_id=user_id, chat_id=chat_id, message=message) await manager.broadcast(uow=uow, user_id=user_id, chat_id=chat_id, message=message)
@ -206,7 +210,7 @@ class PollingManager:
self, self,
chat_id: int, chat_id: int,
message: dict, message: dict,
): ) -> None:
self.messages[chat_id].append(message) self.messages[chat_id].append(message)
while self.waiters[chat_id]: while self.waiters[chat_id]:
waiter = self.waiters[chat_id].pop(0) waiter = self.waiters[chat_id].pop(0)

View file

@ -1,22 +1,18 @@
from uuid import UUID
from pydantic import HttpUrl from pydantic import HttpUrl
from sqlalchemy import insert, select, update, delete, func from sqlalchemy import insert, select, update, delete, func
from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.orm import aliased
from app.dao.base import BaseDAO from app.dao.base import BaseDAO
from app.chat.exceptions import ( from app.chat.exceptions import (
UserAlreadyInChatException, UserAlreadyInChatException,
UserAlreadyPinnedChatException, UserAlreadyPinnedChatException,
MessageNotFoundException,
MessageAlreadyPinnedException, MessageAlreadyPinnedException,
ChatNotFoundException, ChatNotFoundException,
) )
from app.chat.shemas import SMessage, SMessageList, SPinnedMessages, SPinnedChats, SChat from app.chat.shemas import SPinnedChats, SChat
from app.database import db
from app.models.users import Users
from app.models.message_answer import MessageAnswer
from app.models.chat import Chat from app.models.chat import Chat
from app.models.message import Message
from app.models.pinned_chat import PinnedChat from app.models.pinned_chat import PinnedChat
from app.models.pinned_message import PinnedMessage from app.models.pinned_message import PinnedMessage
from app.models.user_chat import UserChat from app.models.user_chat import UserChat
@ -64,142 +60,6 @@ class ChatDAO(BaseDAO):
except IntegrityError: except IntegrityError:
raise UserAlreadyInChatException raise UserAlreadyInChatException
async def send_message(self, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> int:
stmt = (
insert(Message)
.values(chat_id=chat_id, user_id=user_id, message=message, image_url=image_url)
.returning(Message.id)
)
result = await self.session.execute(stmt)
return result.scalar()
async def get_message_by_id(self, message_id: int) -> SMessage:
try:
msg = aliased(Message, name="msg")
query = (
select(
func.json_build_object(
"id", Message.id,
"message", Message.message,
"image_url", Message.image_url,
"chat_id", Message.chat_id,
"user_id", Message.user_id,
"created_at", Message.created_at,
"avatar_image", Users.avatar_image,
"username", Users.username,
"answer_id", MessageAnswer.answer_id,
"answer_message", select(
msg.message
)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery(),
"answer_image_url", select(
msg.image_url
)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery(),
)
)
.select_from(Message)
.join(Users, Users.id == Message.user_id)
.join(MessageAnswer, Message.id == MessageAnswer.self_id, isouter=True)
.where(Message.id == message_id, Message.visibility == True) # noqa: E712
)
result = await self.session.execute(query)
result = result.scalar_one()
return SMessage.model_validate(result)
except NoResultFound:
raise MessageNotFoundException
async def delete_message(self, message_id: int) -> None:
stmt = update(Message).values(visibility=False).where(Message.id == message_id)
await self.session.execute(stmt)
async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageList:
msg = aliased(Message, name="msg")
messages_with_users = (
select(
Message.id,
Message.message,
Message.image_url,
Message.chat_id,
Message.user_id,
Message.created_at,
Users.username,
Users.avatar_image,
MessageAnswer.answer_id,
(
select(msg.message)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery()
).label("answer_message"),
(
select(msg.image_url)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery()
).label("answer_image_url"),
)
.select_from(Message)
.join(Users, Message.user_id == Users.id)
.join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
.where(Message.chat_id == chat_id, Message.visibility == True) # noqa: E712
.order_by(Message.created_at.desc())
.limit(messages_to_get)
.offset(message_number_from)
.cte("messages_with_users")
)
query = (
select(
func.json_build_object(
"messages",
func.json_agg(
func.json_build_object(
"id", messages_with_users.c.id,
"message", messages_with_users.c.message,
"image_url", messages_with_users.c.image_url,
"chat_id", messages_with_users.c.chat_id,
"user_id", messages_with_users.c.user_id,
"created_at", messages_with_users.c.created_at,
"avatar_image", messages_with_users.c.avatar_image,
"username", messages_with_users.c.username,
"answer_id", messages_with_users.c.answer_id,
"answer_message", messages_with_users.c.answer_message,
"answer_image_url", messages_with_users.c.answer_image_url,
)
)
)
).select_from(messages_with_users)
)
result = await self.session.execute(query)
result = result.scalar()
return SMessageList.model_validate(result)
async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> None:
stmt = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url)
await self.session.execute(stmt)
async def add_answer(self, self_id: int, answer_id: int) -> None:
stmt = (
insert(MessageAnswer)
.values(self_id=self_id, answer_id=answer_id)
)
await self.session.execute(stmt)
async def change_data(self, chat_id: int, chat_name: str, avatar_image: HttpUrl) -> None: async def change_data(self, chat_id: int, chat_name: str, avatar_image: HttpUrl) -> None:
stmt = ( stmt = (
update(Chat) update(Chat)
@ -252,94 +112,26 @@ class ChatDAO(BaseDAO):
result = result.scalar_one() result = result.scalar_one()
return SPinnedChats.model_validate(result) return SPinnedChats.model_validate(result)
async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> None: async def pin_message(self, chat_id: int, message_id: UUID, user_id: int) -> None:
try: try:
stmt = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id) stmt = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id)
await self.session.execute(stmt) await self.session.execute(stmt)
except IntegrityError: except IntegrityError:
raise MessageAlreadyPinnedException raise MessageAlreadyPinnedException
async def unpin_message(self, chat_id: int, message_id: int) -> None: async def unpin_message(self, chat_id: int, message_id: UUID) -> None:
stmt = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id) stmt = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id)
await self.session.execute(stmt) await self.session.execute(stmt)
async def get_pinned_messages(self, chat_id: int) -> SPinnedMessages: async def get_pinned_messages_ids(self, chat_id: int) -> list[UUID]:
msg = aliased(Message, name="msg")
messages_with_users = (
select(
Message.id,
Message.message,
Message.image_url,
Message.chat_id,
Message.user_id,
Message.created_at,
Message.visibility,
Users.username,
Users.avatar_image,
MessageAnswer.self_id,
MessageAnswer.answer_id,
(
select(msg.message)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True) # noqa
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery()
).label("answer_message"),
(
select(msg.image_url)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery()
).label("answer_image_url"),
)
.select_from(PinnedMessage)
.join(Message, PinnedMessage.message_id == Message.id)
.join(Users, Message.user_id == Users.id)
.join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
.where(PinnedMessage.chat_id == chat_id, Message.visibility == True) # noqa: E712
.order_by(Message.created_at.desc())
.cte("messages_with_users")
)
query = ( query = (
select( select(PinnedMessage.message_id)
func.json_build_object( .where(PinnedMessage.chat_id == chat_id)
"pinned_messages",
func.json_agg(
func.json_build_object(
"id", messages_with_users.c.id,
"message", messages_with_users.c.message,
"image_url", messages_with_users.c.image_url,
"chat_id", messages_with_users.c.chat_id,
"user_id", messages_with_users.c.user_id,
"created_at", messages_with_users.c.created_at,
"avatar_image", messages_with_users.c.avatar_image,
"username", messages_with_users.c.username,
"answer_id", messages_with_users.c.answer_id,
"answer_message", messages_with_users.c.answer_message,
"answer_image_url", messages_with_users.c.answer_image_url
)
)
)
).select_from(messages_with_users)
) )
result = await self.session.execute(query) result = await self.session.execute(query)
result = result.scalar_one() result = result.scalars().all()
pinned_messages = SPinnedMessages.model_validate(result) return result # noqa
return pinned_messages
async def test_mongo(self):
await db.message.insert_one(
{
"message": "пизда",
"chat_id": 1,
"user_id": 1,
"image_url": "",
"created_at": "2024-06-13 13:49:59.14757+00",
"visibility": True,
}
)

60
app/dao/message.py Normal file
View file

@ -0,0 +1,60 @@
from datetime import datetime, UTC
from uuid import uuid4, UUID
from app.chat.shemas import SMessageRaw, SMessageRawList
class MessageDAO:
def __init__(self, mongo_db):
self.message = mongo_db.message
async def send_message(
self,
user_id: int,
chat_id: int,
message: str | None,
image_url: str | None,
answer_id: UUID | None = None,
answer_message: str | None = None,
answer_image_url: str | None = None,
) -> UUID:
message_id = uuid4()
await self.message.insert_one(
{
"id": str(message_id),
"message": message,
"image_url": image_url,
"chat_id": chat_id,
"user_id": user_id,
"created_at": datetime.now(UTC),
"answer_id": str(answer_id) if answer_id else None,
"answer_message": answer_message,
"answer_image_url": answer_image_url,
"visibility": True,
}
)
return message_id
async def get_message_by_id(self, message_id: UUID) -> SMessageRaw:
message = self.message.find_one({"id": str(message_id)})
return SMessageRaw.model_validate(message)
async def delete_message(self, message_id: UUID) -> None:
await self.message.update_one({"id": str(message_id)}, {"$set": {"visibility": False}})
async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageRawList:
cursor = self.message.find({"visibility": True, "chat_id": chat_id})
cursor.sort("created_at").skip(message_number_from)
return SMessageRawList.model_validate({"message_raw_list": await cursor.to_list(length=messages_to_get)})
async def edit_message(self, message_id: UUID, new_message: str, new_image_url: str) -> None:
await self.message.update_one(
{"id": str(message_id)},
{"$set": {"message": new_message, "image_url": new_image_url}}
)
async def get_messages_from_ids(self, messages_ids: list[UUID]) -> SMessageRawList:
cursor = self.message.find({"visibility": True, "id": {"$in": [str(message_id) for message_id in messages_ids]}})
return SMessageRawList.model_validate({"message_raw_list": await cursor.to_list(length=None) or None})

View file

@ -1,4 +1,4 @@
from pydantic import HttpUrl, EmailStr from pydantic import EmailStr
from sqlalchemy import update, select, insert, func, or_ from sqlalchemy import update, select, insert, func, or_
from sqlalchemy.exc import MultipleResultsFound, IntegrityError, NoResultFound from sqlalchemy.exc import MultipleResultsFound, IntegrityError, NoResultFound

View file

@ -29,5 +29,5 @@ MONGO_URL = f"mongodb://{settings.MONGO_HOST}:{settings.MONGO_PORT}"
mongo_client = AsyncIOMotorClient(MONGO_URL) mongo_client = AsyncIOMotorClient(MONGO_URL)
db = mongo_client.test_db mongo_db = mongo_client.test_db

View file

@ -6,9 +6,7 @@ from alembic import context
from app.database import DATABASE_URL, Base from app.database import DATABASE_URL, Base
from app.models.users import Users # noqa from app.models.users import Users # noqa
from app.models.message_answer import MessageAnswer # noqa
from app.models.chat import Chat # noqa from app.models.chat import Chat # noqa
from app.models.message import Message # noqa
from app.models.pinned_chat import PinnedChat # noqa from app.models.pinned_chat import PinnedChat # noqa
from app.models.pinned_message import PinnedMessage # noqa from app.models.pinned_message import PinnedMessage # noqa
from app.models.user_chat import UserChat # noqa from app.models.user_chat import UserChat # noqa

View file

@ -1,8 +1,8 @@
"""Database Creation """Database Creation
Revision ID: c69369209bab Revision ID: 53fd6f2b93a4
Revises: Revises:
Create Date: 2024-06-13 18:40:03.297322 Create Date: 2024-08-20 12:30:52.435668
""" """
from typing import Sequence, Union from typing import Sequence, Union
@ -12,7 +12,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'c69369209bab' revision: str = '53fd6f2b93a4'
down_revision: Union[str, None] = None down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@ -52,18 +52,6 @@ def upgrade() -> None:
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )
op.create_table('message',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('message', sa.String(), nullable=True),
sa.Column('image_url', sa.String(), nullable=True),
sa.Column('chat_id', sa.Integer(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('visibility', sa.Boolean(), server_default='true', nullable=False),
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('pinned_chat', op.create_table('pinned_chat',
sa.Column('user_id', sa.Integer(), nullable=False), sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('chat_id', sa.Integer(), nullable=False), sa.Column('chat_id', sa.Integer(), nullable=False),
@ -71,6 +59,14 @@ def upgrade() -> None:
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('user_id', 'chat_id') sa.PrimaryKeyConstraint('user_id', 'chat_id')
) )
op.create_table('pinned_message',
sa.Column('chat_id', sa.Integer(), nullable=True),
sa.Column('message_id', sa.Uuid(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('message_id', 'user_id')
)
op.create_table('user_chat', op.create_table('user_chat',
sa.Column('user_id', sa.Integer(), nullable=False), sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('chat_id', sa.Integer(), nullable=False), sa.Column('chat_id', sa.Integer(), nullable=False),
@ -78,32 +74,14 @@ def upgrade() -> None:
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('user_id', 'chat_id') sa.PrimaryKeyConstraint('user_id', 'chat_id')
) )
op.create_table('message_answer',
sa.Column('self_id', sa.Integer(), nullable=False),
sa.Column('answer_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['answer_id'], ['message.id'], ),
sa.ForeignKeyConstraint(['self_id'], ['message.id'], ),
sa.PrimaryKeyConstraint('self_id')
)
op.create_table('pinned_message',
sa.Column('chat_id', sa.Integer(), nullable=True),
sa.Column('message_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ),
sa.ForeignKeyConstraint(['message_id'], ['message.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('message_id')
)
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_table('pinned_message')
op.drop_table('message_answer')
op.drop_table('user_chat') op.drop_table('user_chat')
op.drop_table('pinned_message')
op.drop_table('pinned_chat') op.drop_table('pinned_chat')
op.drop_table('message')
op.drop_table('user_avatar') op.drop_table('user_avatar')
op.drop_table('chat') op.drop_table('chat')
op.drop_table('users') op.drop_table('users')

View file

@ -1,18 +0,0 @@
from datetime import datetime
from sqlalchemy import ForeignKey, func, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class Message(Base):
__tablename__ = "message"
id: Mapped[int] = mapped_column(primary_key=True)
chat_id = mapped_column(ForeignKey("chat.id"))
user_id = mapped_column(ForeignKey("users.id"))
message: Mapped[str | None]
image_url: Mapped[str | None]
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
visibility: Mapped[bool] = mapped_column(server_default="true")

View file

@ -1,11 +0,0 @@
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class MessageAnswer(Base):
__tablename__ = "message_answer"
self_id: Mapped[int] = mapped_column(ForeignKey("message.id"), primary_key=True)
answer_id: Mapped[int] = mapped_column(ForeignKey("message.id"))

View file

@ -1,5 +1,7 @@
from uuid import UUID
from sqlalchemy import ForeignKey from sqlalchemy import ForeignKey
from sqlalchemy.orm import mapped_column from sqlalchemy.orm import mapped_column, Mapped
from app.database import Base from app.database import Base
@ -8,5 +10,5 @@ class PinnedMessage(Base):
__tablename__ = "pinned_message" __tablename__ = "pinned_message"
chat_id = mapped_column(ForeignKey("chat.id")) chat_id = mapped_column(ForeignKey("chat.id"))
message_id = mapped_column(ForeignKey("message.id"), primary_key=True) message_id: Mapped[UUID] = mapped_column(primary_key=True)
user_id = mapped_column(ForeignKey("users.id")) user_id = mapped_column(ForeignKey("users.id"), primary_key=True)

View file

@ -126,5 +126,4 @@ class ChatService:
async def get_pinned_chats(uow: UnitOfWork, user_id: int) -> SPinnedChats: async def get_pinned_chats(uow: UnitOfWork, user_id: int) -> SPinnedChats:
async with uow: async with uow:
pinned_chats = await uow.chat.get_pinned_chats(user_id=user_id) pinned_chats = await uow.chat.get_pinned_chats(user_id=user_id)
await uow.chat.test_mongo()
return pinned_chats return pinned_chats

View file

@ -1,54 +1,93 @@
from app.chat.shemas import SMessage, SSendMessage, SPinnedMessages, SMessageList, SMessageRaw from datetime import datetime, UTC, timedelta
from app.chat.shemas import SMessage, SSendMessage, SPinnedMessages, SMessageList, SMessageRaw, SMessageRawList
from app.services.user_service import UserService from app.services.user_service import UserService
from app.users.schemas import SUser
from app.utils.unit_of_work import UnitOfWork from app.utils.unit_of_work import UnitOfWork
class MessageService: class MessageService:
@staticmethod users: dict[int, tuple[SUser, datetime]] = {}
async def add_avatar_image_and_username_to_message(uow: UnitOfWork, message: SMessageRaw) -> SMessage:
@classmethod
async def _get_cached_user(cls, uow: UnitOfWork, user_id: int) -> SUser:
if user_id in cls.users and cls.users[user_id][1] > datetime.now(UTC):
return cls.users[user_id][0]
user = await UserService.find_user(uow=uow, id=user_id)
cls.users[user_id] = user, datetime.now(UTC) + timedelta(minutes=5)
return user
@classmethod
async def add_avatar_image_and_username_to_message(cls, uow: UnitOfWork, message: SMessageRaw) -> SMessage:
user = await cls._get_cached_user(uow=uow, user_id=message.user_id)
message = message.model_dump() message = message.model_dump()
user = await UserService.find_user(uow=uow, id=message["user_id"]) message["avatar_image"] = str(user.avatar_image)
message["id"] = message["_id"]
message["avatar_image"] = user.avatar_image
message["username"] = user.username message["username"] = user.username
return SMessage.model_validate(message) return SMessage.model_validate(message)
@staticmethod @classmethod
async def add_avatar_image_and_username_to_message_list(
cls, uow: UnitOfWork, messages: SMessageRawList
) -> SMessageList:
return SMessageList.model_validate(
{"messages": [
await cls.add_avatar_image_and_username_to_message(uow=uow, message=message)
for message in messages.message_raw_list
] if messages.message_raw_list else None
}
)
@classmethod
async def send_message( async def send_message(
uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage, image_url: str | None = None cls, uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage, image_url: str | None = None
) -> SMessage: ) -> SMessage:
async with uow: async with uow:
message_id = await uow.chat.send_message(
user_id=user_id, chat_id=chat_id, message=message.message, image_url=image_url
)
if message.answer: if message.answer:
await uow.chat.add_answer(self_id=message_id, answer_id=message.answer) answer_message = await uow.message.get_message_by_id(message_id=message.answer)
new_message = await uow.chat.get_message_by_id(message_id=message_id) message_id = await uow.message.send_message(
await uow.commit() user_id=user_id,
chat_id=chat_id,
message=message.message,
image_url=image_url,
answer_id=answer_message.id,
answer_message=answer_message.message,
answer_image_url=answer_message.answer_image_url,
)
else:
message_id = await uow.message.send_message(
user_id=user_id,
chat_id=chat_id,
message=message.message,
image_url=image_url,
)
raw_message = await uow.message.get_message_by_id(message_id=message_id)
new_message = await cls.add_avatar_image_and_username_to_message(uow=uow, message=raw_message)
return new_message return new_message
@staticmethod @staticmethod
async def delete_message(uow: UnitOfWork, message_id: int) -> None: async def delete_message(uow: UnitOfWork, message_id: int) -> None:
async with uow: async with uow:
await uow.chat.delete_message(message_id=message_id) await uow.message.delete_message(message_id=message_id)
await uow.commit()
@staticmethod @staticmethod
async def edit_message(uow: UnitOfWork, message_id: int, new_message: str, new_image_url: str) -> None: async def edit_message(uow: UnitOfWork, message_id: int, new_message: str, new_image_url: str) -> None:
async with uow: async with uow:
await uow.chat.edit_message( await uow.message.edit_message(
message_id=message_id, new_message=new_message, new_image_url=new_image_url message_id=message_id, new_message=new_message, new_image_url=new_image_url
) )
await uow.commit()
@staticmethod @classmethod
async def pin_message(uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage: async def pin_message(cls, uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage:
async with uow: async with uow:
await uow.chat.pin_message(chat_id=chat_id, message_id=message_id, user_id=user_id) await uow.chat.pin_message(chat_id=chat_id, message_id=message_id, user_id=user_id)
pinned_message = await uow.chat.get_message_by_id(message_id=message_id)
await uow.commit() await uow.commit()
return pinned_message
raw_message = await uow.message.get_message_by_id(message_id=message_id)
pinned_message = await cls.add_avatar_image_and_username_to_message(uow=uow, message=raw_message)
return pinned_message
@staticmethod @staticmethod
async def unpin_message(uow: UnitOfWork, chat_id: int, message_id: int) -> None: async def unpin_message(uow: UnitOfWork, chat_id: int, message_id: int) -> None:
@ -56,24 +95,28 @@ class MessageService:
await uow.chat.unpin_message(chat_id=chat_id, message_id=message_id) await uow.chat.unpin_message(chat_id=chat_id, message_id=message_id)
await uow.commit() await uow.commit()
@staticmethod @classmethod
async def get_message_by_id(uow: UnitOfWork, message_id: int) -> SMessage: async def get_message_by_id(cls, uow: UnitOfWork, message_id: int) -> SMessage:
async with uow: async with uow:
message = await uow.chat.get_message_by_id(message_id=message_id) raw_message = await uow.message.get_message_by_id(message_id=message_id)
message = await cls.add_avatar_image_and_username_to_message(uow=uow, message=raw_message)
return message return message
@staticmethod @classmethod
async def get_pinned_messages(uow: UnitOfWork, chat_id: int) -> SPinnedMessages: async def get_pinned_messages(cls, uow: UnitOfWork, chat_id: int) -> SPinnedMessages:
async with uow: async with uow:
pinned_messages = await uow.chat.get_pinned_messages(chat_id=chat_id) pinned_messages_ids = await uow.chat.get_pinned_messages_ids(chat_id=chat_id)
return pinned_messages raw_messages = await uow.message.get_messages_from_ids(messages_ids=pinned_messages_ids)
pinned_messages = await cls.add_avatar_image_and_username_to_message_list(uow=uow, messages=raw_messages)
return SPinnedMessages.model_validate({"pinned_messages": pinned_messages.messages})
@staticmethod @classmethod
async def get_some_messages( async def get_some_messages(
uow: UnitOfWork, chat_id: int, message_number_from: int, messages_to_get: int cls, uow: UnitOfWork, chat_id: int, message_number_from: int, messages_to_get: int
) -> SMessageList: ) -> SMessageList:
async with uow: async with uow:
message = await uow.chat.get_some_messages( messages = await uow.message.get_some_messages(
chat_id=chat_id, message_number_from=message_number_from, messages_to_get=messages_to_get chat_id=chat_id, message_number_from=message_number_from, messages_to_get=messages_to_get
) )
return message messages = await cls.add_avatar_image_and_username_to_message_list(uow=uow, messages=messages)
return messages

View file

@ -112,7 +112,7 @@ class UserService:
avatar_image=str(user_data.avatar_url or user.avatar_image), avatar_image=str(user_data.avatar_url or user.avatar_image),
hashed_password=hashed_password hashed_password=hashed_password
) )
(await uow.user.add_user_avatar(user_id=user.id, avatar=str(user_data.avatar_url))) if user_data.avatar_url else None (await uow.user.add_user_avatar(user_id=user.id, avatar=str(user_data.avatar_url))) if user_data.avatar_url else None # noqa
await uow.commit() await uow.commit()
async with RedisService() as redis: async with RedisService() as redis:

View file

@ -23,7 +23,7 @@ def create_registration_confirmation_template(
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="preconnect" href="https://fonts.googleapis.com"> <link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin> <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Comfortaa:wght@200&display=swap" rel="stylesheet"> <link href="https://fonts.googleapis.com/css2?family=Comfortaa:wght@200&display=swap" rel="stylesheet">
<title>Submiting</title> <title>Submiting</title>
</head> </head>
<body> <body>
@ -38,9 +38,9 @@ def create_registration_confirmation_template(
background-color:#101010; background-color:#101010;
font-family: 'Comfortaa'; font-family: 'Comfortaa';
font-weight: 200; font-weight: 200;
font-style: normal; font-style: normal;
font-stretch: normal; font-stretch: normal;
font-optical-sizing: auto; font-optical-sizing: auto;
"> ">
@ -112,7 +112,7 @@ def create_data_change_confirmation_email(username: str, email_to: EmailStr, _:
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="preconnect" href="https://fonts.googleapis.com"> <link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin> <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Comfortaa:wght@200&display=swap" rel="stylesheet"> <link href="https://fonts.googleapis.com/css2?family=Comfortaa:wght@200&display=swap" rel="stylesheet">
<title>Submiting</title> <title>Submiting</title>
</head> </head>
<body> <body>
@ -128,8 +128,8 @@ def create_data_change_confirmation_email(username: str, email_to: EmailStr, _:
font-family: 'Comfortaa'; font-family: 'Comfortaa';
font-weight: 200; font-weight: 200;
font-style: normal; font-style: normal;
font-stretch: normal; font-stretch: normal;
font-optical-sizing: auto; font-optical-sizing: auto;
"> ">
@ -187,7 +187,7 @@ def create_data_change_email(username: str, email_to: EmailStr):
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="preconnect" href="https://fonts.googleapis.com"> <link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin> <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Comfortaa:wght@200&display=swap" rel="stylesheet"> <link href="https://fonts.googleapis.com/css2?family=Comfortaa:wght@200&display=swap" rel="stylesheet">
<title>Submiting</title> <title>Submiting</title>
</head> </head>
<body> <body>
@ -202,9 +202,9 @@ def create_data_change_email(username: str, email_to: EmailStr):
background-color:#101010; background-color:#101010;
font-family: 'Comfortaa'; font-family: 'Comfortaa';
font-weight: 200; font-weight: 200;
font-style: normal; font-style: normal;
font-stretch: normal; font-stretch: normal;
font-optical-sizing: auto; font-optical-sizing: auto;
"> ">

View file

@ -1,7 +1,8 @@
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from app.dao.chat import ChatDAO from app.dao.chat import ChatDAO
from app.database import async_session_maker from app.dao.message import MessageDAO
from app.database import async_session_maker, mongo_db
from app.dao.user import UserDAO from app.dao.user import UserDAO
from app.exceptions import BlackPhoenixException from app.exceptions import BlackPhoenixException
@ -9,12 +10,14 @@ from app.exceptions import BlackPhoenixException
class UnitOfWork: class UnitOfWork:
def __init__(self): def __init__(self):
self.session_factory = async_session_maker self.session_factory = async_session_maker
self.mongo_db = mongo_db
async def __aenter__(self): async def __aenter__(self):
self.session = self.session_factory() self.session = self.session_factory()
self.user = UserDAO(self.session) self.user = UserDAO(self.session)
self.chat = ChatDAO(self.session) self.chat = ChatDAO(self.session)
self.message = MessageDAO(self.mongo_db)
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
try: try: