chat_back/app/users/dependencies.py
2024-06-07 14:54:50 +05:00

80 lines
2.1 KiB
Python

import logging
from typing import Annotated
from fastapi import Depends, Header
from fastapi.security import HTTPBearer
from jose import JWTError, jwt, ExpiredSignatureError
from app.config import settings
from app.exceptions import (
IncorrectTokenFormatException,
TokenAbsentException,
TokenExpiredException,
UserIsNotPresentException,
UserMustConfirmEmailException,
)
from app.services.user_service import UserService
from app.unit_of_work import UnitOfWork
from app.users.schemas import SUser
auth_schema = HTTPBearer()
def get_token(token=Depends(auth_schema)) -> str:
if not token:
raise TokenAbsentException
return token.credentials
async def get_current_user(token: str = Depends(get_token), uow=Depends(UnitOfWork)) -> SUser:
try:
payload = jwt.decode(token, settings.SECRET_KEY, settings.ALGORITHM)
except ExpiredSignatureError:
raise TokenExpiredException
except JWTError:
raise IncorrectTokenFormatException
user_id: str = payload.get("sub")
if not user_id:
raise UserIsNotPresentException
user = await UserService.find_one_or_none(uow=uow, user_id=int(user_id))
if not user:
raise UserIsNotPresentException
return user
async def check_verificated_user_with_exc(user: SUser = Depends(get_current_user)) -> SUser:
if not user.role >= settings.VERIFICATED_USER:
raise UserMustConfirmEmailException
return user
def get_token_ws(sec_websocket_protocol: Annotated[str | None, Header()] = None) -> str:
if sec_websocket_protocol is None:
raise TokenAbsentException
logging.critical(sec_websocket_protocol)
return sec_websocket_protocol
async def get_current_user_ws(token: str = Depends(get_token_ws), uow=Depends(UnitOfWork)):
try:
payload = jwt.decode(token, settings.SECRET_KEY, settings.ALGORITHM)
except ExpiredSignatureError:
raise TokenExpiredException
except JWTError:
raise IncorrectTokenFormatException
user_id: str = payload.get("sub")
if not user_id:
raise UserIsNotPresentException
user = await UserService.find_one_or_none(uow=uow, user_id=int(user_id))
if not user:
raise UserIsNotPresentException
return user
async def get_subprotocol():
pass