from typing import Annotated from fastapi import Depends, Header from fastapi.security import HTTPBearer from jose import ExpiredSignatureError, JWTError, jwt from app.config import settings from app.exceptions import ( IncorrectTokenFormatException, TokenExpiredException, TokenMissingException, ) from app.services import user_service from app.users.exceptions import UserMustConfirmEmailException, UserNotFoundException from app.users.schemas import SUser from app.utils.unit_of_work import UnitOfWork auth_schema = HTTPBearer() def get_token(token=Depends(auth_schema)) -> str: if not token: raise TokenMissingException return token.credentials async def get_current_user(token: str = Depends(get_token), uow=Depends(UnitOfWork)) -> SUser: try: payload = jwt.decode(token, settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM) except ExpiredSignatureError: raise TokenExpiredException except JWTError: raise IncorrectTokenFormatException user_id: str = payload.get("sub") if not user_id: raise UserNotFoundException user = await user_service.find_user(uow=uow, id=int(user_id)) return user async def get_verificated_user(user: SUser = Depends(get_current_user)) -> SUser: if user.role < settings.VERIFICATED_USER: raise UserMustConfirmEmailException return user def get_token_ws(sec_websocket_protocol: Annotated[str | None, Header()] = None) -> str: if sec_websocket_protocol is None: raise TokenMissingException return sec_websocket_protocol.split()[-1] async def get_current_user_ws(token: str = Depends(get_token_ws), uow=Depends(UnitOfWork)) -> SUser: try: payload = jwt.decode(token, settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM) except ExpiredSignatureError: raise TokenExpiredException except JWTError: raise IncorrectTokenFormatException user_id: str = payload.get("sub") if not user_id: raise UserNotFoundException user = await user_service.find_user(uow=uow, id=int(user_id)) return user async def get_subprotocol_ws(sec_websocket_protocol: Annotated[str | None, Header()] = None) -> str: if sec_websocket_protocol is None: raise TokenMissingException return sec_websocket_protocol.split(",")[0]