form create && get (by user) methods

This commit is contained in:
grey-cat-1908 2024-08-12 13:57:40 +00:00
parent 2d84b36000
commit 62ec226a26
8 changed files with 75 additions and 28 deletions

View file

@ -11,3 +11,4 @@ class Base(DeclarativeBase):
from .user import User from .user import User
from .form import Form

View file

@ -1,4 +1,4 @@
from sqlalchemy import JSON from sqlalchemy import JSON, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from database import Base, User from database import Base, User
@ -9,5 +9,5 @@ class Form(Base):
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] name: Mapped[str]
owner: Mapped[User] = relationship(User, cascade="all, delete") owner_id: Mapped[int] = mapped_column(ForeignKey(User.id, ondelete="CASCADE"))
data: Mapped[dict] = mapped_column(JSON) data: Mapped[dict] = mapped_column(JSON)

View file

@ -8,3 +8,4 @@ class BaseModel(pydantic.BaseModel):
from .settings import settings from .settings import settings
from .user import * from .user import *
from .form import *

View file

@ -1,4 +1,4 @@
from enum import Enum, auto from enum import IntEnum, auto
from typing import TypeAlias from typing import TypeAlias
from pydantic import Field, field_validator from pydantic import Field, field_validator
@ -6,9 +6,9 @@ from pydantic import Field, field_validator
from models import BaseModel from models import BaseModel
class QuestionType(Enum): class QuestionType(IntEnum):
text = auto() text = 1
choice = auto() selector = 2
class BaseQuestion(BaseModel): class BaseQuestion(BaseModel):
@ -23,57 +23,64 @@ class Option(BaseModel):
class TextQuestion(BaseQuestion): class TextQuestion(BaseQuestion):
question_type = QuestionType.text question_type: QuestionType = QuestionType.text
min_length: int | None = None min_length: int | None = None
max_length: int | None = None max_length: int | None = None
@field_validator('min_length') @field_validator("min_length")
@classmethod @classmethod
def validate_min_length(cls, v, values): def validate_min_length(cls, v, info):
if v is not None and v < 0: if v is not None and v < 0:
raise ValueError('min_length must be greater than or equal to 0') raise ValueError("min_length must be greater than or equal to 0")
return v return v
@field_validator('max_length') @field_validator("max_length")
@classmethod @classmethod
def validate_max_length(cls, v, values): def validate_max_length(cls, v, info):
min_length = values.get('min_length') min_length = info.data.get("min_length")
if v is not None: if v is not None:
if v <= 0: if v <= 0:
raise ValueError('max_length must be greater than 0') raise ValueError("max_length must be greater than 0")
if min_length is not None and v < min_length: if min_length is not None and v < min_length:
raise ValueError('max_length cannot be less than min_length') raise ValueError("max_length cannot be less than min_length")
return v return v
class SelectorQuestion(BaseQuestion): class SelectorQuestion(BaseQuestion):
question_type = QuestionType.choice question_type: QuestionType = QuestionType.selector
min_values: int = 1 min_values: int = 1
max_values: int | None = None max_values: int | None = None
options: list[Option] = [] options: list[Option] = []
@field_validator('min_values') @field_validator("min_values")
@classmethod @classmethod
def validate_min_values(cls, v, values): def validate_min_values(cls, v, info):
options = values.get('options') options = info.data.get("options")
if v is not None and (v < 1 or v > len(options)): if v is not None and (v < 1 or v > len(options)):
raise ValueError('min_values must be greater than or equal to 1') raise ValueError("min_values must be greater than or equal to 1")
return v return v
@field_validator('max_values') @field_validator("max_values")
@classmethod @classmethod
def validate_max_values(cls, v, values): def validate_max_values(cls, v, info):
min_values = values.get('min_values') min_values = info.data.get("min_values")
options = values.get('options') options = info.data.get("options")
if v is not None and (v > len(options) or min_values > v): if v is not None and (v > len(options) or min_values > v):
raise ValueError('max_values cannot be less than min_length or greater than the number of options') raise ValueError(
"max_values cannot be less than min_length or greater than the number of options"
)
return v return v
Question: TypeAlias = SelectorQuestion | TextQuestion Question: TypeAlias = SelectorQuestion | TextQuestion
class FormData(BaseModel):
name: str = Field(min_length=1)
description: str | None = Field(None, min_length=1)
questions: list[Question] = []
class Form(BaseModel): class Form(BaseModel):
id: int id: int
name: str data: FormData
questions: list[Question] = []

View file

@ -2,8 +2,10 @@ from fastapi import APIRouter
from . import admin from . import admin
from . import user from . import user
from . import form
router = APIRouter() router = APIRouter()
router.include_router(admin.router) router.include_router(admin.router)
router.include_router(user.router) router.include_router(user.router)
router.include_router(form.router)

35
routes/form.py Normal file
View file

@ -0,0 +1,35 @@
from fastapi import APIRouter
from sqlalchemy import select
import database
from models import FormData, Form
from .utils import User
router = APIRouter(prefix="/form")
@router.post("/create")
async def create_form(user: User, form_data: FormData):
async with database.sessions.begin() as session:
form = database.Form(
name=form_data.name, owner_id=user.id, data=form_data.model_dump()
)
session.add(form)
await session.flush()
await session.refresh(form)
return {"status": "success", "form_id": form.id}
@router.get("/my")
async def user_forms(user: User):
async with database.sessions.begin() as session:
return {
"forms": [
Form.model_validate(item)
for item in await session.scalars(
select(database.Form).where(database.Form.owner_id == user.id)
)
]
}

View file

@ -13,7 +13,7 @@ router = APIRouter(prefix="/user")
@router.post("/login") @router.post("/login")
async def login(auth: models.Auth): async def login(auth: models.Auth) -> models.Token:
async with database.sessions.begin() as session: async with database.sessions.begin() as session:
stmt = select(database.User).where( stmt = select(database.User).where(
database.User.username == auth.username.strip() database.User.username == auth.username.strip()

View file

@ -43,6 +43,7 @@ async def verify_user(token: Annotated[str, Header(alias="x-token")]) -> databas
except jwt.exceptions.InvalidSignatureError: except jwt.exceptions.InvalidSignatureError:
raise HTTPException(401, "Invalid token") raise HTTPException(401, "Invalid token")
session.expunge(user)
return user return user