feat(api): use pydantic validation during style preset import

- Enforce name is present and not an empty string
- Provide empty string as default for positive and negative prompt
- Add `positive_prompt` as validation alias for `prompt` field
- Strip whitespace automatically
- Create `TypeAdapter` to validate the whole list in one go
This commit is contained in:
psychedelicious 2024-08-15 14:44:13 +10:00 committed by Mary Hipp Rogers
parent 15415c6d85
commit deb917825e

View File

@ -4,7 +4,7 @@ from enum import Enum
from typing import Any, Generator, Optional from typing import Any, Generator, Optional
from fastapi import UploadFile from fastapi import UploadFile
from pydantic import BaseModel, Field, TypeAdapter from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.metaenum import MetaEnum
@ -59,9 +59,19 @@ class StylePresetRecordWithImage(StylePresetRecordDTO):
class StylePresetImportRow(BaseModel): class StylePresetImportRow(BaseModel):
name: str name: str = Field(min_length=1, description="The name of the preset.")
prompt: str positive_prompt: str = Field(
negative_prompt: str default="",
description="The positive prompt for the preset.",
validation_alias=AliasChoices("positive_prompt", "prompt"),
)
negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
model_config = ConfigDict(str_strip_whitespace=True)
StylePresetImportList = list[StylePresetImportRow]
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
def parse_csv(file: UploadFile) -> Generator[StylePresetImportRow, None, None]: def parse_csv(file: UploadFile) -> Generator[StylePresetImportRow, None, None]:
@ -74,5 +84,5 @@ def parse_csv(file: UploadFile) -> Generator[StylePresetImportRow, None, None]:
raise StylePresetImportValidationError() raise StylePresetImportValidationError()
yield StylePresetImportRow( yield StylePresetImportRow(
name=row["name"].strip(), prompt=row["prompt"].strip(), negative_prompt=row["negative_prompt"].strip() name=row["name"].strip(), positive_prompt=row["prompt"].strip(), negative_prompt=row["negative_prompt"].strip()
) )