feat(api): add endpoint to take a CSV, parse it, validate it, and create many style preset entries

This commit is contained in:
Mary Hipp 2024-08-14 11:09:11 -04:00 committed by Mary Hipp Rogers
parent d36c43a10f
commit 2d58754789
4 changed files with 93 additions and 1 deletions

View File

@ -16,9 +16,11 @@ from invokeai.app.services.style_preset_records.style_preset_records_common impo
PresetData,
PresetType,
StylePresetChanges,
StylePresetImportValidationError,
StylePresetNotFoundError,
StylePresetRecordWithImage,
StylePresetWithoutId,
parse_csv,
)
@ -225,3 +227,28 @@ async def get_style_preset_image(
return response
except Exception:
raise HTTPException(status_code=404)
@style_presets_router.post(
"/import",
operation_id="import_style_presets",
)
async def import_style_presets(file: UploadFile = File(description="The file to import")):
if not file.filename.endswith(".csv"):
raise HTTPException(status_code=400, detail="Invalid file type")
try:
parsed_data = parse_csv(file)
except StylePresetImportValidationError:
raise HTTPException(
status_code=400, detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'"
)
style_presets: list[StylePresetWithoutId] = []
for style_preset in parsed_data:
preset_data = PresetData(positive_prompt=style_preset.prompt, negative_prompt=style_preset.negative_prompt)
style_preset = StylePresetWithoutId(name=style_preset.name, preset_data=preset_data, type=PresetType.User)
style_presets.append(style_preset)
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)

View File

@ -20,6 +20,11 @@ class StylePresetRecordsStorageBase(ABC):
"""Creates a style preset."""
pass
@abstractmethod
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
"""Creates many style presets."""
pass
@abstractmethod
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
"""Updates a style preset."""

View File

@ -1,6 +1,9 @@
import csv
import io
from enum import Enum
from typing import Any, Optional
from typing import Any, Generator, Optional
from fastapi import UploadFile
from pydantic import BaseModel, Field, TypeAdapter
from invokeai.app.util.metaenum import MetaEnum
@ -10,6 +13,10 @@ class StylePresetNotFoundError(Exception):
"""Raised when a style preset is not found"""
class StylePresetImportValidationError(Exception):
"""Raised when a style preset import is not valid"""
class PresetData(BaseModel, extra="forbid"):
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
@ -49,3 +56,23 @@ StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
class StylePresetRecordWithImage(StylePresetRecordDTO):
image: Optional[str] = Field(description="The path for image")
class StylePresetImportRow(BaseModel):
name: str
prompt: str
negative_prompt: str
def parse_csv(file: UploadFile) -> Generator[StylePresetImportRow, None, None]:
"""Yield parsed and validated rows from the CSV file."""
file_content = file.file.read().decode("utf-8")
csv_reader = csv.DictReader(io.StringIO(file_content))
for row in csv_reader:
if "name" not in row or "prompt" not in row or "negative_prompt" not in row:
raise StylePresetImportValidationError()
yield StylePresetImportRow(
name=row["name"].strip(), prompt=row["prompt"].strip(), negative_prompt=row["negative_prompt"].strip()
)

View File

@ -75,6 +75,39 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
self._lock.release()
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
self._lock.acquire()
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
name,
preset_data,
type
)
VALUES (?, ?, ?, ?);
""",
(
style_preset_id,
style_preset.name,
style_preset.preset_data.model_dump_json(),
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try:
self._lock.acquire()