From 2d587547896bda30d2f5bf3f67358d91bc2abd2a Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Wed, 14 Aug 2024 11:09:11 -0400 Subject: [PATCH] feat(api): add endpoint to take a CSV, parse it, validate it, and create many style preset entries --- invokeai/app/api/routers/style_presets.py | 27 +++++++++++++++ .../style_preset_records_base.py | 5 +++ .../style_preset_records_common.py | 29 +++++++++++++++- .../style_preset_records_sqlite.py | 33 +++++++++++++++++++ 4 files changed, 93 insertions(+), 1 deletion(-) diff --git a/invokeai/app/api/routers/style_presets.py b/invokeai/app/api/routers/style_presets.py index 786c522c20..d7673cc25d 100644 --- a/invokeai/app/api/routers/style_presets.py +++ b/invokeai/app/api/routers/style_presets.py @@ -16,9 +16,11 @@ from invokeai.app.services.style_preset_records.style_preset_records_common impo PresetData, PresetType, StylePresetChanges, + StylePresetImportValidationError, StylePresetNotFoundError, StylePresetRecordWithImage, StylePresetWithoutId, + parse_csv, ) @@ -225,3 +227,28 @@ async def get_style_preset_image( return response except Exception: raise HTTPException(status_code=404) + + +@style_presets_router.post( + "/import", + operation_id="import_style_presets", +) +async def import_style_presets(file: UploadFile = File(description="The file to import")): + if not file.filename.endswith(".csv"): + raise HTTPException(status_code=400, detail="Invalid file type") + + try: + parsed_data = parse_csv(file) + except StylePresetImportValidationError: + raise HTTPException( + status_code=400, detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'" + ) + + style_presets: list[StylePresetWithoutId] = [] + + for style_preset in parsed_data: + preset_data = PresetData(positive_prompt=style_preset.prompt, negative_prompt=style_preset.negative_prompt) + style_preset = StylePresetWithoutId(name=style_preset.name, preset_data=preset_data, type=PresetType.User) + style_presets.append(style_preset) + + ApiDependencies.invoker.services.style_preset_records.create_many(style_presets) diff --git a/invokeai/app/services/style_preset_records/style_preset_records_base.py b/invokeai/app/services/style_preset_records/style_preset_records_base.py index 9e3a504e06..282388c7e4 100644 --- a/invokeai/app/services/style_preset_records/style_preset_records_base.py +++ b/invokeai/app/services/style_preset_records/style_preset_records_base.py @@ -20,6 +20,11 @@ class StylePresetRecordsStorageBase(ABC): """Creates a style preset.""" pass + @abstractmethod + def create_many(self, style_presets: list[StylePresetWithoutId]) -> None: + """Creates many style presets.""" + pass + @abstractmethod def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO: """Updates a style preset.""" diff --git a/invokeai/app/services/style_preset_records/style_preset_records_common.py b/invokeai/app/services/style_preset_records/style_preset_records_common.py index 964489b54d..d20b019e3e 100644 --- a/invokeai/app/services/style_preset_records/style_preset_records_common.py +++ b/invokeai/app/services/style_preset_records/style_preset_records_common.py @@ -1,6 +1,9 @@ +import csv +import io from enum import Enum -from typing import Any, Optional +from typing import Any, Generator, Optional +from fastapi import UploadFile from pydantic import BaseModel, Field, TypeAdapter from invokeai.app.util.metaenum import MetaEnum @@ -10,6 +13,10 @@ class StylePresetNotFoundError(Exception): """Raised when a style preset is not found""" +class StylePresetImportValidationError(Exception): + """Raised when a style preset import is not valid""" + + class PresetData(BaseModel, extra="forbid"): positive_prompt: str = Field(description="Positive prompt") negative_prompt: str = Field(description="Negative prompt") @@ -49,3 +56,23 @@ StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO) class StylePresetRecordWithImage(StylePresetRecordDTO): image: Optional[str] = Field(description="The path for image") + + +class StylePresetImportRow(BaseModel): + name: str + prompt: str + negative_prompt: str + + +def parse_csv(file: UploadFile) -> Generator[StylePresetImportRow, None, None]: + """Yield parsed and validated rows from the CSV file.""" + file_content = file.file.read().decode("utf-8") + csv_reader = csv.DictReader(io.StringIO(file_content)) + + for row in csv_reader: + if "name" not in row or "prompt" not in row or "negative_prompt" not in row: + raise StylePresetImportValidationError() + + yield StylePresetImportRow( + name=row["name"].strip(), prompt=row["prompt"].strip(), negative_prompt=row["negative_prompt"].strip() + ) diff --git a/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py index a98ff462f2..952cf35ba9 100644 --- a/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py +++ b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py @@ -75,6 +75,39 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase): self._lock.release() return self.get(style_preset_id) + def create_many(self, style_presets: list[StylePresetWithoutId]) -> None: + style_preset_ids = [] + try: + self._lock.acquire() + for style_preset in style_presets: + style_preset_id = uuid_string() + style_preset_ids.append(style_preset_id) + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO style_presets ( + id, + name, + preset_data, + type + ) + VALUES (?, ?, ?, ?); + """, + ( + style_preset_id, + style_preset.name, + style_preset.preset_data.model_dump_json(), + style_preset.type, + ), + ) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + + return None + def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO: try: self._lock.acquire()