diff --git a/app/chat/websocket.py b/app/chat/websocket.py index 762889f..e618e8b 100644 --- a/app/chat/websocket.py +++ b/app/chat/websocket.py @@ -24,8 +24,8 @@ class ConnectionManager: "unpin": self.unpin, } - async def connect(self, chat_id: int, websocket: WebSocket): - await websocket.accept() + async def connect(self, chat_id: int, websocket: WebSocket, subprotocol: str | None = None): + await websocket.accept(subprotocol=subprotocol) if chat_id not in self.active_connections: self.active_connections[chat_id] = [] self.active_connections[chat_id].append(websocket) @@ -136,11 +136,12 @@ async def websocket_endpoint( chat_id: int, websocket: WebSocket, user: SUser = Depends(get_current_user_ws), - uow=Depends(UnitOfWork) + subprotocol: str = Depends(get_subprotocol), + uow=Depends(UnitOfWork), ): await AuthService.check_verificated_user_with_exc(uow=uow, user_id=user.id) await AuthService.validate_user_access_to_chat(uow=uow, user_id=user.id, chat_id=chat_id) - await manager.connect(chat_id, websocket) + await manager.connect(chat_id, websocket, subprotocol) try: while True: data = await websocket.receive_json() diff --git a/app/main.py b/app/main.py index b4f45b3..5e223fc 100644 --- a/app/main.py +++ b/app/main.py @@ -55,5 +55,5 @@ class AddHeaderMiddleware(BaseHTTPMiddleware): return response -app.add_middleware(AddHeaderMiddleware) +# app.add_middleware(AddHeaderMiddleware) diff --git a/app/users/dependencies.py b/app/users/dependencies.py index 9d11feb..6cc1635 100644 --- a/app/users/dependencies.py +++ b/app/users/dependencies.py @@ -54,6 +54,7 @@ async def check_verificated_user_with_exc(user: SUser = Depends(get_current_user def get_token_ws(sec_websocket_protocol: Annotated[str | None, Header()] = None) -> str: if sec_websocket_protocol is None: raise TokenAbsentException + logging.critical(sec_websocket_protocol) return sec_websocket_protocol @@ -74,3 +75,6 @@ async def get_current_user_ws(token: str = Depends(get_token_ws), uow=Depends(Un raise UserIsNotPresentException return user + +async def get_subprotocol(): + pass