mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(api): tidy style presets import logic
- Extract parsing into utility function - Log import errors - Forbid extra properties on the imported data
This commit is contained in:
parent
bd07c86db9
commit
60d754d1df
@ -1,27 +1,27 @@
|
|||||||
import csv
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from codecs import iterdecode
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
from fastapi import APIRouter, File, Form, HTTPException, Path, UploadFile
|
from fastapi import APIRouter, File, Form, HTTPException, Path, UploadFile
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.api.dependencies import ApiDependencies
|
from invokeai.app.api.dependencies import ApiDependencies
|
||||||
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
|
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
|
||||||
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
|
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
|
||||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||||
|
InvalidPresetImportDataError,
|
||||||
PresetData,
|
PresetData,
|
||||||
PresetType,
|
PresetType,
|
||||||
StylePresetChanges,
|
StylePresetChanges,
|
||||||
StylePresetImportListTypeAdapter,
|
|
||||||
StylePresetNotFoundError,
|
StylePresetNotFoundError,
|
||||||
StylePresetRecordWithImage,
|
StylePresetRecordWithImage,
|
||||||
StylePresetWithoutId,
|
StylePresetWithoutId,
|
||||||
|
UnsupportedFileTypeError,
|
||||||
|
parse_presets_from_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -235,36 +235,12 @@ async def get_style_preset_image(
|
|||||||
operation_id="import_style_presets",
|
operation_id="import_style_presets",
|
||||||
)
|
)
|
||||||
async def import_style_presets(file: UploadFile = File(description="The file to import")):
|
async def import_style_presets(file: UploadFile = File(description="The file to import")):
|
||||||
if file.content_type not in ["text/csv", "application/json"]:
|
|
||||||
raise HTTPException(status_code=400, detail="Unsupported file type")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if file.content_type == "text/csv":
|
style_presets = await parse_presets_from_file(file)
|
||||||
csv_reader = csv.DictReader(iterdecode(file.file, "utf-8"))
|
|
||||||
data = list(csv_reader)
|
|
||||||
else: # file.content_type == "application/json":
|
|
||||||
json_data = await file.read()
|
|
||||||
data = json.loads(json_data)
|
|
||||||
|
|
||||||
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
|
|
||||||
|
|
||||||
style_presets: list[StylePresetWithoutId] = []
|
|
||||||
|
|
||||||
for imported in imported_presets:
|
|
||||||
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
|
|
||||||
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
|
|
||||||
style_presets.append(style_preset)
|
|
||||||
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
|
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
|
||||||
except ValidationError:
|
except InvalidPresetImportDataError as e:
|
||||||
if file.content_type == "text/csv":
|
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
status_code=400,
|
except UnsupportedFileTypeError as e:
|
||||||
detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'",
|
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||||
)
|
raise HTTPException(status_code=415, detail=str(e))
|
||||||
else: # file.content_type == "application/json":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt'",
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
file.file.close()
|
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
|
import codecs
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
from fastapi import UploadFile
|
||||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
|
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
|
||||||
|
|
||||||
from invokeai.app.util.metaenum import MetaEnum
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
@ -10,10 +15,6 @@ class StylePresetNotFoundError(Exception):
|
|||||||
"""Raised when a style preset is not found"""
|
"""Raised when a style preset is not found"""
|
||||||
|
|
||||||
|
|
||||||
class StylePresetImportValidationError(Exception):
|
|
||||||
"""Raised when a style preset import is not valid"""
|
|
||||||
|
|
||||||
|
|
||||||
class PresetData(BaseModel, extra="forbid"):
|
class PresetData(BaseModel, extra="forbid"):
|
||||||
positive_prompt: str = Field(description="Positive prompt")
|
positive_prompt: str = Field(description="Positive prompt")
|
||||||
negative_prompt: str = Field(description="Negative prompt")
|
negative_prompt: str = Field(description="Negative prompt")
|
||||||
@ -64,8 +65,74 @@ class StylePresetImportRow(BaseModel):
|
|||||||
)
|
)
|
||||||
negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
|
negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
|
||||||
|
|
||||||
model_config = ConfigDict(str_strip_whitespace=True)
|
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
StylePresetImportList = list[StylePresetImportRow]
|
StylePresetImportList = list[StylePresetImportRow]
|
||||||
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
|
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedFileTypeError(ValueError):
|
||||||
|
"""Raised when an unsupported file type is encountered"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPresetImportDataError(ValueError):
|
||||||
|
"""Raised when invalid preset import data is encountered"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId]:
|
||||||
|
"""Parses style presets from a file. The file must be a CSV or JSON file.
|
||||||
|
|
||||||
|
If CSV, the file must have the following columns:
|
||||||
|
- name
|
||||||
|
- prompt (or positive_prompt)
|
||||||
|
- negative_prompt
|
||||||
|
|
||||||
|
If JSON, the file must be a list of objects with the following keys:
|
||||||
|
- name
|
||||||
|
- prompt (or positive_prompt)
|
||||||
|
- negative_prompt
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (UploadFile): The file to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[StylePresetWithoutId]: The parsed style presets.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UnsupportedFileTypeError: If the file type is not supported.
|
||||||
|
InvalidPresetImportDataError: If the data in the file is invalid.
|
||||||
|
"""
|
||||||
|
if file.content_type not in ["text/csv", "application/json"]:
|
||||||
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
|
if file.content_type == "text/csv":
|
||||||
|
csv_reader = csv.DictReader(codecs.iterdecode(file.file, "utf-8"))
|
||||||
|
data = list(csv_reader)
|
||||||
|
else: # file.content_type == "application/json":
|
||||||
|
json_data = await file.read()
|
||||||
|
data = json.loads(json_data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
|
||||||
|
|
||||||
|
style_presets: list[StylePresetWithoutId] = []
|
||||||
|
|
||||||
|
for imported in imported_presets:
|
||||||
|
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
|
||||||
|
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
|
||||||
|
style_presets.append(style_preset)
|
||||||
|
except pydantic.ValidationError as e:
|
||||||
|
if file.content_type == "text/csv":
|
||||||
|
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'"
|
||||||
|
else: # file.content_type == "application/json":
|
||||||
|
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt'"
|
||||||
|
raise InvalidPresetImportDataError(msg) from e
|
||||||
|
finally:
|
||||||
|
file.file.close()
|
||||||
|
|
||||||
|
return style_presets
|
||||||
|
Loading…
Reference in New Issue
Block a user