diff --git a/models/answer.py b/models/answer.py index 3c74fb7..433f9ba 100644 --- a/models/answer.py +++ b/models/answer.py @@ -37,11 +37,7 @@ class SelectorValue(BaseValue): values: set[int] def validate(self, question: form.SelectorQuestion) -> None: - min_values = ( - max(question.min_values, 1) - if question.min_values - else 1 - ) + min_values = max(question.min_values, 1) if question.min_values else 1 max_values = ( min(question.max_values, question.options) if question.max_values @@ -72,9 +68,8 @@ class AnswerData(BaseValue): uuids.add(value.question_id) if len(v) != len(uuids): - raise ValueError( - AnswerError.DUPLICATE_QUESTIONS - ) + raise ValueError(AnswerError.DUPLICATE_QUESTIONS) + class Answer(BaseModel): id: int @@ -83,20 +78,16 @@ class Answer(BaseModel): @field_validator("data") @classmethod - def answer_validator( - cls, - v, - info - ): + def answer_validator(cls, v, info): uuids = v.question_uuids questions = info.data["form"].data for question in questions: if question.required and question.id not in uuids: raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED) if question.question_type != uuids[question.id].question_type: - raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED) + raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED) del uuids[question.id] if len(uuids) > 0: - raise ValueError("Some questions are not known") + raise ValueError(AnswerError.INCORRECT_IDS) return v diff --git a/models/form.py b/models/form.py index 3e78275..86512a6 100644 --- a/models/form.py +++ b/models/form.py @@ -1,12 +1,21 @@ -from enum import IntEnum, auto +from enum import IntEnum, Enum from typing import TypeAlias from uuid import UUID, uuid4 -from pydantic import Field, field_validator +from pydantic import Field, field_validator, field_serializer from models import BaseModel +class FormError(Enum): + MIN_LENGTH_ERR = "min_length must be greater than or equal to 0." + MAX_LENGTH_ERR = "max_length must be greater than 0." + MAX_LENGTH_LESS_THAN_MIN_LENGTH = "max_length cannot be less than min_length." + MIN_VALUES_ERR = "min_values must be greater than or equal to 1." + MAX_VALUES_ERR = "max_values cannot be less than min_length or greater than the number of options." + SIMMILAR_ID_ERR = "All questions must have different id's" + + class QuestionType(IntEnum): text = 1 selector = 2 @@ -19,6 +28,10 @@ class BaseQuestion(BaseModel): description: str | None = Field(None, min_length=1) required: bool = True + @field_serializer("id") + def serialize_id(self, id: UUID): + return str(id) + class Option(BaseModel): label: str @@ -33,7 +46,7 @@ class TextQuestion(BaseQuestion): @classmethod def validate_min_length(cls, v, info): if v is not None and v < 0: - raise ValueError("min_length must be greater than or equal to 0") + raise ValueError(FormError.MIN_LENGTH_ERR) return v @field_validator("max_length") @@ -42,9 +55,9 @@ class TextQuestion(BaseQuestion): min_length = info.data.get("min_length") if v is not None: if v <= 0: - raise ValueError("max_length must be greater than 0") + raise ValueError(FormError.MAX_LENGTH_TOO_SMALL) if min_length is not None and v < min_length: - raise ValueError("max_length cannot be less than min_length") + raise ValueError(FormError.MAX_LENGTH_LESS_THAN_MIN_LENGTH) return v @@ -58,8 +71,9 @@ class SelectorQuestion(BaseQuestion): @classmethod def validate_min_values(cls, v, info): options = info.data.get("options") + options = [] if not options else options if v is not None and (v < 1 or v > len(options)): - raise ValueError("min_values must be greater than or equal to 1") + raise ValueError(FormError.MIN_VALUES_ERR) return v @field_validator("max_values") @@ -68,9 +82,7 @@ class SelectorQuestion(BaseQuestion): min_values = info.data.get("min_values") options = info.data.get("options") if v is not None and (v > len(options) or min_values > v): - raise ValueError( - "max_values cannot be less than min_length or greater than the number of options" - ) + raise ValueError(FormError.MAX_VALUES_ERR) return v @@ -84,21 +96,13 @@ class FormData(BaseModel): @field_validator("questions") @classmethod - def validate_questions( - cls, - v, - info - ): - questions = info.data.get("questions") - + def validate_questions(cls, v, info): uuids = set() - for question in questions: - uuids.add(question.question_id) + for question in v: + uuids.add(question.id) - if len(questions) != len(uuids): - raise ValueError( - "All questions must have different id's" - ) + if len(v) != len(uuids): + raise ValueError(FormError.SIMMILAR_ID_ERR) return v diff --git a/routes/form.py b/routes/form.py index e46d97f..758e3ba 100644 --- a/routes/form.py +++ b/routes/form.py @@ -29,9 +29,7 @@ async def create_form(user: User, form_data: FormData) -> CreateForm: @router.delete("/delete") async def delete_form(user: User, id: int): async with database.sessions.begin() as session: - stmt = select(database.Form).where( - database.Form.id == id - ) + stmt = select(database.Form).where(database.Form.id == id) db_request = await session.execute(stmt) form = db_request.scalar_one_or_none() @@ -39,7 +37,7 @@ async def delete_form(user: User, id: int): raise HTTPException(404, "No form was found") if form.owner_id != user.id: raise HTTPException(403, "Forbidden") - + await session.delete(form)