answer model && some form updates

This commit is contained in:
grey-cat-1908 2024-08-14 13:21:23 +00:00
parent 69e21b1290
commit 47c219fd1f
4 changed files with 143 additions and 1 deletions

View file

@ -9,3 +9,4 @@ class BaseModel(pydantic.BaseModel):
from .settings import settings
from .user import *
from .form import *
from .answer import *

102
models/answer.py Normal file
View file

@ -0,0 +1,102 @@
from enum import Enum
from uuid import UUID
from typing import TypeAlias
from pydantic import field_validator
from models import BaseModel, form
class AnswerError(Enum):
TOO_SHORT = "The text value is shorter than the minimum allowed length."
TOO_LONG = "The text value is longer than the maximum allowed length."
TOO_FEW_SELECTED = "The number of selected items is less than the minimum required."
TOO_MANY_SELECTED = "The number of selected items is more than the maximum allowed."
DUPLICATE_QUESTIONS = "Each value must correspond to a different question."
INCORRECT_IDS = "The ids for some questions are incorrect."
class BaseValue(BaseModel):
question_id: UUID
question_type: form.QuestionType
class TextValue(BaseValue):
question_type: form.QuestionType = form.QuestionType.text
value: str
def validate(self, question: form.TextQuestion) -> None:
if question.min_length and len(self.value) < question.min_length:
raise ValueError(AnswerError.TOO_SHORT)
if question.max_length and len(self.value) > question.max_length:
raise ValueError(AnswerError.TOO_LONG)
class SelectorValue(BaseValue):
question_type: form.QuestionType = form.QuestionType.selector
values: set[int]
def validate(self, question: form.SelectorQuestion) -> None:
min_values = (
max(question.min_values, 1)
if question.min_values
else 1
)
max_values = (
min(question.max_values, question.options)
if question.max_values
else len(question.options)
)
if len(self.values) < min_values:
raise ValueError(AnswerError.TOO_FEW_SELECTED)
if len(self.values) > max_values:
raise ValueError(AnswerError.TOO_MANY_SELECTED)
Value: TypeAlias = SelectorValue | TextValue
class AnswerData(BaseValue):
values: list[Value]
@property
def question_uuids(self) -> dict[UUID, Value]:
return {value.question_id: value for value in self.values}
@field_validator("values")
@classmethod
def validate_values(cls, v, info):
uuids = set()
for value in v:
uuids.add(value.question_id)
if len(v) != len(uuids):
raise ValueError(
AnswerError.DUPLICATE_QUESTIONS
)
class Answer(BaseModel):
id: int
form: form.Form
data: AnswerData
@field_validator("data")
@classmethod
def answer_validator(
cls,
v,
info
):
uuids = v.question_uuids
questions = info.data["form"].data
for question in questions:
if question.required and question.id not in uuids:
raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED)
if question.question_type != uuids[question.id].question_type:
raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED)
del uuids[question.id]
if len(uuids) > 0:
raise ValueError("Some questions are not known")
return v

View file

@ -1,5 +1,6 @@
from enum import IntEnum, auto
from typing import TypeAlias
from uuid import UUID, uuid4
from pydantic import Field, field_validator
@ -12,6 +13,7 @@ class QuestionType(IntEnum):
class BaseQuestion(BaseModel):
id: UUID = Field(default_factory=uuid4)
question_type: QuestionType
label: str = Field(min_length=1)
description: str | None = Field(None, min_length=1)
@ -80,6 +82,26 @@ class FormData(BaseModel):
description: str | None = Field(None, min_length=1)
questions: list[Question] = []
@field_validator("questions")
@classmethod
def validate_questions(
cls,
v,
info
):
questions = info.data.get("questions")
uuids = set()
for question in questions:
uuids.add(question.question_id)
if len(questions) != len(uuids):
raise ValueError(
"All questions must have different id's"
)
return v
class Form(BaseModel):
id: int

View file

@ -1,4 +1,4 @@
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from sqlalchemy import select
import database
@ -26,6 +26,23 @@ async def create_form(user: User, form_data: FormData) -> CreateForm:
return CreateForm(form_id=form.id)
@router.delete("/delete")
async def delete_form(user: User, id: int):
async with database.sessions.begin() as session:
stmt = select(database.Form).where(
database.Form.id == id
)
db_request = await session.execute(stmt)
form = db_request.scalar_one_or_none()
if form is None:
raise HTTPException(404, "No form was found")
if form.owner_id != user.id:
raise HTTPException(403, "Forbidden")
await session.delete(form)
@router.get("/my")
async def user_forms(user: User):
async with database.sessions.begin() as session: