from typing import Annotated from fastapi import Depends, Header from fastapi.security import HTTPBearer from jose import JWTError, jwt, ExpiredSignatureError from app.config import settings from app.exceptions import IncorrectTokenFormatException, TokenMissingException, TokenExpiredException from app.services import user_service from app.utils.unit_of_work import UnitOfWork from app.users.schemas import SUser from app.users.exceptions import UserNotFoundException, UserMustConfirmEmailException auth_schema = HTTPBearer() def get_token(token=Depends(auth_schema)) -> str: if not token: raise TokenMissingException return token.credentials async def get_current_user(token: str = Depends(get_token), uow=Depends(UnitOfWork)) -> SUser: try: payload = jwt.decode(token, settings.SECRET_KEY, settings.ALGORITHM) except ExpiredSignatureError: raise TokenExpiredException except JWTError: raise IncorrectTokenFormatException user_id: str = payload.get("sub") if not user_id: raise UserNotFoundException user = await user_service.find_user(uow=uow, id=int(user_id)) return user async def get_verificated_user(user: SUser = Depends(get_current_user)) -> SUser: if user.role < settings.VERIFICATED_USER: raise UserMustConfirmEmailException return user def get_token_ws(sec_websocket_protocol: Annotated[str | None, Header()] = None) -> str: if sec_websocket_protocol is None: raise TokenMissingException return sec_websocket_protocol.split()[-1] async def get_current_user_ws(token: str = Depends(get_token_ws), uow=Depends(UnitOfWork)) -> SUser: try: payload = jwt.decode(token, settings.SECRET_KEY, settings.ALGORITHM) except ExpiredSignatureError: raise TokenExpiredException except JWTError: raise IncorrectTokenFormatException user_id: str = payload.get("sub") if not user_id: raise UserNotFoundException user = await user_service.find_user(uow=uow, id=int(user_id)) return user async def get_subprotocol_ws(sec_websocket_protocol: Annotated[str | None, Header()] = None) -> str: if sec_websocket_protocol is None: raise TokenMissingException return sec_websocket_protocol.split(",")[0]