feat(api): tidy style presets import logic

- Extract parsing into utility function
- Log import errors
- Forbid extra properties on the imported data
This commit is contained in:
psychedelicious 2024-08-15 15:31:53 +10:00 committed by Mary Hipp Rogers
parent bd07c86db9
commit 60d754d1df
2 changed files with 83 additions and 40 deletions

View File

@ -1,27 +1,27 @@
import csv
import io import io
import json import json
import traceback import traceback
from codecs import iterdecode
from typing import Optional from typing import Optional
import pydantic import pydantic
from fastapi import APIRouter, File, Form, HTTPException, Path, UploadFile from fastapi import APIRouter, File, Form, HTTPException, Path, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
from invokeai.app.services.style_preset_records.style_preset_records_common import ( from invokeai.app.services.style_preset_records.style_preset_records_common import (
InvalidPresetImportDataError,
PresetData, PresetData,
PresetType, PresetType,
StylePresetChanges, StylePresetChanges,
StylePresetImportListTypeAdapter,
StylePresetNotFoundError, StylePresetNotFoundError,
StylePresetRecordWithImage, StylePresetRecordWithImage,
StylePresetWithoutId, StylePresetWithoutId,
UnsupportedFileTypeError,
parse_presets_from_file,
) )
@ -235,36 +235,12 @@ async def get_style_preset_image(
operation_id="import_style_presets", operation_id="import_style_presets",
) )
async def import_style_presets(file: UploadFile = File(description="The file to import")): async def import_style_presets(file: UploadFile = File(description="The file to import")):
if file.content_type not in ["text/csv", "application/json"]:
raise HTTPException(status_code=400, detail="Unsupported file type")
try: try:
if file.content_type == "text/csv": style_presets = await parse_presets_from_file(file)
csv_reader = csv.DictReader(iterdecode(file.file, "utf-8"))
data = list(csv_reader)
else: # file.content_type == "application/json":
json_data = await file.read()
data = json.loads(json_data)
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
style_presets: list[StylePresetWithoutId] = []
for imported in imported_presets:
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
style_presets.append(style_preset)
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets) ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
except ValidationError: except InvalidPresetImportDataError as e:
if file.content_type == "text/csv": ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException( raise HTTPException(status_code=400, detail=str(e))
status_code=400, except UnsupportedFileTypeError as e:
detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'", ApiDependencies.invoker.services.logger.error(traceback.format_exc())
) raise HTTPException(status_code=415, detail=str(e))
else: # file.content_type == "application/json":
raise HTTPException(
status_code=400,
detail="Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt'",
)
finally:
file.file.close()

View File

@ -1,6 +1,11 @@
import codecs
import csv
import json
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
import pydantic
from fastapi import UploadFile
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.metaenum import MetaEnum
@ -10,10 +15,6 @@ class StylePresetNotFoundError(Exception):
"""Raised when a style preset is not found""" """Raised when a style preset is not found"""
class StylePresetImportValidationError(Exception):
"""Raised when a style preset import is not valid"""
class PresetData(BaseModel, extra="forbid"): class PresetData(BaseModel, extra="forbid"):
positive_prompt: str = Field(description="Positive prompt") positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt") negative_prompt: str = Field(description="Negative prompt")
@ -64,8 +65,74 @@ class StylePresetImportRow(BaseModel):
) )
negative_prompt: str = Field(default="", description="The negative prompt for the preset.") negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
model_config = ConfigDict(str_strip_whitespace=True) model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
StylePresetImportList = list[StylePresetImportRow] StylePresetImportList = list[StylePresetImportRow]
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList) StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
class UnsupportedFileTypeError(ValueError):
"""Raised when an unsupported file type is encountered"""
pass
class InvalidPresetImportDataError(ValueError):
"""Raised when invalid preset import data is encountered"""
pass
async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId]:
"""Parses style presets from a file. The file must be a CSV or JSON file.
If CSV, the file must have the following columns:
- name
- prompt (or positive_prompt)
- negative_prompt
If JSON, the file must be a list of objects with the following keys:
- name
- prompt (or positive_prompt)
- negative_prompt
Args:
file (UploadFile): The file to parse.
Returns:
list[StylePresetWithoutId]: The parsed style presets.
Raises:
UnsupportedFileTypeError: If the file type is not supported.
InvalidPresetImportDataError: If the data in the file is invalid.
"""
if file.content_type not in ["text/csv", "application/json"]:
raise UnsupportedFileTypeError()
if file.content_type == "text/csv":
csv_reader = csv.DictReader(codecs.iterdecode(file.file, "utf-8"))
data = list(csv_reader)
else: # file.content_type == "application/json":
json_data = await file.read()
data = json.loads(json_data)
try:
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
style_presets: list[StylePresetWithoutId] = []
for imported in imported_presets:
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
style_presets.append(style_preset)
except pydantic.ValidationError as e:
if file.content_type == "text/csv":
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'"
else: # file.content_type == "application/json":
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt'"
raise InvalidPresetImportDataError(msg) from e
finally:
file.file.close()
return style_presets