mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(api): support JSON for preset imports
This allows us to support Fooocus format presets.
This commit is contained in:
parent
deb917825e
commit
356661459b
@ -1,13 +1,15 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import traceback
|
||||
from codecs import iterdecode
|
||||
from typing import Optional
|
||||
|
||||
import pydantic
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Path, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
|
||||
@ -16,11 +18,10 @@ from invokeai.app.services.style_preset_records.style_preset_records_common impo
|
||||
PresetData,
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetImportValidationError,
|
||||
StylePresetImportListTypeAdapter,
|
||||
StylePresetNotFoundError,
|
||||
StylePresetRecordWithImage,
|
||||
StylePresetWithoutId,
|
||||
parse_csv,
|
||||
)
|
||||
|
||||
|
||||
@ -234,21 +235,36 @@ async def get_style_preset_image(
|
||||
operation_id="import_style_presets",
|
||||
)
|
||||
async def import_style_presets(file: UploadFile = File(description="The file to import")):
|
||||
if not file.filename.endswith(".csv"):
|
||||
raise HTTPException(status_code=400, detail="Invalid file type")
|
||||
if file.content_type not in ["text/csv", "application/json"]:
|
||||
raise HTTPException(status_code=400, detail="Unsupported file type")
|
||||
|
||||
try:
|
||||
parsed_data = parse_csv(file)
|
||||
except StylePresetImportValidationError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'"
|
||||
)
|
||||
if file.content_type == "text/csv":
|
||||
csv_reader = csv.DictReader(iterdecode(file.file, "utf-8"))
|
||||
data = list(csv_reader)
|
||||
else: # file.content_type == "application/json":
|
||||
json_data = await file.read()
|
||||
data = json.loads(json_data)
|
||||
|
||||
style_presets: list[StylePresetWithoutId] = []
|
||||
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
|
||||
|
||||
for style_preset in parsed_data:
|
||||
preset_data = PresetData(positive_prompt=style_preset.prompt, negative_prompt=style_preset.negative_prompt)
|
||||
style_preset = StylePresetWithoutId(name=style_preset.name, preset_data=preset_data, type=PresetType.User)
|
||||
style_presets.append(style_preset)
|
||||
style_presets: list[StylePresetWithoutId] = []
|
||||
|
||||
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
|
||||
for imported in imported_presets:
|
||||
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
|
||||
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
|
||||
style_presets.append(style_preset)
|
||||
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
|
||||
except ValidationError:
|
||||
if file.content_type == "text/csv":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'",
|
||||
)
|
||||
else: # file.content_type == "application/json":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt'",
|
||||
)
|
||||
finally:
|
||||
file.file.close()
|
||||
|
@ -1,9 +1,6 @@
|
||||
import csv
|
||||
import io
|
||||
from enum import Enum
|
||||
from typing import Any, Generator, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
@ -72,17 +69,3 @@ class StylePresetImportRow(BaseModel):
|
||||
|
||||
StylePresetImportList = list[StylePresetImportRow]
|
||||
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
|
||||
|
||||
|
||||
def parse_csv(file: UploadFile) -> Generator[StylePresetImportRow, None, None]:
|
||||
"""Yield parsed and validated rows from the CSV file."""
|
||||
file_content = file.file.read().decode("utf-8")
|
||||
csv_reader = csv.DictReader(io.StringIO(file_content))
|
||||
|
||||
for row in csv_reader:
|
||||
if "name" not in row or "prompt" not in row or "negative_prompt" not in row:
|
||||
raise StylePresetImportValidationError()
|
||||
|
||||
yield StylePresetImportRow(
|
||||
name=row["name"].strip(), positive_prompt=row["prompt"].strip(), negative_prompt=row["negative_prompt"].strip()
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user