mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(api): add endpoint to take a CSV, parse it, validate it, and create many style preset entries
This commit is contained in:
parent
d36c43a10f
commit
2d58754789
@ -16,9 +16,11 @@ from invokeai.app.services.style_preset_records.style_preset_records_common impo
|
||||
PresetData,
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetImportValidationError,
|
||||
StylePresetNotFoundError,
|
||||
StylePresetRecordWithImage,
|
||||
StylePresetWithoutId,
|
||||
parse_csv,
|
||||
)
|
||||
|
||||
|
||||
@ -225,3 +227,28 @@ async def get_style_preset_image(
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@style_presets_router.post(
|
||||
"/import",
|
||||
operation_id="import_style_presets",
|
||||
)
|
||||
async def import_style_presets(file: UploadFile = File(description="The file to import")):
|
||||
if not file.filename.endswith(".csv"):
|
||||
raise HTTPException(status_code=400, detail="Invalid file type")
|
||||
|
||||
try:
|
||||
parsed_data = parse_csv(file)
|
||||
except StylePresetImportValidationError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'"
|
||||
)
|
||||
|
||||
style_presets: list[StylePresetWithoutId] = []
|
||||
|
||||
for style_preset in parsed_data:
|
||||
preset_data = PresetData(positive_prompt=style_preset.prompt, negative_prompt=style_preset.negative_prompt)
|
||||
style_preset = StylePresetWithoutId(name=style_preset.name, preset_data=preset_data, type=PresetType.User)
|
||||
style_presets.append(style_preset)
|
||||
|
||||
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
|
||||
|
@ -20,6 +20,11 @@ class StylePresetRecordsStorageBase(ABC):
|
||||
"""Creates a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
"""Creates many style presets."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
"""Updates a style preset."""
|
||||
|
@ -1,6 +1,9 @@
|
||||
import csv
|
||||
import io
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Generator, Optional
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
@ -10,6 +13,10 @@ class StylePresetNotFoundError(Exception):
|
||||
"""Raised when a style preset is not found"""
|
||||
|
||||
|
||||
class StylePresetImportValidationError(Exception):
|
||||
"""Raised when a style preset import is not valid"""
|
||||
|
||||
|
||||
class PresetData(BaseModel, extra="forbid"):
|
||||
positive_prompt: str = Field(description="Positive prompt")
|
||||
negative_prompt: str = Field(description="Negative prompt")
|
||||
@ -49,3 +56,23 @@ StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
|
||||
|
||||
class StylePresetRecordWithImage(StylePresetRecordDTO):
|
||||
image: Optional[str] = Field(description="The path for image")
|
||||
|
||||
|
||||
class StylePresetImportRow(BaseModel):
|
||||
name: str
|
||||
prompt: str
|
||||
negative_prompt: str
|
||||
|
||||
|
||||
def parse_csv(file: UploadFile) -> Generator[StylePresetImportRow, None, None]:
|
||||
"""Yield parsed and validated rows from the CSV file."""
|
||||
file_content = file.file.read().decode("utf-8")
|
||||
csv_reader = csv.DictReader(io.StringIO(file_content))
|
||||
|
||||
for row in csv_reader:
|
||||
if "name" not in row or "prompt" not in row or "negative_prompt" not in row:
|
||||
raise StylePresetImportValidationError()
|
||||
|
||||
yield StylePresetImportRow(
|
||||
name=row["name"].strip(), prompt=row["prompt"].strip(), negative_prompt=row["negative_prompt"].strip()
|
||||
)
|
||||
|
@ -75,6 +75,39 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
self._lock.release()
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
style_preset_ids = []
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for style_preset in style_presets:
|
||||
style_preset_id = uuid_string()
|
||||
style_preset_ids.append(style_preset_id)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
id,
|
||||
name,
|
||||
preset_data,
|
||||
type
|
||||
)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
style_preset_id,
|
||||
style_preset.name,
|
||||
style_preset.preset_data.model_dump_json(),
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return None
|
||||
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
Loading…
Reference in New Issue
Block a user