73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
from typing import Annotated
|
|
|
|
from fastapi import Depends, Header
|
|
from fastapi.security import HTTPBearer
|
|
from jose import JWTError, jwt, ExpiredSignatureError
|
|
|
|
from app.config import settings
|
|
from app.exceptions import IncorrectTokenFormatException, TokenMissingException, TokenExpiredException
|
|
|
|
from app.services import user_service
|
|
from app.utils.unit_of_work import UnitOfWork
|
|
from app.users.schemas import SUser
|
|
from app.users.exceptions import UserNotFoundException, UserMustConfirmEmailException
|
|
|
|
auth_schema = HTTPBearer()
|
|
|
|
|
|
def get_token(token=Depends(auth_schema)) -> str:
|
|
if not token:
|
|
raise TokenMissingException
|
|
return token.credentials
|
|
|
|
|
|
async def get_current_user(token: str = Depends(get_token), uow=Depends(UnitOfWork)) -> SUser:
|
|
try:
|
|
payload = jwt.decode(token, settings.SECRET_KEY, settings.ALGORITHM)
|
|
except ExpiredSignatureError:
|
|
raise TokenExpiredException
|
|
except JWTError:
|
|
raise IncorrectTokenFormatException
|
|
|
|
user_id: str = payload.get("sub")
|
|
if not user_id:
|
|
raise UserNotFoundException
|
|
|
|
user = await user_service.find_user(uow=uow, id=int(user_id))
|
|
|
|
return user
|
|
|
|
|
|
async def get_verificated_user(user: SUser = Depends(get_current_user)) -> SUser:
|
|
if user.role < settings.VERIFICATED_USER:
|
|
raise UserMustConfirmEmailException
|
|
return user
|
|
|
|
|
|
def get_token_ws(sec_websocket_protocol: Annotated[str | None, Header()] = None) -> str:
|
|
if sec_websocket_protocol is None:
|
|
raise TokenMissingException
|
|
return sec_websocket_protocol.split()[-1]
|
|
|
|
|
|
async def get_current_user_ws(token: str = Depends(get_token_ws), uow=Depends(UnitOfWork)) -> SUser:
|
|
try:
|
|
payload = jwt.decode(token, settings.SECRET_KEY, settings.ALGORITHM)
|
|
except ExpiredSignatureError:
|
|
raise TokenExpiredException
|
|
except JWTError:
|
|
raise IncorrectTokenFormatException
|
|
|
|
user_id: str = payload.get("sub")
|
|
if not user_id:
|
|
raise UserNotFoundException
|
|
|
|
user = await user_service.find_user(uow=uow, id=int(user_id))
|
|
|
|
return user
|
|
|
|
|
|
async def get_subprotocol_ws(sec_websocket_protocol: Annotated[str | None, Header()] = None) -> str:
|
|
if sec_websocket_protocol is None:
|
|
raise TokenMissingException
|
|
return sec_websocket_protocol.split(",")[0]
|