feat(api): add endpoint to take a CSV, parse it, validate it, and create many style preset entries

This commit is contained in:
Mary Hipp 2024-08-14 11:09:11 -04:00 committed by Mary Hipp Rogers
parent d36c43a10f
commit 2d58754789
4 changed files with 93 additions and 1 deletions

View File

@ -16,9 +16,11 @@ from invokeai.app.services.style_preset_records.style_preset_records_common impo
PresetData, PresetData,
PresetType, PresetType,
StylePresetChanges, StylePresetChanges,
StylePresetImportValidationError,
StylePresetNotFoundError, StylePresetNotFoundError,
StylePresetRecordWithImage, StylePresetRecordWithImage,
StylePresetWithoutId, StylePresetWithoutId,
parse_csv,
) )
@ -225,3 +227,28 @@ async def get_style_preset_image(
return response return response
except Exception: except Exception:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@style_presets_router.post(
"/import",
operation_id="import_style_presets",
)
async def import_style_presets(file: UploadFile = File(description="The file to import")):
if not file.filename.endswith(".csv"):
raise HTTPException(status_code=400, detail="Invalid file type")
try:
parsed_data = parse_csv(file)
except StylePresetImportValidationError:
raise HTTPException(
status_code=400, detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'"
)
style_presets: list[StylePresetWithoutId] = []
for style_preset in parsed_data:
preset_data = PresetData(positive_prompt=style_preset.prompt, negative_prompt=style_preset.negative_prompt)
style_preset = StylePresetWithoutId(name=style_preset.name, preset_data=preset_data, type=PresetType.User)
style_presets.append(style_preset)
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)

View File

@ -20,6 +20,11 @@ class StylePresetRecordsStorageBase(ABC):
"""Creates a style preset.""" """Creates a style preset."""
pass pass
@abstractmethod
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
"""Creates many style presets."""
pass
@abstractmethod @abstractmethod
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO: def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
"""Updates a style preset.""" """Updates a style preset."""

View File

@ -1,6 +1,9 @@
import csv
import io
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Generator, Optional
from fastapi import UploadFile
from pydantic import BaseModel, Field, TypeAdapter from pydantic import BaseModel, Field, TypeAdapter
from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.metaenum import MetaEnum
@ -10,6 +13,10 @@ class StylePresetNotFoundError(Exception):
"""Raised when a style preset is not found""" """Raised when a style preset is not found"""
class StylePresetImportValidationError(Exception):
"""Raised when a style preset import is not valid"""
class PresetData(BaseModel, extra="forbid"): class PresetData(BaseModel, extra="forbid"):
positive_prompt: str = Field(description="Positive prompt") positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt") negative_prompt: str = Field(description="Negative prompt")
@ -49,3 +56,23 @@ StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
class StylePresetRecordWithImage(StylePresetRecordDTO): class StylePresetRecordWithImage(StylePresetRecordDTO):
image: Optional[str] = Field(description="The path for image") image: Optional[str] = Field(description="The path for image")
class StylePresetImportRow(BaseModel):
name: str
prompt: str
negative_prompt: str
def parse_csv(file: UploadFile) -> Generator[StylePresetImportRow, None, None]:
"""Yield parsed and validated rows from the CSV file."""
file_content = file.file.read().decode("utf-8")
csv_reader = csv.DictReader(io.StringIO(file_content))
for row in csv_reader:
if "name" not in row or "prompt" not in row or "negative_prompt" not in row:
raise StylePresetImportValidationError()
yield StylePresetImportRow(
name=row["name"].strip(), prompt=row["prompt"].strip(), negative_prompt=row["negative_prompt"].strip()
)

View File

@ -75,6 +75,39 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
self._lock.release() self._lock.release()
return self.get(style_preset_id) return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
self._lock.acquire()
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
name,
preset_data,
type
)
VALUES (?, ?, ?, ?);
""",
(
style_preset_id,
style_preset.name,
style_preset.preset_data.model_dump_json(),
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO: def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try: try:
self._lock.acquire() self._lock.acquire()