better errors && some fixes

This commit is contained in:
grey-cat-1908 2024-08-14 13:46:48 +00:00
parent 47c219fd1f
commit ae8437b28c
3 changed files with 34 additions and 41 deletions

View file

@ -37,11 +37,7 @@ class SelectorValue(BaseValue):
values: set[int] values: set[int]
def validate(self, question: form.SelectorQuestion) -> None: def validate(self, question: form.SelectorQuestion) -> None:
min_values = ( min_values = max(question.min_values, 1) if question.min_values else 1
max(question.min_values, 1)
if question.min_values
else 1
)
max_values = ( max_values = (
min(question.max_values, question.options) min(question.max_values, question.options)
if question.max_values if question.max_values
@ -72,9 +68,8 @@ class AnswerData(BaseValue):
uuids.add(value.question_id) uuids.add(value.question_id)
if len(v) != len(uuids): if len(v) != len(uuids):
raise ValueError( raise ValueError(AnswerError.DUPLICATE_QUESTIONS)
AnswerError.DUPLICATE_QUESTIONS
)
class Answer(BaseModel): class Answer(BaseModel):
id: int id: int
@ -83,20 +78,16 @@ class Answer(BaseModel):
@field_validator("data") @field_validator("data")
@classmethod @classmethod
def answer_validator( def answer_validator(cls, v, info):
cls,
v,
info
):
uuids = v.question_uuids uuids = v.question_uuids
questions = info.data["form"].data questions = info.data["form"].data
for question in questions: for question in questions:
if question.required and question.id not in uuids: if question.required and question.id not in uuids:
raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED) raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED)
if question.question_type != uuids[question.id].question_type: if question.question_type != uuids[question.id].question_type:
raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED) raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED)
del uuids[question.id] del uuids[question.id]
if len(uuids) > 0: if len(uuids) > 0:
raise ValueError("Some questions are not known") raise ValueError(AnswerError.INCORRECT_IDS)
return v return v

View file

@ -1,12 +1,21 @@
from enum import IntEnum, auto from enum import IntEnum, Enum
from typing import TypeAlias from typing import TypeAlias
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from pydantic import Field, field_validator from pydantic import Field, field_validator, field_serializer
from models import BaseModel from models import BaseModel
class FormError(Enum):
MIN_LENGTH_ERR = "min_length must be greater than or equal to 0."
MAX_LENGTH_ERR = "max_length must be greater than 0."
MAX_LENGTH_LESS_THAN_MIN_LENGTH = "max_length cannot be less than min_length."
MIN_VALUES_ERR = "min_values must be greater than or equal to 1."
MAX_VALUES_ERR = "max_values cannot be less than min_length or greater than the number of options."
SIMMILAR_ID_ERR = "All questions must have different id's"
class QuestionType(IntEnum): class QuestionType(IntEnum):
text = 1 text = 1
selector = 2 selector = 2
@ -19,6 +28,10 @@ class BaseQuestion(BaseModel):
description: str | None = Field(None, min_length=1) description: str | None = Field(None, min_length=1)
required: bool = True required: bool = True
@field_serializer("id")
def serialize_id(self, id: UUID):
return str(id)
class Option(BaseModel): class Option(BaseModel):
label: str label: str
@ -33,7 +46,7 @@ class TextQuestion(BaseQuestion):
@classmethod @classmethod
def validate_min_length(cls, v, info): def validate_min_length(cls, v, info):
if v is not None and v < 0: if v is not None and v < 0:
raise ValueError("min_length must be greater than or equal to 0") raise ValueError(FormError.MIN_LENGTH_ERR)
return v return v
@field_validator("max_length") @field_validator("max_length")
@ -42,9 +55,9 @@ class TextQuestion(BaseQuestion):
min_length = info.data.get("min_length") min_length = info.data.get("min_length")
if v is not None: if v is not None:
if v <= 0: if v <= 0:
raise ValueError("max_length must be greater than 0") raise ValueError(FormError.MAX_LENGTH_TOO_SMALL)
if min_length is not None and v < min_length: if min_length is not None and v < min_length:
raise ValueError("max_length cannot be less than min_length") raise ValueError(FormError.MAX_LENGTH_LESS_THAN_MIN_LENGTH)
return v return v
@ -58,8 +71,9 @@ class SelectorQuestion(BaseQuestion):
@classmethod @classmethod
def validate_min_values(cls, v, info): def validate_min_values(cls, v, info):
options = info.data.get("options") options = info.data.get("options")
options = [] if not options else options
if v is not None and (v < 1 or v > len(options)): if v is not None and (v < 1 or v > len(options)):
raise ValueError("min_values must be greater than or equal to 1") raise ValueError(FormError.MIN_VALUES_ERR)
return v return v
@field_validator("max_values") @field_validator("max_values")
@ -68,9 +82,7 @@ class SelectorQuestion(BaseQuestion):
min_values = info.data.get("min_values") min_values = info.data.get("min_values")
options = info.data.get("options") options = info.data.get("options")
if v is not None and (v > len(options) or min_values > v): if v is not None and (v > len(options) or min_values > v):
raise ValueError( raise ValueError(FormError.MAX_VALUES_ERR)
"max_values cannot be less than min_length or greater than the number of options"
)
return v return v
@ -84,21 +96,13 @@ class FormData(BaseModel):
@field_validator("questions") @field_validator("questions")
@classmethod @classmethod
def validate_questions( def validate_questions(cls, v, info):
cls,
v,
info
):
questions = info.data.get("questions")
uuids = set() uuids = set()
for question in questions: for question in v:
uuids.add(question.question_id) uuids.add(question.id)
if len(questions) != len(uuids): if len(v) != len(uuids):
raise ValueError( raise ValueError(FormError.SIMMILAR_ID_ERR)
"All questions must have different id's"
)
return v return v

View file

@ -29,9 +29,7 @@ async def create_form(user: User, form_data: FormData) -> CreateForm:
@router.delete("/delete") @router.delete("/delete")
async def delete_form(user: User, id: int): async def delete_form(user: User, id: int):
async with database.sessions.begin() as session: async with database.sessions.begin() as session:
stmt = select(database.Form).where( stmt = select(database.Form).where(database.Form.id == id)
database.Form.id == id
)
db_request = await session.execute(stmt) db_request = await session.execute(stmt)
form = db_request.scalar_one_or_none() form = db_request.scalar_one_or_none()
@ -39,7 +37,7 @@ async def delete_form(user: User, id: int):
raise HTTPException(404, "No form was found") raise HTTPException(404, "No form was found")
if form.owner_id != user.id: if form.owner_id != user.id:
raise HTTPException(403, "Forbidden") raise HTTPException(403, "Forbidden")
await session.delete(form) await session.delete(form)