From 60d754d1dffccc0008c5edc4269a7f4ff9870689 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Aug 2024 15:31:53 +1000 Subject: [PATCH] feat(api): tidy style presets import logic - Extract parsing into utility function - Log import errors - Forbid extra properties on the imported data --- invokeai/app/api/routers/style_presets.py | 46 +++-------- .../style_preset_records_common.py | 77 +++++++++++++++++-- 2 files changed, 83 insertions(+), 40 deletions(-) diff --git a/invokeai/app/api/routers/style_presets.py b/invokeai/app/api/routers/style_presets.py index ccea914750..14d8c666aa 100644 --- a/invokeai/app/api/routers/style_presets.py +++ b/invokeai/app/api/routers/style_presets.py @@ -1,27 +1,27 @@ -import csv import io import json import traceback -from codecs import iterdecode from typing import Optional import pydantic from fastapi import APIRouter, File, Form, HTTPException, Path, UploadFile from fastapi.responses import FileResponse from PIL import Image -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel, Field from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException from invokeai.app.services.style_preset_records.style_preset_records_common import ( + InvalidPresetImportDataError, PresetData, PresetType, StylePresetChanges, - StylePresetImportListTypeAdapter, StylePresetNotFoundError, StylePresetRecordWithImage, StylePresetWithoutId, + UnsupportedFileTypeError, + parse_presets_from_file, ) @@ -235,36 +235,12 @@ async def get_style_preset_image( operation_id="import_style_presets", ) async def import_style_presets(file: UploadFile = File(description="The file to import")): - if file.content_type not in ["text/csv", "application/json"]: - raise HTTPException(status_code=400, detail="Unsupported file type") - try: - if file.content_type == "text/csv": - csv_reader = csv.DictReader(iterdecode(file.file, "utf-8")) - data = list(csv_reader) - else: # file.content_type == "application/json": - json_data = await file.read() - data = json.loads(json_data) - - imported_presets = StylePresetImportListTypeAdapter.validate_python(data) - - style_presets: list[StylePresetWithoutId] = [] - - for imported in imported_presets: - preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt) - style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User) - style_presets.append(style_preset) + style_presets = await parse_presets_from_file(file) ApiDependencies.invoker.services.style_preset_records.create_many(style_presets) - except ValidationError: - if file.content_type == "text/csv": - raise HTTPException( - status_code=400, - detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'", - ) - else: # file.content_type == "application/json": - raise HTTPException( - status_code=400, - detail="Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt'", - ) - finally: - file.file.close() + except InvalidPresetImportDataError as e: + ApiDependencies.invoker.services.logger.error(traceback.format_exc()) + raise HTTPException(status_code=400, detail=str(e)) + except UnsupportedFileTypeError as e: + ApiDependencies.invoker.services.logger.error(traceback.format_exc()) + raise HTTPException(status_code=415, detail=str(e)) diff --git a/invokeai/app/services/style_preset_records/style_preset_records_common.py b/invokeai/app/services/style_preset_records/style_preset_records_common.py index 2d33a7ea76..34a30d0377 100644 --- a/invokeai/app/services/style_preset_records/style_preset_records_common.py +++ b/invokeai/app/services/style_preset_records/style_preset_records_common.py @@ -1,6 +1,11 @@ +import codecs +import csv +import json from enum import Enum from typing import Any, Optional +import pydantic +from fastapi import UploadFile from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter from invokeai.app.util.metaenum import MetaEnum @@ -10,10 +15,6 @@ class StylePresetNotFoundError(Exception): """Raised when a style preset is not found""" -class StylePresetImportValidationError(Exception): - """Raised when a style preset import is not valid""" - - class PresetData(BaseModel, extra="forbid"): positive_prompt: str = Field(description="Positive prompt") negative_prompt: str = Field(description="Negative prompt") @@ -64,8 +65,74 @@ class StylePresetImportRow(BaseModel): ) negative_prompt: str = Field(default="", description="The negative prompt for the preset.") - model_config = ConfigDict(str_strip_whitespace=True) + model_config = ConfigDict(str_strip_whitespace=True, extra="forbid") StylePresetImportList = list[StylePresetImportRow] StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList) + + +class UnsupportedFileTypeError(ValueError): + """Raised when an unsupported file type is encountered""" + + pass + + +class InvalidPresetImportDataError(ValueError): + """Raised when invalid preset import data is encountered""" + + pass + + +async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId]: + """Parses style presets from a file. The file must be a CSV or JSON file. + + If CSV, the file must have the following columns: + - name + - prompt (or positive_prompt) + - negative_prompt + + If JSON, the file must be a list of objects with the following keys: + - name + - prompt (or positive_prompt) + - negative_prompt + + Args: + file (UploadFile): The file to parse. + + Returns: + list[StylePresetWithoutId]: The parsed style presets. + + Raises: + UnsupportedFileTypeError: If the file type is not supported. + InvalidPresetImportDataError: If the data in the file is invalid. + """ + if file.content_type not in ["text/csv", "application/json"]: + raise UnsupportedFileTypeError() + + if file.content_type == "text/csv": + csv_reader = csv.DictReader(codecs.iterdecode(file.file, "utf-8")) + data = list(csv_reader) + else: # file.content_type == "application/json": + json_data = await file.read() + data = json.loads(json_data) + + try: + imported_presets = StylePresetImportListTypeAdapter.validate_python(data) + + style_presets: list[StylePresetWithoutId] = [] + + for imported in imported_presets: + preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt) + style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User) + style_presets.append(style_preset) + except pydantic.ValidationError as e: + if file.content_type == "text/csv": + msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'" + else: # file.content_type == "application/json": + msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt'" + raise InvalidPresetImportDataError(msg) from e + finally: + file.file.close() + + return style_presets