Переделал чат на монгу

This commit is contained in:
urec56 2024-08-20 12:34:04 +04:00
parent 6e3ad02e6b
commit afc77b2000
17 changed files with 212 additions and 361 deletions

View file

@ -197,7 +197,7 @@ async def get_some_messages(
uow=uow,
chat_id=chat_id,
message_number_from=last_messages.messages_loaded,
messages_to_get=last_messages.messages_to_get
messages_to_get=last_messages.messages_to_get,
)
return messages

View file

@ -1,10 +1,11 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, HttpUrl
class SMessage(BaseModel):
id: int
id: int | UUID # TODO: Заменить на UUID
message: str | None = None
image_url: str | None = None
chat_id: int
@ -12,19 +13,19 @@ class SMessage(BaseModel):
username: str
created_at: datetime
avatar_image: str
answer_id: int | None
answer_id: int | None | UUID # TODO: Заменить на UUID
answer_message: str | None
answer_image_url: str | None
class SMessageRaw(BaseModel):
_id: str
id: UUID
message: str | None = None
image_url: str | None = None
chat_id: int
user_id: int
created_at: datetime
answer_id: int | None = None
answer_id: UUID | None = None
answer_message: str | None = None
answer_image_url: str | None = None
@ -76,7 +77,7 @@ class SSendMessage(BaseModel):
flag: str
message: str
image_url: str | None
answer: int | None
answer: UUID | None
class SDeleteMessage(BaseModel):

View file

@ -4,7 +4,7 @@ from collections import defaultdict
import websockets
from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException, status
from app.chat.exceptions import UseWSException
from app.chat.exceptions import UseWSException, MessageNotFoundException, MessageAlreadyPinnedException
from app.exceptions import IncorrectDataException
from app.chat.exceptions import UserDontHavePermissionException
from app.services.message_service import MessageService
@ -20,33 +20,37 @@ class ConnectionManager:
def __init__(self):
self.active_connections: dict[int, list[WebSocket]] = defaultdict(list)
self.message_methods = {
"send": self.send,
"delete": self.delete,
"edit": self.edit,
"pin": self.pin,
"unpin": self.unpin,
"send": self._send,
"delete": self._delete,
"edit": self._edit,
"pin": self._pin,
"unpin": self._unpin,
}
async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None):
async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None) -> None:
await websocket.accept(subprotocol=subprotocol)
self.active_connections[chat_id].append(websocket)
async def disconnect(self, chat_id: int, websocket: WebSocket, code_and_reason: tuple[int, str] | None = None):
async def disconnect(
self, chat_id: int, websocket: WebSocket, code_and_reason: tuple[int, str] | None = None
) -> None:
self.active_connections[chat_id].remove(websocket)
if code_and_reason:
await websocket.close(code=code_and_reason[0], reason=code_and_reason[1])
async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict):
async def broadcast(self, uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> None:
try:
new_message = await self.message_methods[message["flag"]](uow, user_id, chat_id, message)
for websocket in self.active_connections[chat_id]:
await websocket.send_json(new_message)
await polling_manager.send(chat_id, new_message)
except (MessageNotFoundException, MessageAlreadyPinnedException):
pass
except KeyError:
raise IncorrectDataException
@staticmethod
async def send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict:
async def _send(uow: UnitOfWork, user_id: int, chat_id: int, message: dict) -> dict:
message = SSendMessage.model_validate(message)
new_message = await MessageService.send_message(
@ -58,7 +62,7 @@ class ConnectionManager:
return new_message
@staticmethod
async def delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
async def _delete(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
message = SDeleteMessage.model_validate(message)
if message.user_id != user_id:
@ -69,7 +73,7 @@ class ConnectionManager:
return new_message
@staticmethod
async def edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
async def _edit(uow: UnitOfWork, user_id: int, _: int, message: dict) -> dict:
message = SEditMessage.model_validate(message)
if message.user_id != user_id:
@ -87,7 +91,7 @@ class ConnectionManager:
return new_message
@staticmethod
async def pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
async def _pin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
message = SPinMessage.model_validate(message)
pinned_message = await MessageService.pin_message(
uow=uow, chat_id=chat_id, user_id=message.user_id, message_id=message.id
@ -98,7 +102,7 @@ class ConnectionManager:
return new_message
@staticmethod
async def unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
async def _unpin(uow: UnitOfWork, _: int, chat_id: int, message: dict) -> dict:
message = SUnpinMessage.model_validate(message)
await MessageService.unpin_message(uow=uow, chat_id=chat_id, message_id=message.id)
new_message = {"flag": "unpin", "id": message.id}
@ -179,7 +183,7 @@ class PollingManager:
self.waiters: dict[int, list[asyncio.Future]] = defaultdict(list)
self.messages: dict[int, list[dict]] = defaultdict(list)
async def poll(self, chat_id: int):
async def poll(self, chat_id: int) -> dict:
future = asyncio.Future()
self.waiters[chat_id].append(future)
try:
@ -196,7 +200,7 @@ class PollingManager:
user_id: int,
chat_id: int,
message: SSendMessage | SDeleteMessage | SEditMessage | SPinMessage | SUnpinMessage,
):
) -> None:
message = message.model_dump()
await manager.broadcast(uow=uow, user_id=user_id, chat_id=chat_id, message=message)
@ -206,7 +210,7 @@ class PollingManager:
self,
chat_id: int,
message: dict,
):
) -> None:
self.messages[chat_id].append(message)
while self.waiters[chat_id]:
waiter = self.waiters[chat_id].pop(0)

View file

@ -1,22 +1,18 @@
from uuid import UUID
from pydantic import HttpUrl
from sqlalchemy import insert, select, update, delete, func
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.orm import aliased
from app.dao.base import BaseDAO
from app.chat.exceptions import (
UserAlreadyInChatException,
UserAlreadyPinnedChatException,
MessageNotFoundException,
MessageAlreadyPinnedException,
ChatNotFoundException,
)
from app.chat.shemas import SMessage, SMessageList, SPinnedMessages, SPinnedChats, SChat
from app.database import db
from app.models.users import Users
from app.models.message_answer import MessageAnswer
from app.chat.shemas import SPinnedChats, SChat
from app.models.chat import Chat
from app.models.message import Message
from app.models.pinned_chat import PinnedChat
from app.models.pinned_message import PinnedMessage
from app.models.user_chat import UserChat
@ -64,142 +60,6 @@ class ChatDAO(BaseDAO):
except IntegrityError:
raise UserAlreadyInChatException
async def send_message(self, user_id: int, chat_id: int, message: str, image_url: str | None = None) -> int:
stmt = (
insert(Message)
.values(chat_id=chat_id, user_id=user_id, message=message, image_url=image_url)
.returning(Message.id)
)
result = await self.session.execute(stmt)
return result.scalar()
async def get_message_by_id(self, message_id: int) -> SMessage:
try:
msg = aliased(Message, name="msg")
query = (
select(
func.json_build_object(
"id", Message.id,
"message", Message.message,
"image_url", Message.image_url,
"chat_id", Message.chat_id,
"user_id", Message.user_id,
"created_at", Message.created_at,
"avatar_image", Users.avatar_image,
"username", Users.username,
"answer_id", MessageAnswer.answer_id,
"answer_message", select(
msg.message
)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery(),
"answer_image_url", select(
msg.image_url
)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery(),
)
)
.select_from(Message)
.join(Users, Users.id == Message.user_id)
.join(MessageAnswer, Message.id == MessageAnswer.self_id, isouter=True)
.where(Message.id == message_id, Message.visibility == True) # noqa: E712
)
result = await self.session.execute(query)
result = result.scalar_one()
return SMessage.model_validate(result)
except NoResultFound:
raise MessageNotFoundException
async def delete_message(self, message_id: int) -> None:
stmt = update(Message).values(visibility=False).where(Message.id == message_id)
await self.session.execute(stmt)
async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageList:
msg = aliased(Message, name="msg")
messages_with_users = (
select(
Message.id,
Message.message,
Message.image_url,
Message.chat_id,
Message.user_id,
Message.created_at,
Users.username,
Users.avatar_image,
MessageAnswer.answer_id,
(
select(msg.message)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery()
).label("answer_message"),
(
select(msg.image_url)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery()
).label("answer_image_url"),
)
.select_from(Message)
.join(Users, Message.user_id == Users.id)
.join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
.where(Message.chat_id == chat_id, Message.visibility == True) # noqa: E712
.order_by(Message.created_at.desc())
.limit(messages_to_get)
.offset(message_number_from)
.cte("messages_with_users")
)
query = (
select(
func.json_build_object(
"messages",
func.json_agg(
func.json_build_object(
"id", messages_with_users.c.id,
"message", messages_with_users.c.message,
"image_url", messages_with_users.c.image_url,
"chat_id", messages_with_users.c.chat_id,
"user_id", messages_with_users.c.user_id,
"created_at", messages_with_users.c.created_at,
"avatar_image", messages_with_users.c.avatar_image,
"username", messages_with_users.c.username,
"answer_id", messages_with_users.c.answer_id,
"answer_message", messages_with_users.c.answer_message,
"answer_image_url", messages_with_users.c.answer_image_url,
)
)
)
).select_from(messages_with_users)
)
result = await self.session.execute(query)
result = result.scalar()
return SMessageList.model_validate(result)
async def edit_message(self, message_id: int, new_message: str, new_image_url: str) -> None:
stmt = update(Message).where(Message.id == message_id).values(message=new_message, image_url=new_image_url)
await self.session.execute(stmt)
async def add_answer(self, self_id: int, answer_id: int) -> None:
stmt = (
insert(MessageAnswer)
.values(self_id=self_id, answer_id=answer_id)
)
await self.session.execute(stmt)
async def change_data(self, chat_id: int, chat_name: str, avatar_image: HttpUrl) -> None:
stmt = (
update(Chat)
@ -252,94 +112,26 @@ class ChatDAO(BaseDAO):
result = result.scalar_one()
return SPinnedChats.model_validate(result)
async def pin_message(self, chat_id: int, message_id: int, user_id: int) -> None:
async def pin_message(self, chat_id: int, message_id: UUID, user_id: int) -> None:
try:
stmt = insert(PinnedMessage).values(chat_id=chat_id, message_id=message_id, user_id=user_id)
await self.session.execute(stmt)
except IntegrityError:
raise MessageAlreadyPinnedException
async def unpin_message(self, chat_id: int, message_id: int) -> None:
async def unpin_message(self, chat_id: int, message_id: UUID) -> None:
stmt = delete(PinnedMessage).where(PinnedMessage.chat_id == chat_id, PinnedMessage.message_id == message_id)
await self.session.execute(stmt)
async def get_pinned_messages(self, chat_id: int) -> SPinnedMessages:
msg = aliased(Message, name="msg")
messages_with_users = (
select(
Message.id,
Message.message,
Message.image_url,
Message.chat_id,
Message.user_id,
Message.created_at,
Message.visibility,
Users.username,
Users.avatar_image,
MessageAnswer.self_id,
MessageAnswer.answer_id,
(
select(msg.message)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True) # noqa
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery()
).label("answer_message"),
(
select(msg.image_url)
.select_from(MessageAnswer)
.join(msg, msg.id == MessageAnswer.answer_id, isouter=True)
.where(MessageAnswer.self_id == Message.id, msg.visibility == True) # noqa: E712
.scalar_subquery()
).label("answer_image_url"),
)
.select_from(PinnedMessage)
.join(Message, PinnedMessage.message_id == Message.id)
.join(Users, Message.user_id == Users.id)
.join(MessageAnswer, MessageAnswer.self_id == Message.id, isouter=True)
.where(PinnedMessage.chat_id == chat_id, Message.visibility == True) # noqa: E712
.order_by(Message.created_at.desc())
.cte("messages_with_users")
)
async def get_pinned_messages_ids(self, chat_id: int) -> list[UUID]:
query = (
select(
func.json_build_object(
"pinned_messages",
func.json_agg(
func.json_build_object(
"id", messages_with_users.c.id,
"message", messages_with_users.c.message,
"image_url", messages_with_users.c.image_url,
"chat_id", messages_with_users.c.chat_id,
"user_id", messages_with_users.c.user_id,
"created_at", messages_with_users.c.created_at,
"avatar_image", messages_with_users.c.avatar_image,
"username", messages_with_users.c.username,
"answer_id", messages_with_users.c.answer_id,
"answer_message", messages_with_users.c.answer_message,
"answer_image_url", messages_with_users.c.answer_image_url
)
)
)
).select_from(messages_with_users)
select(PinnedMessage.message_id)
.where(PinnedMessage.chat_id == chat_id)
)
result = await self.session.execute(query)
result = result.scalar_one()
pinned_messages = SPinnedMessages.model_validate(result)
return pinned_messages
result = result.scalars().all()
return result # noqa
async def test_mongo(self):
await db.message.insert_one(
{
"message": "пизда",
"chat_id": 1,
"user_id": 1,
"image_url": "",
"created_at": "2024-06-13 13:49:59.14757+00",
"visibility": True,
}
)

60
app/dao/message.py Normal file
View file

@ -0,0 +1,60 @@
from datetime import datetime, UTC
from uuid import uuid4, UUID
from app.chat.shemas import SMessageRaw, SMessageRawList
class MessageDAO:
def __init__(self, mongo_db):
self.message = mongo_db.message
async def send_message(
self,
user_id: int,
chat_id: int,
message: str | None,
image_url: str | None,
answer_id: UUID | None = None,
answer_message: str | None = None,
answer_image_url: str | None = None,
) -> UUID:
message_id = uuid4()
await self.message.insert_one(
{
"id": str(message_id),
"message": message,
"image_url": image_url,
"chat_id": chat_id,
"user_id": user_id,
"created_at": datetime.now(UTC),
"answer_id": str(answer_id) if answer_id else None,
"answer_message": answer_message,
"answer_image_url": answer_image_url,
"visibility": True,
}
)
return message_id
async def get_message_by_id(self, message_id: UUID) -> SMessageRaw:
message = self.message.find_one({"id": str(message_id)})
return SMessageRaw.model_validate(message)
async def delete_message(self, message_id: UUID) -> None:
await self.message.update_one({"id": str(message_id)}, {"$set": {"visibility": False}})
async def get_some_messages(self, chat_id: int, message_number_from: int, messages_to_get: int) -> SMessageRawList:
cursor = self.message.find({"visibility": True, "chat_id": chat_id})
cursor.sort("created_at").skip(message_number_from)
return SMessageRawList.model_validate({"message_raw_list": await cursor.to_list(length=messages_to_get)})
async def edit_message(self, message_id: UUID, new_message: str, new_image_url: str) -> None:
await self.message.update_one(
{"id": str(message_id)},
{"$set": {"message": new_message, "image_url": new_image_url}}
)
async def get_messages_from_ids(self, messages_ids: list[UUID]) -> SMessageRawList:
cursor = self.message.find({"visibility": True, "id": {"$in": [str(message_id) for message_id in messages_ids]}})
return SMessageRawList.model_validate({"message_raw_list": await cursor.to_list(length=None) or None})

View file

@ -1,4 +1,4 @@
from pydantic import HttpUrl, EmailStr
from pydantic import EmailStr
from sqlalchemy import update, select, insert, func, or_
from sqlalchemy.exc import MultipleResultsFound, IntegrityError, NoResultFound

View file

@ -29,5 +29,5 @@ MONGO_URL = f"mongodb://{settings.MONGO_HOST}:{settings.MONGO_PORT}"
mongo_client = AsyncIOMotorClient(MONGO_URL)
db = mongo_client.test_db
mongo_db = mongo_client.test_db

View file

@ -6,9 +6,7 @@ from alembic import context
from app.database import DATABASE_URL, Base
from app.models.users import Users # noqa
from app.models.message_answer import MessageAnswer # noqa
from app.models.chat import Chat # noqa
from app.models.message import Message # noqa
from app.models.pinned_chat import PinnedChat # noqa
from app.models.pinned_message import PinnedMessage # noqa
from app.models.user_chat import UserChat # noqa

View file

@ -1,8 +1,8 @@
"""Database Creation
Revision ID: c69369209bab
Revision ID: 53fd6f2b93a4
Revises:
Create Date: 2024-06-13 18:40:03.297322
Create Date: 2024-08-20 12:30:52.435668
"""
from typing import Sequence, Union
@ -12,7 +12,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'c69369209bab'
revision: str = '53fd6f2b93a4'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@ -52,18 +52,6 @@ def upgrade() -> None:
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('message',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('message', sa.String(), nullable=True),
sa.Column('image_url', sa.String(), nullable=True),
sa.Column('chat_id', sa.Integer(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('visibility', sa.Boolean(), server_default='true', nullable=False),
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('pinned_chat',
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('chat_id', sa.Integer(), nullable=False),
@ -71,6 +59,14 @@ def upgrade() -> None:
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('user_id', 'chat_id')
)
op.create_table('pinned_message',
sa.Column('chat_id', sa.Integer(), nullable=True),
sa.Column('message_id', sa.Uuid(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('message_id', 'user_id')
)
op.create_table('user_chat',
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('chat_id', sa.Integer(), nullable=False),
@ -78,32 +74,14 @@ def upgrade() -> None:
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('user_id', 'chat_id')
)
op.create_table('message_answer',
sa.Column('self_id', sa.Integer(), nullable=False),
sa.Column('answer_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['answer_id'], ['message.id'], ),
sa.ForeignKeyConstraint(['self_id'], ['message.id'], ),
sa.PrimaryKeyConstraint('self_id')
)
op.create_table('pinned_message',
sa.Column('chat_id', sa.Integer(), nullable=True),
sa.Column('message_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ),
sa.ForeignKeyConstraint(['message_id'], ['message.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('message_id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('pinned_message')
op.drop_table('message_answer')
op.drop_table('user_chat')
op.drop_table('pinned_message')
op.drop_table('pinned_chat')
op.drop_table('message')
op.drop_table('user_avatar')
op.drop_table('chat')
op.drop_table('users')

View file

@ -1,18 +0,0 @@
from datetime import datetime
from sqlalchemy import ForeignKey, func, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class Message(Base):
__tablename__ = "message"
id: Mapped[int] = mapped_column(primary_key=True)
chat_id = mapped_column(ForeignKey("chat.id"))
user_id = mapped_column(ForeignKey("users.id"))
message: Mapped[str | None]
image_url: Mapped[str | None]
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
visibility: Mapped[bool] = mapped_column(server_default="true")

View file

@ -1,11 +0,0 @@
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class MessageAnswer(Base):
__tablename__ = "message_answer"
self_id: Mapped[int] = mapped_column(ForeignKey("message.id"), primary_key=True)
answer_id: Mapped[int] = mapped_column(ForeignKey("message.id"))

View file

@ -1,5 +1,7 @@
from uuid import UUID
from sqlalchemy import ForeignKey
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import mapped_column, Mapped
from app.database import Base
@ -8,5 +10,5 @@ class PinnedMessage(Base):
__tablename__ = "pinned_message"
chat_id = mapped_column(ForeignKey("chat.id"))
message_id = mapped_column(ForeignKey("message.id"), primary_key=True)
user_id = mapped_column(ForeignKey("users.id"))
message_id: Mapped[UUID] = mapped_column(primary_key=True)
user_id = mapped_column(ForeignKey("users.id"), primary_key=True)

View file

@ -126,5 +126,4 @@ class ChatService:
async def get_pinned_chats(uow: UnitOfWork, user_id: int) -> SPinnedChats:
async with uow:
pinned_chats = await uow.chat.get_pinned_chats(user_id=user_id)
await uow.chat.test_mongo()
return pinned_chats

View file

@ -1,53 +1,92 @@
from app.chat.shemas import SMessage, SSendMessage, SPinnedMessages, SMessageList, SMessageRaw
from datetime import datetime, UTC, timedelta
from app.chat.shemas import SMessage, SSendMessage, SPinnedMessages, SMessageList, SMessageRaw, SMessageRawList
from app.services.user_service import UserService
from app.users.schemas import SUser
from app.utils.unit_of_work import UnitOfWork
class MessageService:
@staticmethod
async def add_avatar_image_and_username_to_message(uow: UnitOfWork, message: SMessageRaw) -> SMessage:
users: dict[int, tuple[SUser, datetime]] = {}
@classmethod
async def _get_cached_user(cls, uow: UnitOfWork, user_id: int) -> SUser:
if user_id in cls.users and cls.users[user_id][1] > datetime.now(UTC):
return cls.users[user_id][0]
user = await UserService.find_user(uow=uow, id=user_id)
cls.users[user_id] = user, datetime.now(UTC) + timedelta(minutes=5)
return user
@classmethod
async def add_avatar_image_and_username_to_message(cls, uow: UnitOfWork, message: SMessageRaw) -> SMessage:
user = await cls._get_cached_user(uow=uow, user_id=message.user_id)
message = message.model_dump()
user = await UserService.find_user(uow=uow, id=message["user_id"])
message["id"] = message["_id"]
message["avatar_image"] = user.avatar_image
message["avatar_image"] = str(user.avatar_image)
message["username"] = user.username
return SMessage.model_validate(message)
@staticmethod
@classmethod
async def add_avatar_image_and_username_to_message_list(
cls, uow: UnitOfWork, messages: SMessageRawList
) -> SMessageList:
return SMessageList.model_validate(
{"messages": [
await cls.add_avatar_image_and_username_to_message(uow=uow, message=message)
for message in messages.message_raw_list
] if messages.message_raw_list else None
}
)
@classmethod
async def send_message(
uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage, image_url: str | None = None
cls, uow: UnitOfWork, user_id: int, chat_id: int, message: SSendMessage, image_url: str | None = None
) -> SMessage:
async with uow:
message_id = await uow.chat.send_message(
user_id=user_id, chat_id=chat_id, message=message.message, image_url=image_url
)
if message.answer:
await uow.chat.add_answer(self_id=message_id, answer_id=message.answer)
new_message = await uow.chat.get_message_by_id(message_id=message_id)
await uow.commit()
answer_message = await uow.message.get_message_by_id(message_id=message.answer)
message_id = await uow.message.send_message(
user_id=user_id,
chat_id=chat_id,
message=message.message,
image_url=image_url,
answer_id=answer_message.id,
answer_message=answer_message.message,
answer_image_url=answer_message.answer_image_url,
)
else:
message_id = await uow.message.send_message(
user_id=user_id,
chat_id=chat_id,
message=message.message,
image_url=image_url,
)
raw_message = await uow.message.get_message_by_id(message_id=message_id)
new_message = await cls.add_avatar_image_and_username_to_message(uow=uow, message=raw_message)
return new_message
@staticmethod
async def delete_message(uow: UnitOfWork, message_id: int) -> None:
async with uow:
await uow.chat.delete_message(message_id=message_id)
await uow.commit()
await uow.message.delete_message(message_id=message_id)
@staticmethod
async def edit_message(uow: UnitOfWork, message_id: int, new_message: str, new_image_url: str) -> None:
async with uow:
await uow.chat.edit_message(
await uow.message.edit_message(
message_id=message_id, new_message=new_message, new_image_url=new_image_url
)
await uow.commit()
@staticmethod
async def pin_message(uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage:
@classmethod
async def pin_message(cls, uow: UnitOfWork, chat_id: int, user_id: int, message_id: int) -> SMessage:
async with uow:
await uow.chat.pin_message(chat_id=chat_id, message_id=message_id, user_id=user_id)
pinned_message = await uow.chat.get_message_by_id(message_id=message_id)
await uow.commit()
raw_message = await uow.message.get_message_by_id(message_id=message_id)
pinned_message = await cls.add_avatar_image_and_username_to_message(uow=uow, message=raw_message)
return pinned_message
@staticmethod
@ -56,24 +95,28 @@ class MessageService:
await uow.chat.unpin_message(chat_id=chat_id, message_id=message_id)
await uow.commit()
@staticmethod
async def get_message_by_id(uow: UnitOfWork, message_id: int) -> SMessage:
@classmethod
async def get_message_by_id(cls, uow: UnitOfWork, message_id: int) -> SMessage:
async with uow:
message = await uow.chat.get_message_by_id(message_id=message_id)
raw_message = await uow.message.get_message_by_id(message_id=message_id)
message = await cls.add_avatar_image_and_username_to_message(uow=uow, message=raw_message)
return message
@staticmethod
async def get_pinned_messages(uow: UnitOfWork, chat_id: int) -> SPinnedMessages:
@classmethod
async def get_pinned_messages(cls, uow: UnitOfWork, chat_id: int) -> SPinnedMessages:
async with uow:
pinned_messages = await uow.chat.get_pinned_messages(chat_id=chat_id)
return pinned_messages
pinned_messages_ids = await uow.chat.get_pinned_messages_ids(chat_id=chat_id)
raw_messages = await uow.message.get_messages_from_ids(messages_ids=pinned_messages_ids)
pinned_messages = await cls.add_avatar_image_and_username_to_message_list(uow=uow, messages=raw_messages)
return SPinnedMessages.model_validate({"pinned_messages": pinned_messages.messages})
@staticmethod
@classmethod
async def get_some_messages(
uow: UnitOfWork, chat_id: int, message_number_from: int, messages_to_get: int
cls, uow: UnitOfWork, chat_id: int, message_number_from: int, messages_to_get: int
) -> SMessageList:
async with uow:
message = await uow.chat.get_some_messages(
messages = await uow.message.get_some_messages(
chat_id=chat_id, message_number_from=message_number_from, messages_to_get=messages_to_get
)
return message
messages = await cls.add_avatar_image_and_username_to_message_list(uow=uow, messages=messages)
return messages

View file

@ -112,7 +112,7 @@ class UserService:
avatar_image=str(user_data.avatar_url or user.avatar_image),
hashed_password=hashed_password
)
(await uow.user.add_user_avatar(user_id=user.id, avatar=str(user_data.avatar_url))) if user_data.avatar_url else None
(await uow.user.add_user_avatar(user_id=user.id, avatar=str(user_data.avatar_url))) if user_data.avatar_url else None # noqa
await uow.commit()
async with RedisService() as redis:

View file

@ -1,7 +1,8 @@
from sqlalchemy.exc import SQLAlchemyError
from app.dao.chat import ChatDAO
from app.database import async_session_maker
from app.dao.message import MessageDAO
from app.database import async_session_maker, mongo_db
from app.dao.user import UserDAO
from app.exceptions import BlackPhoenixException
@ -9,12 +10,14 @@ from app.exceptions import BlackPhoenixException
class UnitOfWork:
def __init__(self):
self.session_factory = async_session_maker
self.mongo_db = mongo_db
async def __aenter__(self):
self.session = self.session_factory()
self.user = UserDAO(self.session)
self.chat = ChatDAO(self.session)
self.message = MessageDAO(self.mongo_db)
async def __aexit__(self, exc_type, exc, tb):
try: