"""This module handles CRUD operations for users in the database, based on pydanctic schemas.""" from datetime import datetime from sqlalchemy.orm import Session from backend.models import users as usermodel from backend.schemas import users as userschema from backend.exceptions import NotFoundException def hash_password(password: str) -> str: """This is a placeholder for a secure password hashing algorithm. It will convert a plaintext password into a secure, salted hash, for storage in the database. """ # TODO actually hash the password! return password def _fill_missing_user_fields(db_user: usermodel.User) -> userschema.User: """Fills all the fields of an instance of userschema.User that cannot be filled by pydantic. This function is necessary because the userschema is not a one-to-one reflection of the database data model. I did not want the 'patient' and 'administrator' database table to be encoded as their own top level JSON keys in serialized user object. Instead, the user schema combines all fields from all user types. This function fills the optional fields, depending on what type of user is passed in. """ full_user = userschema.User.from_orm(db_user) if db_user.patient: full_user.devices = db_user.patient.devices full_user.gender = db_user.patient.gender full_user.date_of_birth = db_user.patient.date_of_birth full_user.is_patient = True full_user.is_admin = False else: full_user.is_patient = False full_user.is_admin = True return full_user def create_user(db: Session, user: userschema.UserCreate) -> userschema.User: """Creates the specified user in the database.""" db_user = usermodel.User( email=user.email, first_name=user.first_name, last_name=user.last_name, password=hash_password(user.password), ) if user.is_patient: db_patient = usermodel.Patient( user=db_user, gender=user.gender, date_of_birth=user.date_of_birth, ) db.add(db_patient) else: db_administrator = usermodel.Administrator( user=db_user, ) db.add(db_administrator) db.commit() db.refresh(db_user) return _fill_missing_user_fields(db_user) def read_user(db: Session, id: int) -> userschema.User | None: """Queries the db for a user with the specified id and returns them.""" db_user = db.query(usermodel.User).filter(usermodel.User.id == id).first() if not db_user: raise NotFoundException(f"User with id '{id}' not found.") return _fill_missing_user_fields(db_user) def read_user_by_email(db: Session, email: str) -> userschema.User | None: """Queries the db for a user with the specified email and returns them.""" db_user = db.query(usermodel.User).filter(usermodel.User.email == email).first() if not db_user: raise NotFoundException(f"User with email '{email}' not found.") return _fill_missing_user_fields(db_user) def read_users(db: Session, skip: int = 0, limit: int = 100) -> list[userschema.User]: """Returns an unfiltered range (by id) of users in the database.""" db_users = db.query(usermodel.User).offset(skip).limit(limit).all() full_users = [] for db_user in db_users: full_users.append(_fill_missing_user_fields(db_user)) return full_users def update_user(db: Session, user: userschema.UserUpdate, id: int) -> userschema.User: """Updates the user with the provided id with all non-None fields from the input user.""" db_user = db.query(usermodel.User).filter(usermodel.User.id == id).first() if not db_user: raise NotFoundException(f"User with id '{id}' not found.") for key in ['gender', 'date_of_birth']: value = getattr(user, key) if value is not None: setattr(db_user.patient, key, value) for key in ['email', 'first_name', 'last_name']: value = getattr(user, key) if value is not None: setattr(db_user, key, value) if user.password is not None: db_user.password = hash_password(user.password) db.commit() db.refresh(db_user) return _fill_missing_user_fields(db_user) def delete_user(db: Session, id: int) -> userschema.User: """Deletes the user with the provided id from the db.""" db_user = db.query(usermodel.User).filter(usermodel.User.id == id).first() if not db_user: raise NotFoundException(f"User with id '{id}' not found.") user_copy = _fill_missing_user_fields(db_user) db.delete(db_user) db.commit() user_copy.updated = datetime.now(user_copy.updated.tzinfo) return user_copy