From c0f79726ecb43c1ba344b2a3baec7aa5e1a795e7 Mon Sep 17 00:00:00 2001 From: grey-cat-1908 <61203964+grey-cat-1908@users.noreply.github.com> Date: Thu, 15 Aug 2024 18:01:01 +0000 Subject: [PATCH] create answer && answer fixes --- database/__init__.py | 1 + database/answer.py | 14 ++++++++++++++ models/answer.py | 30 +++++++++++++++++++----------- models/form.py | 12 ++++++------ routes/__init__.py | 2 ++ routes/answer.py | 28 ++++++++++++++++++++++++++++ routes/form.py | 24 +++++++++++++++--------- 7 files changed, 85 insertions(+), 26 deletions(-) create mode 100644 database/answer.py create mode 100644 routes/answer.py diff --git a/database/__init__.py b/database/__init__.py index 4733df9..127972e 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -12,3 +12,4 @@ class Base(DeclarativeBase): from .user import User from .form import Form +from .answer import Answer diff --git a/database/answer.py b/database/answer.py new file mode 100644 index 0000000..95638cf --- /dev/null +++ b/database/answer.py @@ -0,0 +1,14 @@ +from sqlalchemy import JSON, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from database import Base, Form + + +class Answer(Base): + __tablename__ = "answers" + + id: Mapped[int] = mapped_column(primary_key=True) + form_id: Mapped[int] = mapped_column(ForeignKey(Form.id, ondelete="CASCADE")) + data: Mapped[dict] = mapped_column(JSON) + + form: Mapped[Form] = relationship(Form, lazy="joined") diff --git a/models/answer.py b/models/answer.py index 433f9ba..25d5a83 100644 --- a/models/answer.py +++ b/models/answer.py @@ -2,7 +2,7 @@ from enum import Enum from uuid import UUID from typing import TypeAlias -from pydantic import field_validator +from pydantic import field_validator, field_serializer from models import BaseModel, form @@ -20,6 +20,10 @@ class BaseValue(BaseModel): question_id: UUID question_type: form.QuestionType + @field_serializer("question_id") + def serialize_id(self, id: UUID): + return str(id) + class TextValue(BaseValue): question_type: form.QuestionType = form.QuestionType.text @@ -27,9 +31,9 @@ class TextValue(BaseValue): def validate(self, question: form.TextQuestion) -> None: if question.min_length and len(self.value) < question.min_length: - raise ValueError(AnswerError.TOO_SHORT) + raise ValueError(AnswerError.TOO_SHORT.value) if question.max_length and len(self.value) > question.max_length: - raise ValueError(AnswerError.TOO_LONG) + raise ValueError(AnswerError.TOO_LONG.value) class SelectorValue(BaseValue): @@ -45,15 +49,15 @@ class SelectorValue(BaseValue): ) if len(self.values) < min_values: - raise ValueError(AnswerError.TOO_FEW_SELECTED) + raise ValueError(AnswerError.TOO_FEW_SELECTED.value) if len(self.values) > max_values: - raise ValueError(AnswerError.TOO_MANY_SELECTED) + raise ValueError(AnswerError.TOO_MANY_SELECTED.value) Value: TypeAlias = SelectorValue | TextValue -class AnswerData(BaseValue): +class AnswerData(BaseModel): values: list[Value] @property @@ -68,7 +72,9 @@ class AnswerData(BaseValue): uuids.add(value.question_id) if len(v) != len(uuids): - raise ValueError(AnswerError.DUPLICATE_QUESTIONS) + raise ValueError(AnswerError.DUPLICATE_QUESTIONS.value) + + return v class Answer(BaseModel): @@ -80,14 +86,16 @@ class Answer(BaseModel): @classmethod def answer_validator(cls, v, info): uuids = v.question_uuids - questions = info.data["form"].data + questions = info.data["form"].data.questions for question in questions: if question.required and question.id not in uuids: - raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED) + raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED.value) if question.question_type != uuids[question.id].question_type: - raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED) + raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED.value) + uuids[question.id].validate(question) + del uuids[question.id] if len(uuids) > 0: - raise ValueError(AnswerError.INCORRECT_IDS) + raise ValueError(AnswerError.INCORRECT_IDS.value) return v diff --git a/models/form.py b/models/form.py index 86512a6..01da847 100644 --- a/models/form.py +++ b/models/form.py @@ -46,7 +46,7 @@ class TextQuestion(BaseQuestion): @classmethod def validate_min_length(cls, v, info): if v is not None and v < 0: - raise ValueError(FormError.MIN_LENGTH_ERR) + raise ValueError(FormError.MIN_LENGTH_ERR.value) return v @field_validator("max_length") @@ -55,9 +55,9 @@ class TextQuestion(BaseQuestion): min_length = info.data.get("min_length") if v is not None: if v <= 0: - raise ValueError(FormError.MAX_LENGTH_TOO_SMALL) + raise ValueError(FormError.MAX_LENGTH_TOO_SMALL.value) if min_length is not None and v < min_length: - raise ValueError(FormError.MAX_LENGTH_LESS_THAN_MIN_LENGTH) + raise ValueError(FormError.MAX_LENGTH_LESS_THAN_MIN_LENGTH.value) return v @@ -73,7 +73,7 @@ class SelectorQuestion(BaseQuestion): options = info.data.get("options") options = [] if not options else options if v is not None and (v < 1 or v > len(options)): - raise ValueError(FormError.MIN_VALUES_ERR) + raise ValueError(FormError.MIN_VALUES_ERR.value) return v @field_validator("max_values") @@ -82,7 +82,7 @@ class SelectorQuestion(BaseQuestion): min_values = info.data.get("min_values") options = info.data.get("options") if v is not None and (v > len(options) or min_values > v): - raise ValueError(FormError.MAX_VALUES_ERR) + raise ValueError(FormError.MAX_VALUES_ERR.value) return v @@ -102,7 +102,7 @@ class FormData(BaseModel): uuids.add(question.id) if len(v) != len(uuids): - raise ValueError(FormError.SIMMILAR_ID_ERR) + raise ValueError(FormError.SIMMILAR_ID_ERR.value) return v diff --git a/routes/__init__.py b/routes/__init__.py index b32dc9a..4bce9ec 100644 --- a/routes/__init__.py +++ b/routes/__init__.py @@ -3,9 +3,11 @@ from fastapi import APIRouter from . import admin from . import user from . import form +from . import answer router = APIRouter() router.include_router(admin.router) router.include_router(user.router) router.include_router(form.router) +router.include_router(answer.router) diff --git a/routes/answer.py b/routes/answer.py new file mode 100644 index 0000000..fdb87b6 --- /dev/null +++ b/routes/answer.py @@ -0,0 +1,28 @@ +from fastapi import APIRouter, HTTPException +from sqlalchemy import select +from pydantic import ValidationError + +import database +from models import AnswerData, Answer + +router = APIRouter(prefix="/answer") + + +@router.post("/create") +async def create_answer(form_id: int, answer_data: AnswerData): + async with database.sessions.begin() as session: + answer = database.Answer( + form_id=form_id, + data=answer_data.model_dump(), + ) + + session.add(answer) + await session.flush() + await session.refresh(answer) + + try: + answer_model = Answer.model_validate(answer) + except ValidationError as e: + raise HTTPException(400, e.errors()[0].get("msg")) + + return answer_model diff --git a/routes/form.py b/routes/form.py index 758e3ba..f0b5afa 100644 --- a/routes/form.py +++ b/routes/form.py @@ -2,18 +2,14 @@ from fastapi import APIRouter, HTTPException from sqlalchemy import select import database -from models import FormData, Form, BaseModel +from models import FormData, Form from .utils import User router = APIRouter(prefix="/form") -class CreateForm(BaseModel): - form_id: int - - @router.post("/create") -async def create_form(user: User, form_data: FormData) -> CreateForm: +async def create_form(user: User, form_data: FormData) -> Form: async with database.sessions.begin() as session: form = database.Form( name=form_data.name, owner_id=user.id, data=form_data.model_dump() @@ -23,7 +19,7 @@ async def create_form(user: User, form_data: FormData) -> CreateForm: await session.flush() await session.refresh(form) - return CreateForm(form_id=form.id) + return Form.model_validate(form) @router.delete("/delete") @@ -34,14 +30,14 @@ async def delete_form(user: User, id: int): form = db_request.scalar_one_or_none() if form is None: - raise HTTPException(404, "No form was found") + raise HTTPException(404, "Form not found") if form.owner_id != user.id: raise HTTPException(403, "Forbidden") await session.delete(form) -@router.get("/my") +@router.get("/list") async def user_forms(user: User): async with database.sessions.begin() as session: return { @@ -52,3 +48,13 @@ async def user_forms(user: User): ) ] } + + +@router.get("/get") +async def get_form(id: int) -> Form: + async with database.sessions.begin() as session: + stmt = select(database.Form).where(database.Form.id == id) + db_request = await session.execute(stmt) + form = db_request.scalar_one_or_none() + + return Form.model_validate(form)