refactor(backend): subfolders and docstrings
This commit is contained in:
parent
f9e970afdb
commit
8b07b03ccd
@ -1,3 +1,9 @@
|
|||||||
|
"""This module provides global application settings.
|
||||||
|
|
||||||
|
All settings are read from environment variables, but defaults are provided below
|
||||||
|
if the respective envvar is unset.
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from urllib.parse import quote_plus as url_encode
|
from urllib.parse import quote_plus as url_encode
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
@ -15,16 +21,17 @@ class Settings(BaseSettings):
|
|||||||
app_name: str = os.getenv("APP_NAME", "MEDWingS")
|
app_name: str = os.getenv("APP_NAME", "MEDWingS")
|
||||||
admin_email: str = os.getenv("ADMIN_EMAIL", "admin@example.com")
|
admin_email: str = os.getenv("ADMIN_EMAIL", "admin@example.com")
|
||||||
|
|
||||||
|
# Debug mode has the following effects:
|
||||||
|
# - logs SQL operations
|
||||||
debug_mode: bool = False
|
debug_mode: bool = False
|
||||||
if os.getenv("DEBUG_MODE", "false").lower() == "true":
|
if os.getenv("DEBUG_MODE", "false").lower() == "true":
|
||||||
debug_mode = True
|
debug_mode = True
|
||||||
|
|
||||||
_pg_hostname = os.getenv("POSTGRES_HOST", "db")
|
pg_hostname = os.getenv("POSTGRES_HOST", "db")
|
||||||
_pg_port = os.getenv("POSTGRES_PORT", "5432")
|
pg_port = os.getenv("POSTGRES_PORT", "5432")
|
||||||
_pg_dbname = os.getenv("POSTGRES_DB", "medwings")
|
pg_dbname = os.getenv("POSTGRES_DB", "medwings")
|
||||||
_pg_user = url_encode(os.getenv("POSTGRES_USER", "medwings"))
|
pg_user = url_encode(os.getenv("POSTGRES_USER", "medwings"))
|
||||||
_pg_password = url_encode(os.getenv("POSTGRES_PASSWORD", "medwings"))
|
pg_password = url_encode(os.getenv("POSTGRES_PASSWORD", "medwings"))
|
||||||
pg_dsn: PostgresDsn = f"postgresql://{_pg_user}:{_pg_password}@{_pg_hostname}:{_pg_port}/{_pg_dbname}"
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
|
114
backend/crud.py
114
backend/crud.py
@ -1,114 +0,0 @@
|
|||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from .import models, schemas
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
|
||||||
# TODO actually hash the password!
|
|
||||||
return password
|
|
||||||
|
|
||||||
|
|
||||||
def _fill_missing_user_fields(db_user: models.User) -> schemas.User:
|
|
||||||
full_user = schemas.User.from_orm(db_user)
|
|
||||||
if db_user.patient:
|
|
||||||
full_user.gender = db_user.patient.gender
|
|
||||||
full_user.date_of_birth = db_user.patient.date_of_birth
|
|
||||||
full_user.is_patient = True
|
|
||||||
full_user.is_admin = False
|
|
||||||
else:
|
|
||||||
full_user.is_patient = False
|
|
||||||
full_user.is_admin = True
|
|
||||||
|
|
||||||
return full_user
|
|
||||||
|
|
||||||
|
|
||||||
def create_user(db: Session, user: schemas.UserCreate):
|
|
||||||
"""Creates a new user as either a patient or an administrator."""
|
|
||||||
|
|
||||||
db_user = models.User(
|
|
||||||
email=user.email,
|
|
||||||
first_name=user.first_name,
|
|
||||||
last_name=user.last_name,
|
|
||||||
password=hash_password(user.password),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add user to database
|
|
||||||
if user.is_patient:
|
|
||||||
db_patient = models.Patient(
|
|
||||||
user=db_user,
|
|
||||||
gender=user.gender,
|
|
||||||
date_of_birth=user.date_of_birth,
|
|
||||||
)
|
|
||||||
db.add(db_patient)
|
|
||||||
else:
|
|
||||||
db_administrator = models.Administrator(
|
|
||||||
user=db_user,
|
|
||||||
)
|
|
||||||
db.add(db_administrator)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
# Construct the updated user to return
|
|
||||||
db.refresh(db_user)
|
|
||||||
return _fill_missing_user_fields(db_user)
|
|
||||||
|
|
||||||
|
|
||||||
def read_user(db: Session, id: int):
|
|
||||||
db_user = db.query(models.User).filter(models.User.id == id).first()
|
|
||||||
if not db_user:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return _fill_missing_user_fields(db_user)
|
|
||||||
|
|
||||||
|
|
||||||
def read_user_by_email(db: Session, email: str):
|
|
||||||
db_user = db.query(models.User).filter(models.User.email == email).first()
|
|
||||||
if not db_user:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return _fill_missing_user_fields(db_user)
|
|
||||||
|
|
||||||
|
|
||||||
def read_users(db: Session, skip: int = 0, limit: int = 100):
|
|
||||||
db_users = db.query(models.User).offset(skip).limit(limit).all()
|
|
||||||
|
|
||||||
full_users = []
|
|
||||||
for db_user in db_users:
|
|
||||||
full_users.append(_fill_missing_user_fields(db_user))
|
|
||||||
return full_users
|
|
||||||
|
|
||||||
|
|
||||||
def update_user(db: Session, user: schemas.UserUpdate, id: int):
|
|
||||||
db_user = db.query(models.User).filter(models.User.id == id).first()
|
|
||||||
current_user = _fill_missing_user_fields(db_user)
|
|
||||||
|
|
||||||
for key in ['gender', 'date_of_birth']:
|
|
||||||
value = getattr(user, key)
|
|
||||||
if value is not None:
|
|
||||||
setattr(db_user.patient, key, value)
|
|
||||||
for key in ['email', 'first_name', 'last_name']:
|
|
||||||
value = getattr(user, key)
|
|
||||||
if value is not None:
|
|
||||||
setattr(db_user, key, value)
|
|
||||||
if user.password is not None:
|
|
||||||
db_user.password = hash_password(user.password)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db.refresh(db_user)
|
|
||||||
return _fill_missing_user_fields(db_user)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_user(db: Session, id: int):
|
|
||||||
db_user = db.query(models.User).filter(models.User.id == id).first()
|
|
||||||
user_copy = _fill_missing_user_fields(db_user)
|
|
||||||
|
|
||||||
db.delete(db_user)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
user_copy.updated = datetime.now(user_copy.updated.tzinfo)
|
|
||||||
return user_copy
|
|
0
backend/crud/__init__.py
Normal file
0
backend/crud/__init__.py
Normal file
144
backend/crud/users.py
Normal file
144
backend/crud/users.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
"""This module handles CRUD operations for users in the database, based on pydanctic schemas."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from backend.models import users as usermodel
|
||||||
|
from backend.schemas import users as userschema
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
"""This is a placeholder for a secure password hashing algorithm.
|
||||||
|
|
||||||
|
It will convert a plaintext password into a secure, salted hash, for storage
|
||||||
|
in the database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO actually hash the password!
|
||||||
|
return password
|
||||||
|
|
||||||
|
|
||||||
|
def _fill_missing_user_fields(db_user: usermodel.User) -> userschema.User:
|
||||||
|
"""Fills all the fields of an instance of userschema.User that cannot be filled by pydantic.
|
||||||
|
|
||||||
|
This function is necessary because the userschema is not a one-to-one reflection
|
||||||
|
of the database data model. I did not want the 'patient' and 'administrator'
|
||||||
|
database table to be encoded as their own top level JSON keys in serialized
|
||||||
|
user object. Instead, the user schema combines all fields from all user types.
|
||||||
|
This function fills the optional fields, depending on what type of user is
|
||||||
|
passed in.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
full_user = userschema.User.from_orm(db_user)
|
||||||
|
if db_user.patient:
|
||||||
|
full_user.gender = db_user.patient.gender
|
||||||
|
full_user.date_of_birth = db_user.patient.date_of_birth
|
||||||
|
full_user.is_patient = True
|
||||||
|
full_user.is_admin = False
|
||||||
|
else:
|
||||||
|
full_user.is_patient = False
|
||||||
|
full_user.is_admin = True
|
||||||
|
|
||||||
|
return full_user
|
||||||
|
|
||||||
|
|
||||||
|
def create_user(db: Session, user: userschema.UserCreate) -> userschema.User:
|
||||||
|
"""Creates the specified user in the database."""
|
||||||
|
|
||||||
|
db_user = usermodel.User(
|
||||||
|
email=user.email,
|
||||||
|
first_name=user.first_name,
|
||||||
|
last_name=user.last_name,
|
||||||
|
password=hash_password(user.password),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add user to database
|
||||||
|
if user.is_patient:
|
||||||
|
db_patient = usermodel.Patient(
|
||||||
|
user=db_user,
|
||||||
|
gender=user.gender,
|
||||||
|
date_of_birth=user.date_of_birth,
|
||||||
|
)
|
||||||
|
db.add(db_patient)
|
||||||
|
else:
|
||||||
|
db_administrator = usermodel.Administrator(
|
||||||
|
user=db_user,
|
||||||
|
)
|
||||||
|
db.add(db_administrator)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# Construct the updated user to return
|
||||||
|
db.refresh(db_user)
|
||||||
|
return _fill_missing_user_fields(db_user)
|
||||||
|
|
||||||
|
|
||||||
|
def read_user(db: Session, id: int) -> userschema.User | None:
|
||||||
|
"""Queries the db for a user with the specified id and returns them if they exist."""
|
||||||
|
|
||||||
|
db_user = db.query(usermodel.User).filter(usermodel.User.id == id).first()
|
||||||
|
if not db_user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _fill_missing_user_fields(db_user)
|
||||||
|
|
||||||
|
|
||||||
|
def read_user_by_email(db: Session, email: str) -> userschema.User | None:
|
||||||
|
"""Queries the db for a user with the specified email and returns them if they exist."""
|
||||||
|
|
||||||
|
db_user = db.query(usermodel.User).filter(usermodel.User.email == email).first()
|
||||||
|
if not db_user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _fill_missing_user_fields(db_user)
|
||||||
|
|
||||||
|
|
||||||
|
def read_users(db: Session, skip: int = 0, limit: int = 100) -> list[userschema.User]:
|
||||||
|
"""Returns an unfiltered range (by id) of users in the database."""
|
||||||
|
|
||||||
|
db_users = db.query(usermodel.User).offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
full_users = []
|
||||||
|
for db_user in db_users:
|
||||||
|
full_users.append(_fill_missing_user_fields(db_user))
|
||||||
|
return full_users
|
||||||
|
|
||||||
|
|
||||||
|
def update_user(db: Session, user: userschema.UserUpdate, id: int) -> userschema.User:
|
||||||
|
"""Updates the user with the provided id with all non-None fields from the input user."""
|
||||||
|
|
||||||
|
db_user = db.query(usermodel.User).filter(usermodel.User.id == id).first()
|
||||||
|
if not db_user:
|
||||||
|
raise RuntimeError("Query returned no user.") # should be checked by caller
|
||||||
|
|
||||||
|
for key in ['gender', 'date_of_birth']:
|
||||||
|
value = getattr(user, key)
|
||||||
|
if value is not None:
|
||||||
|
setattr(db_user.patient, key, value)
|
||||||
|
for key in ['email', 'first_name', 'last_name']:
|
||||||
|
value = getattr(user, key)
|
||||||
|
if value is not None:
|
||||||
|
setattr(db_user, key, value)
|
||||||
|
if user.password is not None:
|
||||||
|
db_user.password = hash_password(user.password)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_user)
|
||||||
|
return _fill_missing_user_fields(db_user)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_user(db: Session, id: int) -> userschema.User:
|
||||||
|
"""Deletes the user with the provided id from the db."""
|
||||||
|
|
||||||
|
db_user = db.query(usermodel.User).filter(usermodel.User.id == id).first()
|
||||||
|
if not db_user:
|
||||||
|
raise RuntimeError("Query returned no user.") # should be checked by caller
|
||||||
|
user_copy = _fill_missing_user_fields(db_user)
|
||||||
|
|
||||||
|
db.delete(db_user)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
user_copy.updated = datetime.now(user_copy.updated.tzinfo)
|
||||||
|
return user_copy
|
@ -1,13 +0,0 @@
|
|||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
from .config import get_settings
|
|
||||||
|
|
||||||
engine = create_engine(
|
|
||||||
get_settings().pg_dsn, # Get connection string from global settings
|
|
||||||
echo=get_settings().debug_mode # Get debugmode status from global settings
|
|
||||||
)
|
|
||||||
SessionLocal = sessionmaker(engine)
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
0
backend/database/__init__.py
Normal file
0
backend/database/__init__.py
Normal file
19
backend/database/engine.py
Normal file
19
backend/database/engine.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
"""This module configures and provides the sqlalchemy session factory and base model."""
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
from backend.config import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
s = get_settings()
|
||||||
|
|
||||||
|
# The SQL driver is specified by the DSN-prefix below.
|
||||||
|
_pg_dsn = f"postgresql+psycopg2://{s.pg_user}:{s.pg_password}@{s.pg_hostname}:{s.pg_port}/{s.pg_dbname}"
|
||||||
|
engine = create_engine(_pg_dsn, echo=s.debug_mode)
|
||||||
|
|
||||||
|
# SQLalchemy session factory
|
||||||
|
SessionLocal = sessionmaker(engine)
|
||||||
|
# SQLalchemy base model
|
||||||
|
Base = declarative_base()
|
@ -1,16 +1,17 @@
|
|||||||
import logging
|
"""Main entry point for the MEDWingS backend.
|
||||||
|
|
||||||
|
This module defines the API routes provided by the backend.
|
||||||
|
"""
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException
|
from fastapi import Depends, FastAPI, HTTPException
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from . import crud, models, schemas
|
import backend.models.users as usermodel
|
||||||
from.database import engine, SessionLocal
|
import backend.schemas.users as userschema
|
||||||
|
import backend.crud.users as usercrud
|
||||||
|
from backend.database.engine import SessionLocal
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
models.Base.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
@ -24,43 +25,44 @@ def get_db():
|
|||||||
|
|
||||||
@app.get("/hello/")
|
@app.get("/hello/")
|
||||||
def hello():
|
def hello():
|
||||||
|
"""Placeholder for a proper healthcheck endpoint."""
|
||||||
|
|
||||||
return "Hello World!"
|
return "Hello World!"
|
||||||
|
|
||||||
|
|
||||||
@app.post("/users/", response_model=schemas.User)
|
@app.post("/users/", response_model=userschema.User)
|
||||||
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)):
|
def create_user(user: userschema.UserCreate, db: Session = Depends(get_db)):
|
||||||
existing_user = crud.read_user_by_email(db, email=user.email)
|
existing_user = usercrud.read_user_by_email(db, email=user.email)
|
||||||
if existing_user:
|
if existing_user:
|
||||||
raise HTTPException(status_code=400, detail="A user with this email address is already registered.")
|
raise HTTPException(status_code=400, detail="A user with this email address is already registered.")
|
||||||
|
return usercrud.create_user(db=db, user=user)
|
||||||
return crud.create_user(db=db, user=user)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/users/{id}", response_model=schemas.User)
|
@app.get("/users/{id}", response_model=userschema.User)
|
||||||
def read_user(id: int, db: Session = Depends(get_db)):
|
def read_user(id: int, db: Session = Depends(get_db)):
|
||||||
user = crud.read_user(db=db, id=id)
|
user = usercrud.read_user(db=db, id=id)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=404, detail=f"No user with id '{id}' found.")
|
raise HTTPException(status_code=404, detail=f"No user with id '{id}' found.")
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@app.get("/users/", response_model=list[schemas.User])
|
@app.get("/users/", response_model=list[userschema.User])
|
||||||
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
||||||
users = crud.read_users(db=db, skip=skip, limit=limit)
|
users = usercrud.read_users(db=db, skip=skip, limit=limit)
|
||||||
return users
|
return users
|
||||||
|
|
||||||
|
|
||||||
@app.patch("/users/{id}", response_model=schemas.User)
|
@app.patch("/users/{id}", response_model=userschema.User)
|
||||||
def update_user(id: int, user: schemas.UserUpdate, db: Session = Depends(get_db)):
|
def update_user(id: int, user: userschema.UserUpdate, db: Session = Depends(get_db)):
|
||||||
current_user = crud.read_user(db=db, id=id)
|
current_user = usercrud.read_user(db=db, id=id)
|
||||||
if not current_user:
|
if not current_user:
|
||||||
raise HTTPException(status_code=404, detail=f"No user with id '{id}' found.")
|
raise HTTPException(status_code=404, detail=f"No user with id '{id}' found.")
|
||||||
return crud.update_user(db=db, user=user, id=id)
|
return usercrud.update_user(db=db, user=user, id=id)
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/users/{id}", response_model=schemas.User)
|
@app.delete("/users/{id}", response_model=userschema.User)
|
||||||
def delete_user(id: int, db: Session = Depends(get_db)):
|
def delete_user(id: int, db: Session = Depends(get_db)):
|
||||||
user = crud.read_user(db=db, id=id)
|
user = usercrud.read_user(db=db, id=id)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=404, detail=f"No user with id '{id}' found.")
|
raise HTTPException(status_code=404, detail=f"No user with id '{id}' found.")
|
||||||
return crud.delete_user(db=db, id=id)
|
return usercrud.delete_user(db=db, id=id)
|
||||||
|
0
backend/models/__init__.py
Normal file
0
backend/models/__init__.py
Normal file
@ -1,13 +1,20 @@
|
|||||||
|
"""This module defines the SQL user model for users.
|
||||||
|
|
||||||
|
All users are either Patients or Administrators.
|
||||||
|
"""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Date, Enum, CheckConstraint
|
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Date, Enum, CheckConstraint
|
||||||
from sqlalchemy.sql.functions import now
|
from sqlalchemy.sql.functions import now
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
from .database import Base
|
from backend.database.engine import Base
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
|
"""Model for the users table. Contains user info common to all user classes."""
|
||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
||||||
@ -20,11 +27,11 @@ class User(Base):
|
|||||||
|
|
||||||
administrator = relationship("Administrator", back_populates="user", uselist=False, cascade="all, delete")
|
administrator = relationship("Administrator", back_populates="user", uselist=False, cascade="all, delete")
|
||||||
patient = relationship("Patient", back_populates="user", uselist=False, cascade="all, delete")
|
patient = relationship("Patient", back_populates="user", uselist=False, cascade="all, delete")
|
||||||
#patient = Column(Integer, ForeignKey('patients.id'), nullable=True)
|
|
||||||
#CheckConstraint("(administrator=NULL AND patient!=NULL) OR (administrator!=NULL AND patient=NULL)")
|
|
||||||
|
|
||||||
|
|
||||||
class Administrator(Base):
|
class Administrator(Base):
|
||||||
|
"""Model for the administrators table. Contains user info specific to administrators."""
|
||||||
|
|
||||||
__tablename__ = "administrators"
|
__tablename__ = "administrators"
|
||||||
|
|
||||||
user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), primary_key=True,)
|
user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), primary_key=True,)
|
||||||
@ -32,11 +39,15 @@ class Administrator(Base):
|
|||||||
|
|
||||||
|
|
||||||
class Gender(enum.Enum):
|
class Gender(enum.Enum):
|
||||||
|
"""Gender (as assigned at birth) of a patient."""
|
||||||
|
|
||||||
male = 'm'
|
male = 'm'
|
||||||
female = 'f'
|
female = 'f'
|
||||||
|
|
||||||
|
|
||||||
class Patient(Base):
|
class Patient(Base):
|
||||||
|
"""Model for the patients table. Contains user info specific to patients."""
|
||||||
|
|
||||||
__tablename__ = "patients"
|
__tablename__ = "patients"
|
||||||
|
|
||||||
user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), primary_key=True)
|
user_id = Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), primary_key=True)
|
0
backend/schemas/__init__.py
Normal file
0
backend/schemas/__init__.py
Normal file
@ -1,13 +1,25 @@
|
|||||||
|
"""This module declared the pydantic schema representation for users.
|
||||||
|
|
||||||
|
Note that it is not a direct representation of how users are modeled in the
|
||||||
|
database. Instead, the User schema class contains all attributes from all user classes
|
||||||
|
as optional attributes.
|
||||||
|
|
||||||
|
I haven't figured out a smart way to do this with pydantic yet, so behold the
|
||||||
|
inheritance hellhole below.
|
||||||
|
"""
|
||||||
|
|
||||||
from datetime import datetime, date
|
from datetime import datetime, date
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
from .models import Gender
|
from backend.models.users import Gender
|
||||||
|
|
||||||
|
|
||||||
class AbstractUserInfoValidation(BaseModel, ABC):
|
class AbstractUserInfoValidation(BaseModel, ABC):
|
||||||
|
"""Base class providing common field validators."""
|
||||||
|
|
||||||
@validator('email', check_fields=False)
|
@validator('email', check_fields=False)
|
||||||
def assert_email_is_valid(cls, email):
|
def assert_email_is_valid(cls, email):
|
||||||
if email is not None:
|
if email is not None:
|
||||||
@ -37,7 +49,14 @@ class AbstractUserInfoValidation(BaseModel, ABC):
|
|||||||
raise ValueError("Date of birth cannot be in the future.")
|
raise ValueError("Date of birth cannot be in the future.")
|
||||||
return dob
|
return dob
|
||||||
|
|
||||||
|
|
||||||
class AbstractUser(AbstractUserInfoValidation, ABC):
|
class AbstractUser(AbstractUserInfoValidation, ABC):
|
||||||
|
"""Base class for attributes common to user creation and user representation.
|
||||||
|
|
||||||
|
A user must be either a patient or an administrator. If a user is a patient,
|
||||||
|
they must specify valid 'date_of_birth' and 'gender' attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
email: str
|
email: str
|
||||||
first_name: str
|
first_name: str
|
||||||
last_name: str
|
last_name: str
|
||||||
@ -50,6 +69,8 @@ class AbstractUser(AbstractUserInfoValidation, ABC):
|
|||||||
|
|
||||||
@validator('is_admin')
|
@validator('is_admin')
|
||||||
def assert_tegridy(cls, is_admin, values):
|
def assert_tegridy(cls, is_admin, values):
|
||||||
|
"""Ensures logical model integrity when optional fields are set."""
|
||||||
|
|
||||||
if values['is_patient']:
|
if values['is_patient']:
|
||||||
if is_admin:
|
if is_admin:
|
||||||
raise ValueError('User cannot be both patient and admin.')
|
raise ValueError('User cannot be both patient and admin.')
|
||||||
@ -62,6 +83,8 @@ class AbstractUser(AbstractUserInfoValidation, ABC):
|
|||||||
|
|
||||||
|
|
||||||
class UserCreate(AbstractUser):
|
class UserCreate(AbstractUser):
|
||||||
|
"""Scheme for user creation."""
|
||||||
|
|
||||||
password: str
|
password: str
|
||||||
password_confirmation: str
|
password_confirmation: str
|
||||||
|
|
||||||
@ -76,6 +99,16 @@ class UserCreate(AbstractUser):
|
|||||||
|
|
||||||
|
|
||||||
class UserUpdate(AbstractUserInfoValidation):
|
class UserUpdate(AbstractUserInfoValidation):
|
||||||
|
"""Scheme for user updates.
|
||||||
|
|
||||||
|
All fields here are optional, but passwords must match if at least one was
|
||||||
|
provided.
|
||||||
|
Note that even administrator updates can specify 'gender' and 'date_of_birth'
|
||||||
|
fields, the function inserting the update into the db should handle this (and
|
||||||
|
just ignore the fields).
|
||||||
|
Switching user types is prohibited.
|
||||||
|
"""
|
||||||
|
|
||||||
email: Optional[str]
|
email: Optional[str]
|
||||||
first_name: Optional[str]
|
first_name: Optional[str]
|
||||||
last_name: Optional[str]
|
last_name: Optional[str]
|
||||||
@ -99,6 +132,12 @@ class UserUpdate(AbstractUserInfoValidation):
|
|||||||
|
|
||||||
|
|
||||||
class User(AbstractUser):
|
class User(AbstractUser):
|
||||||
|
"""Final representation of all types of users, wrapped into one User schema.
|
||||||
|
|
||||||
|
The id, created and updated fields are filled by the db during creation, so
|
||||||
|
they are not needed in the parent classes.
|
||||||
|
"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
created: datetime
|
created: datetime
|
||||||
updated: datetime
|
updated: datetime
|
Loading…
x
Reference in New Issue
Block a user