create answer && answer fixes

This commit is contained in:
grey-cat-1908 2024-08-15 18:01:01 +00:00
parent ae8437b28c
commit c0f79726ec
7 changed files with 85 additions and 26 deletions

View file

@ -12,3 +12,4 @@ class Base(DeclarativeBase):
from .user import User from .user import User
from .form import Form from .form import Form
from .answer import Answer

14
database/answer.py Normal file
View file

@ -0,0 +1,14 @@
from sqlalchemy import JSON, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from database import Base, Form
class Answer(Base):
__tablename__ = "answers"
id: Mapped[int] = mapped_column(primary_key=True)
form_id: Mapped[int] = mapped_column(ForeignKey(Form.id, ondelete="CASCADE"))
data: Mapped[dict] = mapped_column(JSON)
form: Mapped[Form] = relationship(Form, lazy="joined")

View file

@ -2,7 +2,7 @@ from enum import Enum
from uuid import UUID from uuid import UUID
from typing import TypeAlias from typing import TypeAlias
from pydantic import field_validator from pydantic import field_validator, field_serializer
from models import BaseModel, form from models import BaseModel, form
@ -20,6 +20,10 @@ class BaseValue(BaseModel):
question_id: UUID question_id: UUID
question_type: form.QuestionType question_type: form.QuestionType
@field_serializer("question_id")
def serialize_id(self, id: UUID):
return str(id)
class TextValue(BaseValue): class TextValue(BaseValue):
question_type: form.QuestionType = form.QuestionType.text question_type: form.QuestionType = form.QuestionType.text
@ -27,9 +31,9 @@ class TextValue(BaseValue):
def validate(self, question: form.TextQuestion) -> None: def validate(self, question: form.TextQuestion) -> None:
if question.min_length and len(self.value) < question.min_length: if question.min_length and len(self.value) < question.min_length:
raise ValueError(AnswerError.TOO_SHORT) raise ValueError(AnswerError.TOO_SHORT.value)
if question.max_length and len(self.value) > question.max_length: if question.max_length and len(self.value) > question.max_length:
raise ValueError(AnswerError.TOO_LONG) raise ValueError(AnswerError.TOO_LONG.value)
class SelectorValue(BaseValue): class SelectorValue(BaseValue):
@ -45,15 +49,15 @@ class SelectorValue(BaseValue):
) )
if len(self.values) < min_values: if len(self.values) < min_values:
raise ValueError(AnswerError.TOO_FEW_SELECTED) raise ValueError(AnswerError.TOO_FEW_SELECTED.value)
if len(self.values) > max_values: if len(self.values) > max_values:
raise ValueError(AnswerError.TOO_MANY_SELECTED) raise ValueError(AnswerError.TOO_MANY_SELECTED.value)
Value: TypeAlias = SelectorValue | TextValue Value: TypeAlias = SelectorValue | TextValue
class AnswerData(BaseValue): class AnswerData(BaseModel):
values: list[Value] values: list[Value]
@property @property
@ -68,7 +72,9 @@ class AnswerData(BaseValue):
uuids.add(value.question_id) uuids.add(value.question_id)
if len(v) != len(uuids): if len(v) != len(uuids):
raise ValueError(AnswerError.DUPLICATE_QUESTIONS) raise ValueError(AnswerError.DUPLICATE_QUESTIONS.value)
return v
class Answer(BaseModel): class Answer(BaseModel):
@ -80,14 +86,16 @@ class Answer(BaseModel):
@classmethod @classmethod
def answer_validator(cls, v, info): def answer_validator(cls, v, info):
uuids = v.question_uuids uuids = v.question_uuids
questions = info.data["form"].data questions = info.data["form"].data.questions
for question in questions: for question in questions:
if question.required and question.id not in uuids: if question.required and question.id not in uuids:
raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED) raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED.value)
if question.question_type != uuids[question.id].question_type: if question.question_type != uuids[question.id].question_type:
raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED) raise ValueError(AnswerError.REQUIRED_QUIESTION_NOT_ANSWERED.value)
uuids[question.id].validate(question)
del uuids[question.id] del uuids[question.id]
if len(uuids) > 0: if len(uuids) > 0:
raise ValueError(AnswerError.INCORRECT_IDS) raise ValueError(AnswerError.INCORRECT_IDS.value)
return v return v

View file

@ -46,7 +46,7 @@ class TextQuestion(BaseQuestion):
@classmethod @classmethod
def validate_min_length(cls, v, info): def validate_min_length(cls, v, info):
if v is not None and v < 0: if v is not None and v < 0:
raise ValueError(FormError.MIN_LENGTH_ERR) raise ValueError(FormError.MIN_LENGTH_ERR.value)
return v return v
@field_validator("max_length") @field_validator("max_length")
@ -55,9 +55,9 @@ class TextQuestion(BaseQuestion):
min_length = info.data.get("min_length") min_length = info.data.get("min_length")
if v is not None: if v is not None:
if v <= 0: if v <= 0:
raise ValueError(FormError.MAX_LENGTH_TOO_SMALL) raise ValueError(FormError.MAX_LENGTH_TOO_SMALL.value)
if min_length is not None and v < min_length: if min_length is not None and v < min_length:
raise ValueError(FormError.MAX_LENGTH_LESS_THAN_MIN_LENGTH) raise ValueError(FormError.MAX_LENGTH_LESS_THAN_MIN_LENGTH.value)
return v return v
@ -73,7 +73,7 @@ class SelectorQuestion(BaseQuestion):
options = info.data.get("options") options = info.data.get("options")
options = [] if not options else options options = [] if not options else options
if v is not None and (v < 1 or v > len(options)): if v is not None and (v < 1 or v > len(options)):
raise ValueError(FormError.MIN_VALUES_ERR) raise ValueError(FormError.MIN_VALUES_ERR.value)
return v return v
@field_validator("max_values") @field_validator("max_values")
@ -82,7 +82,7 @@ class SelectorQuestion(BaseQuestion):
min_values = info.data.get("min_values") min_values = info.data.get("min_values")
options = info.data.get("options") options = info.data.get("options")
if v is not None and (v > len(options) or min_values > v): if v is not None and (v > len(options) or min_values > v):
raise ValueError(FormError.MAX_VALUES_ERR) raise ValueError(FormError.MAX_VALUES_ERR.value)
return v return v
@ -102,7 +102,7 @@ class FormData(BaseModel):
uuids.add(question.id) uuids.add(question.id)
if len(v) != len(uuids): if len(v) != len(uuids):
raise ValueError(FormError.SIMMILAR_ID_ERR) raise ValueError(FormError.SIMMILAR_ID_ERR.value)
return v return v

View file

@ -3,9 +3,11 @@ from fastapi import APIRouter
from . import admin from . import admin
from . import user from . import user
from . import form from . import form
from . import answer
router = APIRouter() router = APIRouter()
router.include_router(admin.router) router.include_router(admin.router)
router.include_router(user.router) router.include_router(user.router)
router.include_router(form.router) router.include_router(form.router)
router.include_router(answer.router)

28
routes/answer.py Normal file
View file

@ -0,0 +1,28 @@
from fastapi import APIRouter, HTTPException
from sqlalchemy import select
from pydantic import ValidationError
import database
from models import AnswerData, Answer
router = APIRouter(prefix="/answer")
@router.post("/create")
async def create_answer(form_id: int, answer_data: AnswerData):
async with database.sessions.begin() as session:
answer = database.Answer(
form_id=form_id,
data=answer_data.model_dump(),
)
session.add(answer)
await session.flush()
await session.refresh(answer)
try:
answer_model = Answer.model_validate(answer)
except ValidationError as e:
raise HTTPException(400, e.errors()[0].get("msg"))
return answer_model

View file

@ -2,18 +2,14 @@ from fastapi import APIRouter, HTTPException
from sqlalchemy import select from sqlalchemy import select
import database import database
from models import FormData, Form, BaseModel from models import FormData, Form
from .utils import User from .utils import User
router = APIRouter(prefix="/form") router = APIRouter(prefix="/form")
class CreateForm(BaseModel):
form_id: int
@router.post("/create") @router.post("/create")
async def create_form(user: User, form_data: FormData) -> CreateForm: async def create_form(user: User, form_data: FormData) -> Form:
async with database.sessions.begin() as session: async with database.sessions.begin() as session:
form = database.Form( form = database.Form(
name=form_data.name, owner_id=user.id, data=form_data.model_dump() name=form_data.name, owner_id=user.id, data=form_data.model_dump()
@ -23,7 +19,7 @@ async def create_form(user: User, form_data: FormData) -> CreateForm:
await session.flush() await session.flush()
await session.refresh(form) await session.refresh(form)
return CreateForm(form_id=form.id) return Form.model_validate(form)
@router.delete("/delete") @router.delete("/delete")
@ -34,14 +30,14 @@ async def delete_form(user: User, id: int):
form = db_request.scalar_one_or_none() form = db_request.scalar_one_or_none()
if form is None: if form is None:
raise HTTPException(404, "No form was found") raise HTTPException(404, "Form not found")
if form.owner_id != user.id: if form.owner_id != user.id:
raise HTTPException(403, "Forbidden") raise HTTPException(403, "Forbidden")
await session.delete(form) await session.delete(form)
@router.get("/my") @router.get("/list")
async def user_forms(user: User): async def user_forms(user: User):
async with database.sessions.begin() as session: async with database.sessions.begin() as session:
return { return {
@ -52,3 +48,13 @@ async def user_forms(user: User):
) )
] ]
} }
@router.get("/get")
async def get_form(id: int) -> Form:
async with database.sessions.begin() as session:
stmt = select(database.Form).where(database.Form.id == id)
db_request = await session.execute(stmt)
form = db_request.scalar_one_or_none()
return Form.model_validate(form)