chat_back/app/dependencies.py

73 lines
2.1 KiB
Python

from typing import Annotated
from fastapi import Depends, Header
from fastapi.security import HTTPBearer
from jose import JWTError, jwt, ExpiredSignatureError
from app.config import settings
from app.exceptions import IncorrectTokenFormatException, TokenMissingException, TokenExpiredException
from app.services import user_service
from app.utils.unit_of_work import UnitOfWork
from app.users.schemas import SUser
from app.users.exceptions import UserNotFoundException, UserMustConfirmEmailException
auth_schema = HTTPBearer()
def get_token(token=Depends(auth_schema)) -> str:
if not token:
raise TokenMissingException
return token.credentials
async def get_current_user(token: str = Depends(get_token), uow=Depends(UnitOfWork)) -> SUser:
try:
payload = jwt.decode(token, settings.SECRET_KEY, settings.ALGORITHM)
except ExpiredSignatureError:
raise TokenExpiredException
except JWTError:
raise IncorrectTokenFormatException
user_id: str = payload.get("sub")
if not user_id:
raise UserNotFoundException
user = await user_service.find_user(uow=uow, id=int(user_id))
return user
async def get_verificated_user(user: SUser = Depends(get_current_user)) -> SUser:
if user.role < settings.VERIFICATED_USER:
raise UserMustConfirmEmailException
return user
def get_token_ws(sec_websocket_protocol: Annotated[str | None, Header()] = None) -> str:
if sec_websocket_protocol is None:
raise TokenMissingException
return sec_websocket_protocol.split()[-1]
async def get_current_user_ws(token: str = Depends(get_token_ws), uow=Depends(UnitOfWork)) -> SUser:
try:
payload = jwt.decode(token, settings.SECRET_KEY, settings.ALGORITHM)
except ExpiredSignatureError:
raise TokenExpiredException
except JWTError:
raise IncorrectTokenFormatException
user_id: str = payload.get("sub")
if not user_id:
raise UserNotFoundException
user = await user_service.find_user(uow=uow, id=int(user_id))
return user
async def get_subprotocol_ws(sec_websocket_protocol: Annotated[str | None, Header()] = None) -> str:
if sec_websocket_protocol is None:
raise TokenMissingException
return sec_websocket_protocol.split(",")[0]