add type safety / validation to form data payloads and allow type to be passed through api

This commit is contained in:
Mary Hipp 2024-08-13 13:00:31 -04:00
parent b0760710d5
commit e5f7c2a9b7
3 changed files with 40 additions and 10 deletions

View File

@ -3,21 +3,35 @@ import json
import traceback import traceback
from typing import Optional from typing import Optional
import pydantic
from fastapi import APIRouter, File, Form, HTTPException, Path, UploadFile from fastapi import APIRouter, File, Form, HTTPException, Path, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
from invokeai.app.services.style_preset_records.style_preset_records_common import ( from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetData, PresetData,
PresetType,
StylePresetChanges, StylePresetChanges,
StylePresetNotFoundError, StylePresetNotFoundError,
StylePresetRecordWithImage, StylePresetRecordWithImage,
StylePresetWithoutId, StylePresetWithoutId,
) )
class StylePresetUpdateFormData(BaseModel):
name: str = Field(description="Preset name")
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
class StylePresetCreateFormData(StylePresetUpdateFormData):
type: PresetType = Field(description="Preset type")
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"]) style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
@ -75,11 +89,16 @@ async def update_style_preset(
except StylePresetImageFileNotFoundException: except StylePresetImageFileNotFoundException:
pass pass
try:
parsed_data = json.loads(data) parsed_data = json.loads(data)
validated_data = StylePresetUpdateFormData(**parsed_data)
name = parsed_data.get("name", "") name = validated_data.name
positive_prompt = parsed_data.get("positive_prompt", "") positive_prompt = validated_data.positive_prompt
negative_prompt = parsed_data.get("negative_prompt", "") negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt) preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
changes = StylePresetChanges(name=name, preset_data=preset_data) changes = StylePresetChanges(name=name, preset_data=preset_data)
@ -120,14 +139,20 @@ async def create_style_preset(
) -> StylePresetRecordWithImage: ) -> StylePresetRecordWithImage:
"""Creates a style preset""" """Creates a style preset"""
try:
parsed_data = json.loads(data) parsed_data = json.loads(data)
validated_data = StylePresetCreateFormData(**parsed_data)
name = parsed_data.get("name", "") name = validated_data.name
positive_prompt = parsed_data.get("positive_prompt", "") type = validated_data.type
negative_prompt = parsed_data.get("negative_prompt", "") positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt) preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data) style_preset = StylePresetWithoutId(name=name, preset_data=preset_data, type=type)
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset) new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
if image is not None: if image is not None:

View File

@ -7,6 +7,7 @@ import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form'; import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { Trans, useTranslation } from 'react-i18next'; import { Trans, useTranslation } from 'react-i18next';
import type { PresetType } from 'services/api/endpoints/stylePresets';
import { useCreateStylePresetMutation, useUpdateStylePresetMutation } from 'services/api/endpoints/stylePresets'; import { useCreateStylePresetMutation, useUpdateStylePresetMutation } from 'services/api/endpoints/stylePresets';
import { StylePresetImageField } from './StylePresetImageField'; import { StylePresetImageField } from './StylePresetImageField';
@ -47,6 +48,7 @@ export const StylePresetForm = ({
name: data.name, name: data.name,
positive_prompt: data.positivePrompt, positive_prompt: data.positivePrompt,
negative_prompt: data.negativePrompt, negative_prompt: data.negativePrompt,
type: 'user' as PresetType,
}, },
image: data.image, image: data.image,
}; };

View File

@ -1,10 +1,13 @@
import type { paths } from 'services/api/schema'; import type { paths } from 'services/api/schema';
import type { S } from 'services/api/types';
import { api, buildV1Url, LIST_TAG } from '..'; import { api, buildV1Url, LIST_TAG } from '..';
export type StylePresetRecordWithImage = export type StylePresetRecordWithImage =
paths['/api/v1/style_presets/i/{style_preset_id}']['get']['responses']['200']['content']['application/json']; paths['/api/v1/style_presets/i/{style_preset_id}']['get']['responses']['200']['content']['application/json'];
export type PresetType = S['PresetType'];
/** /**
* Builds an endpoint URL for the style_presets router * Builds an endpoint URL for the style_presets router
* @example * @example
@ -37,7 +40,7 @@ export const stylePresetsApi = api.injectEndpoints({
}), }),
createStylePreset: build.mutation< createStylePreset: build.mutation<
paths['/api/v1/style_presets/']['post']['responses']['200']['content']['application/json'], paths['/api/v1/style_presets/']['post']['responses']['200']['content']['application/json'],
{ data: { name: string; positive_prompt: string; negative_prompt: string }; image: Blob | null } { data: { name: string; positive_prompt: string; negative_prompt: string; type: PresetType }; image: Blob | null }
>({ >({
query: ({ data, image }) => { query: ({ data, image }) => {
const formData = new FormData(); const formData = new FormData();