100 lines
3.5 KiB
Python
100 lines
3.5 KiB
Python
|
from datetime import datetime, timedelta
|
||
|
|
||
|
import jwt
|
||
|
from fastapi import Depends, HTTPException, Security
|
||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||
|
from passlib.context import CryptContext
|
||
|
from sqlalchemy.orm import Session
|
||
|
|
||
|
from todo.database.engine import get_db
|
||
|
from todo.config import get_settings
|
||
|
from todo.schemas.common import is_valid_id
|
||
|
from todo.schemas.users import User
|
||
|
import todo.crud.users as userscrud
|
||
|
from todo.utils.exceptions import NotFoundException
|
||
|
|
||
|
|
||
|
class AuthHandler():
|
||
|
"""This class handles operations related to authentication and authorization."""
|
||
|
|
||
|
security = HTTPBearer()
|
||
|
crypt_context = CryptContext(schemes=['bcrypt'], deprecated=['auto'])
|
||
|
jwt_secret = get_settings().jwt_secret
|
||
|
|
||
|
def hash_password(self, input_str: str) -> str:
|
||
|
"""Returns a salted hash of the input string.
|
||
|
|
||
|
Used to hash plaintext passwords prior to storage.
|
||
|
"""
|
||
|
|
||
|
if not isinstance(input_str, str):
|
||
|
raise TypeError("Expected a string.")
|
||
|
if not len(input_str):
|
||
|
raise ValueError("Input string cannot be empty.")
|
||
|
|
||
|
return self.crypt_context.hash(input_str)
|
||
|
|
||
|
|
||
|
def verify_password(self, plaintext: str, hash: str) -> bool:
|
||
|
"""Checks whether the hashed plaintext password matches the provided hash.
|
||
|
|
||
|
Used to compare a provided plaintext password to a stored hash during authentication.
|
||
|
"""
|
||
|
|
||
|
for input_str in [plaintext, hash]:
|
||
|
if not isinstance(input_str, str):
|
||
|
raise TypeError("Expected a string.")
|
||
|
if not len(input_str):
|
||
|
raise ValueError("Input string cannot be empty.")
|
||
|
|
||
|
return self.crypt_context.verify(plaintext, hash)
|
||
|
|
||
|
|
||
|
def encode_token(self, user_id: int) -> str:
|
||
|
"""Creates a fresh JWT containing the specified user id."""
|
||
|
|
||
|
if not is_valid_id(user_id):
|
||
|
raise ValueError("Invalid user ID.")
|
||
|
|
||
|
payload = {
|
||
|
'exp': datetime.utcnow() + timedelta(seconds=get_settings().jwt_expiration_time),
|
||
|
'iat': datetime.utcnow(),
|
||
|
'sub': user_id
|
||
|
}
|
||
|
|
||
|
return jwt.encode(
|
||
|
payload,
|
||
|
self.jwt_secret,
|
||
|
algorithm='HS256',
|
||
|
)
|
||
|
|
||
|
|
||
|
def decode_token(self, token: str) -> int:
|
||
|
"""Decodes the input JWT token and returns the user id stored in it.
|
||
|
|
||
|
If the token's signature does not match its contents, or if the token has
|
||
|
expired, an HTTPException is raised.
|
||
|
"""
|
||
|
|
||
|
try:
|
||
|
payload = jwt.decode(token, self.jwt_secret, algorithms=['HS256'])
|
||
|
return payload['sub']
|
||
|
except jwt.ExpiredSignatureError:
|
||
|
raise HTTPException(401, "JWT signature has expired.")
|
||
|
except jwt.InvalidTokenError:
|
||
|
raise HTTPException(401, "JWT token is invalid.")
|
||
|
|
||
|
|
||
|
def get_current_user_id(self, auth: HTTPAuthorizationCredentials) -> int :
|
||
|
return self.decode_token(auth.credentials)
|
||
|
|
||
|
|
||
|
def get_current_user(self, auth: HTTPAuthorizationCredentials = Security(security), db: Session = Depends(get_db)) -> User:
|
||
|
user_id = self.decode_token(auth.credentials)
|
||
|
return userscrud.read_user(db, user_id)
|
||
|
|
||
|
def asset_current_user_is_admin(self, auth: HTTPAuthorizationCredentials = Security(security), db: Session = Depends(get_db)):
|
||
|
user_id = self.decode_token(auth.credentials)
|
||
|
if not userscrud.read_user(db, user_id).is_admin:
|
||
|
raise HTTPException(403, "You are not authorized to perform this action.")
|