feat(api): support JSON for preset imports

This allows us to support Fooocus format presets.
This commit is contained in:
psychedelicious 2024-08-15 14:45:04 +10:00 committed by Mary Hipp Rogers
parent deb917825e
commit 356661459b
2 changed files with 33 additions and 34 deletions

View File

@ -1,13 +1,15 @@
import csv
import io
import json
import traceback
from codecs import iterdecode
from typing import Optional
import pydantic
from fastapi import APIRouter, File, Form, HTTPException, Path, UploadFile
from fastapi.responses import FileResponse
from PIL import Image
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ValidationError
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
@ -16,11 +18,10 @@ from invokeai.app.services.style_preset_records.style_preset_records_common impo
PresetData,
PresetType,
StylePresetChanges,
StylePresetImportValidationError,
StylePresetImportListTypeAdapter,
StylePresetNotFoundError,
StylePresetRecordWithImage,
StylePresetWithoutId,
parse_csv,
)
@ -234,21 +235,36 @@ async def get_style_preset_image(
operation_id="import_style_presets",
)
async def import_style_presets(file: UploadFile = File(description="The file to import")):
if not file.filename.endswith(".csv"):
raise HTTPException(status_code=400, detail="Invalid file type")
if file.content_type not in ["text/csv", "application/json"]:
raise HTTPException(status_code=400, detail="Unsupported file type")
try:
parsed_data = parse_csv(file)
except StylePresetImportValidationError:
raise HTTPException(
status_code=400, detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'"
)
if file.content_type == "text/csv":
csv_reader = csv.DictReader(iterdecode(file.file, "utf-8"))
data = list(csv_reader)
else: # file.content_type == "application/json":
json_data = await file.read()
data = json.loads(json_data)
style_presets: list[StylePresetWithoutId] = []
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
for style_preset in parsed_data:
preset_data = PresetData(positive_prompt=style_preset.prompt, negative_prompt=style_preset.negative_prompt)
style_preset = StylePresetWithoutId(name=style_preset.name, preset_data=preset_data, type=PresetType.User)
style_presets.append(style_preset)
style_presets: list[StylePresetWithoutId] = []
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
for imported in imported_presets:
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
style_presets.append(style_preset)
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
except ValidationError:
if file.content_type == "text/csv":
raise HTTPException(
status_code=400,
detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'",
)
else: # file.content_type == "application/json":
raise HTTPException(
status_code=400,
detail="Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt'",
)
finally:
file.file.close()

View File

@ -1,9 +1,6 @@
import csv
import io
from enum import Enum
from typing import Any, Generator, Optional
from typing import Any, Optional
from fastapi import UploadFile
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
from invokeai.app.util.metaenum import MetaEnum
@ -72,17 +69,3 @@ class StylePresetImportRow(BaseModel):
StylePresetImportList = list[StylePresetImportRow]
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
def parse_csv(file: UploadFile) -> Generator[StylePresetImportRow, None, None]:
"""Yield parsed and validated rows from the CSV file."""
file_content = file.file.read().decode("utf-8")
csv_reader = csv.DictReader(io.StringIO(file_content))
for row in csv_reader:
if "name" not in row or "prompt" not in row or "negative_prompt" not in row:
raise StylePresetImportValidationError()
yield StylePresetImportRow(
name=row["name"].strip(), positive_prompt=row["prompt"].strip(), negative_prompt=row["negative_prompt"].strip()
)