115 lines
3.1 KiB
Python
115 lines
3.1 KiB
Python
import logging
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from .import models, schemas
|
|
|
|
log = logging.getLogger()
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
# TODO actually hash the password!
|
|
return password
|
|
|
|
|
|
def _fill_missing_user_fields(db_user: models.User) -> schemas.User:
|
|
full_user = schemas.User.from_orm(db_user)
|
|
if db_user.patient:
|
|
full_user.gender = db_user.patient.gender
|
|
full_user.date_of_birth = db_user.patient.date_of_birth
|
|
full_user.is_patient = True
|
|
full_user.is_admin = False
|
|
else:
|
|
full_user.is_patient = False
|
|
full_user.is_admin = True
|
|
|
|
return full_user
|
|
|
|
|
|
def create_user(db: Session, user: schemas.UserCreate):
|
|
"""Creates a new user as either a patient or an administrator."""
|
|
|
|
db_user = models.User(
|
|
email=user.email,
|
|
first_name=user.first_name,
|
|
last_name=user.last_name,
|
|
password=hash_password(user.password),
|
|
)
|
|
|
|
# Add user to database
|
|
if user.is_patient:
|
|
db_patient = models.Patient(
|
|
user=db_user,
|
|
gender=user.gender,
|
|
date_of_birth=user.date_of_birth,
|
|
)
|
|
db.add(db_patient)
|
|
else:
|
|
db_administrator = models.Administrator(
|
|
user=db_user,
|
|
)
|
|
db.add(db_administrator)
|
|
|
|
db.commit()
|
|
|
|
# Construct the updated user to return
|
|
db.refresh(db_user)
|
|
return _fill_missing_user_fields(db_user)
|
|
|
|
|
|
def read_user(db: Session, id: int):
|
|
db_user = db.query(models.User).filter(models.User.id == id).first()
|
|
if not db_user:
|
|
return None
|
|
|
|
return _fill_missing_user_fields(db_user)
|
|
|
|
|
|
def read_user_by_email(db: Session, email: str):
|
|
db_user = db.query(models.User).filter(models.User.email == email).first()
|
|
if not db_user:
|
|
return None
|
|
|
|
return _fill_missing_user_fields(db_user)
|
|
|
|
|
|
def read_users(db: Session, skip: int = 0, limit: int = 100):
|
|
db_users = db.query(models.User).offset(skip).limit(limit).all()
|
|
|
|
full_users = []
|
|
for db_user in db_users:
|
|
full_users.append(_fill_missing_user_fields(db_user))
|
|
return full_users
|
|
|
|
|
|
def update_user(db: Session, user: schemas.UserUpdate, id: int):
|
|
db_user = db.query(models.User).filter(models.User.id == id).first()
|
|
current_user = _fill_missing_user_fields(db_user)
|
|
|
|
for key in ['gender', 'date_of_birth']:
|
|
value = getattr(user, key)
|
|
if value is not None:
|
|
setattr(db_user.patient, key, value)
|
|
for key in ['email', 'first_name', 'last_name']:
|
|
value = getattr(user, key)
|
|
if value is not None:
|
|
setattr(db_user, key, value)
|
|
if user.password is not None:
|
|
db_user.password = hash_password(user.password)
|
|
|
|
db.commit()
|
|
db.refresh(db_user)
|
|
return _fill_missing_user_fields(db_user)
|
|
|
|
|
|
def delete_user(db: Session, id: int):
|
|
db_user = db.query(models.User).filter(models.User.id == id).first()
|
|
user_copy = _fill_missing_user_fields(db_user)
|
|
|
|
db.delete(db_user)
|
|
db.commit()
|
|
|
|
user_copy.updated = datetime.now(user_copy.updated.tzinfo)
|
|
return user_copy
|