"""This module handles CRUD operations for users in the database, based on pydanctic schemas.""" from datetime import datetime from sqlalchemy.orm import Session from backend.models import users as usermodel from backend.schemas import users as userschema def hash_password(password: str) -> str: """This is a placeholder for a secure password hashing algorithm. It will convert a plaintext password into a secure, salted hash, for storage in the database. """ # TODO actually hash the password! return password def _fill_missing_user_fields(db_user: usermodel.User) -> userschema.User: """Fills all the fields of an instance of userschema.User that cannot be filled by pydantic. This function is necessary because the userschema is not a one-to-one reflection of the database data model. I did not want the 'patient' and 'administrator' database table to be encoded as their own top level JSON keys in serialized user object. Instead, the user schema combines all fields from all user types. This function fills the optional fields, depending on what type of user is passed in. """ full_user = userschema.User.from_orm(db_user) if db_user.patient: full_user.gender = db_user.patient.gender full_user.date_of_birth = db_user.patient.date_of_birth full_user.is_patient = True full_user.is_admin = False else: full_user.is_patient = False full_user.is_admin = True return full_user def create_user(db: Session, user: userschema.UserCreate) -> userschema.User: """Creates the specified user in the database.""" db_user = usermodel.User( email=user.email, first_name=user.first_name, last_name=user.last_name, password=hash_password(user.password), ) # Add user to database if user.is_patient: db_patient = usermodel.Patient( user=db_user, gender=user.gender, date_of_birth=user.date_of_birth, ) db.add(db_patient) else: db_administrator = usermodel.Administrator( user=db_user, ) db.add(db_administrator) db.commit() # Construct the updated user to return db.refresh(db_user) return _fill_missing_user_fields(db_user) def read_user(db: Session, id: int) -> userschema.User | None: """Queries the db for a user with the specified id and returns them if they exist.""" db_user = db.query(usermodel.User).filter(usermodel.User.id == id).first() if not db_user: return None return _fill_missing_user_fields(db_user) def read_user_by_email(db: Session, email: str) -> userschema.User | None: """Queries the db for a user with the specified email and returns them if they exist.""" db_user = db.query(usermodel.User).filter(usermodel.User.email == email).first() if not db_user: return None return _fill_missing_user_fields(db_user) def read_users(db: Session, skip: int = 0, limit: int = 100) -> list[userschema.User]: """Returns an unfiltered range (by id) of users in the database.""" db_users = db.query(usermodel.User).offset(skip).limit(limit).all() full_users = [] for db_user in db_users: full_users.append(_fill_missing_user_fields(db_user)) return full_users def update_user(db: Session, user: userschema.UserUpdate, id: int) -> userschema.User: """Updates the user with the provided id with all non-None fields from the input user.""" db_user = db.query(usermodel.User).filter(usermodel.User.id == id).first() if not db_user: raise RuntimeError("Query returned no user.") # should be checked by caller for key in ['gender', 'date_of_birth']: value = getattr(user, key) if value is not None: setattr(db_user.patient, key, value) for key in ['email', 'first_name', 'last_name']: value = getattr(user, key) if value is not None: setattr(db_user, key, value) if user.password is not None: db_user.password = hash_password(user.password) db.commit() db.refresh(db_user) return _fill_missing_user_fields(db_user) def delete_user(db: Session, id: int) -> userschema.User: """Deletes the user with the provided id from the db.""" db_user = db.query(usermodel.User).filter(usermodel.User.id == id).first() if not db_user: raise RuntimeError("Query returned no user.") # should be checked by caller user_copy = _fill_missing_user_fields(db_user) db.delete(db_user) db.commit() user_copy.updated = datetime.now(user_copy.updated.tzinfo) return user_copy