From 62ec226a2636d864e7826ee7c92e271e48b8330b Mon Sep 17 00:00:00 2001 From: grey-cat-1908 <61203964+grey-cat-1908@users.noreply.github.com> Date: Mon, 12 Aug 2024 13:57:40 +0000 Subject: [PATCH] form create && get (by user) methods --- database/__init__.py | 1 + database/form.py | 4 ++-- models/__init__.py | 1 + models/form.py | 57 +++++++++++++++++++++++++------------------- routes/__init__.py | 2 ++ routes/form.py | 35 +++++++++++++++++++++++++++ routes/user.py | 2 +- routes/utils.py | 1 + 8 files changed, 75 insertions(+), 28 deletions(-) create mode 100644 routes/form.py diff --git a/database/__init__.py b/database/__init__.py index 38566ee..4733df9 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -11,3 +11,4 @@ class Base(DeclarativeBase): from .user import User +from .form import Form diff --git a/database/form.py b/database/form.py index 81fe25b..07a61bd 100644 --- a/database/form.py +++ b/database/form.py @@ -1,4 +1,4 @@ -from sqlalchemy import JSON +from sqlalchemy import JSON, ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship from database import Base, User @@ -9,5 +9,5 @@ class Form(Base): id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] - owner: Mapped[User] = relationship(User, cascade="all, delete") + owner_id: Mapped[int] = mapped_column(ForeignKey(User.id, ondelete="CASCADE")) data: Mapped[dict] = mapped_column(JSON) diff --git a/models/__init__.py b/models/__init__.py index 93ff70c..f2dfbfd 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -8,3 +8,4 @@ class BaseModel(pydantic.BaseModel): from .settings import settings from .user import * +from .form import * diff --git a/models/form.py b/models/form.py index 995903a..ab3d705 100644 --- a/models/form.py +++ b/models/form.py @@ -1,4 +1,4 @@ -from enum import Enum, auto +from enum import IntEnum, auto from typing import TypeAlias from pydantic import Field, field_validator @@ -6,9 +6,9 @@ from pydantic import Field, field_validator from models import BaseModel -class QuestionType(Enum): - text = auto() - choice = auto() +class QuestionType(IntEnum): + text = 1 + selector = 2 class BaseQuestion(BaseModel): @@ -23,57 +23,64 @@ class Option(BaseModel): class TextQuestion(BaseQuestion): - question_type = QuestionType.text + question_type: QuestionType = QuestionType.text min_length: int | None = None max_length: int | None = None - @field_validator('min_length') + @field_validator("min_length") @classmethod - def validate_min_length(cls, v, values): + def validate_min_length(cls, v, info): if v is not None and v < 0: - raise ValueError('min_length must be greater than or equal to 0') + raise ValueError("min_length must be greater than or equal to 0") return v - @field_validator('max_length') + @field_validator("max_length") @classmethod - def validate_max_length(cls, v, values): - min_length = values.get('min_length') + def validate_max_length(cls, v, info): + min_length = info.data.get("min_length") if v is not None: if v <= 0: - raise ValueError('max_length must be greater than 0') + raise ValueError("max_length must be greater than 0") if min_length is not None and v < min_length: - raise ValueError('max_length cannot be less than min_length') + raise ValueError("max_length cannot be less than min_length") return v class SelectorQuestion(BaseQuestion): - question_type = QuestionType.choice + question_type: QuestionType = QuestionType.selector min_values: int = 1 max_values: int | None = None options: list[Option] = [] - @field_validator('min_values') + @field_validator("min_values") @classmethod - def validate_min_values(cls, v, values): - options = values.get('options') + def validate_min_values(cls, v, info): + options = info.data.get("options") if v is not None and (v < 1 or v > len(options)): - raise ValueError('min_values must be greater than or equal to 1') + raise ValueError("min_values must be greater than or equal to 1") return v - @field_validator('max_values') + @field_validator("max_values") @classmethod - def validate_max_values(cls, v, values): - min_values = values.get('min_values') - options = values.get('options') + def validate_max_values(cls, v, info): + min_values = info.data.get("min_values") + options = info.data.get("options") if v is not None and (v > len(options) or min_values > v): - raise ValueError('max_values cannot be less than min_length or greater than the number of options') + raise ValueError( + "max_values cannot be less than min_length or greater than the number of options" + ) return v Question: TypeAlias = SelectorQuestion | TextQuestion +class FormData(BaseModel): + name: str = Field(min_length=1) + description: str | None = Field(None, min_length=1) + questions: list[Question] = [] + + class Form(BaseModel): id: int - name: str - questions: list[Question] = [] + data: FormData diff --git a/routes/__init__.py b/routes/__init__.py index 246f78b..b32dc9a 100644 --- a/routes/__init__.py +++ b/routes/__init__.py @@ -2,8 +2,10 @@ from fastapi import APIRouter from . import admin from . import user +from . import form router = APIRouter() router.include_router(admin.router) router.include_router(user.router) +router.include_router(form.router) diff --git a/routes/form.py b/routes/form.py new file mode 100644 index 0000000..56617cf --- /dev/null +++ b/routes/form.py @@ -0,0 +1,35 @@ +from fastapi import APIRouter +from sqlalchemy import select + +import database +from models import FormData, Form +from .utils import User + +router = APIRouter(prefix="/form") + + +@router.post("/create") +async def create_form(user: User, form_data: FormData): + async with database.sessions.begin() as session: + form = database.Form( + name=form_data.name, owner_id=user.id, data=form_data.model_dump() + ) + + session.add(form) + await session.flush() + await session.refresh(form) + + return {"status": "success", "form_id": form.id} + + +@router.get("/my") +async def user_forms(user: User): + async with database.sessions.begin() as session: + return { + "forms": [ + Form.model_validate(item) + for item in await session.scalars( + select(database.Form).where(database.Form.owner_id == user.id) + ) + ] + } diff --git a/routes/user.py b/routes/user.py index be346a7..5cb2483 100644 --- a/routes/user.py +++ b/routes/user.py @@ -13,7 +13,7 @@ router = APIRouter(prefix="/user") @router.post("/login") -async def login(auth: models.Auth): +async def login(auth: models.Auth) -> models.Token: async with database.sessions.begin() as session: stmt = select(database.User).where( database.User.username == auth.username.strip() diff --git a/routes/utils.py b/routes/utils.py index b147b9f..fca642c 100644 --- a/routes/utils.py +++ b/routes/utils.py @@ -43,6 +43,7 @@ async def verify_user(token: Annotated[str, Header(alias="x-token")]) -> databas except jwt.exceptions.InvalidSignatureError: raise HTTPException(401, "Invalid token") + session.expunge(user) return user