mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(api): add endpoint to take a CSV, parse it, validate it, and create many style preset entries
This commit is contained in:
parent
d36c43a10f
commit
2d58754789
@ -16,9 +16,11 @@ from invokeai.app.services.style_preset_records.style_preset_records_common impo
|
|||||||
PresetData,
|
PresetData,
|
||||||
PresetType,
|
PresetType,
|
||||||
StylePresetChanges,
|
StylePresetChanges,
|
||||||
|
StylePresetImportValidationError,
|
||||||
StylePresetNotFoundError,
|
StylePresetNotFoundError,
|
||||||
StylePresetRecordWithImage,
|
StylePresetRecordWithImage,
|
||||||
StylePresetWithoutId,
|
StylePresetWithoutId,
|
||||||
|
parse_csv,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -225,3 +227,28 @@ async def get_style_preset_image(
|
|||||||
return response
|
return response
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
|
@style_presets_router.post(
|
||||||
|
"/import",
|
||||||
|
operation_id="import_style_presets",
|
||||||
|
)
|
||||||
|
async def import_style_presets(file: UploadFile = File(description="The file to import")):
|
||||||
|
if not file.filename.endswith(".csv"):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid file type")
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_data = parse_csv(file)
|
||||||
|
except StylePresetImportValidationError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'"
|
||||||
|
)
|
||||||
|
|
||||||
|
style_presets: list[StylePresetWithoutId] = []
|
||||||
|
|
||||||
|
for style_preset in parsed_data:
|
||||||
|
preset_data = PresetData(positive_prompt=style_preset.prompt, negative_prompt=style_preset.negative_prompt)
|
||||||
|
style_preset = StylePresetWithoutId(name=style_preset.name, preset_data=preset_data, type=PresetType.User)
|
||||||
|
style_presets.append(style_preset)
|
||||||
|
|
||||||
|
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
|
||||||
|
@ -20,6 +20,11 @@ class StylePresetRecordsStorageBase(ABC):
|
|||||||
"""Creates a style preset."""
|
"""Creates a style preset."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||||
|
"""Creates many style presets."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||||
"""Updates a style preset."""
|
"""Updates a style preset."""
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
|
import csv
|
||||||
|
import io
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Generator, Optional
|
||||||
|
|
||||||
|
from fastapi import UploadFile
|
||||||
from pydantic import BaseModel, Field, TypeAdapter
|
from pydantic import BaseModel, Field, TypeAdapter
|
||||||
|
|
||||||
from invokeai.app.util.metaenum import MetaEnum
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
@ -10,6 +13,10 @@ class StylePresetNotFoundError(Exception):
|
|||||||
"""Raised when a style preset is not found"""
|
"""Raised when a style preset is not found"""
|
||||||
|
|
||||||
|
|
||||||
|
class StylePresetImportValidationError(Exception):
|
||||||
|
"""Raised when a style preset import is not valid"""
|
||||||
|
|
||||||
|
|
||||||
class PresetData(BaseModel, extra="forbid"):
|
class PresetData(BaseModel, extra="forbid"):
|
||||||
positive_prompt: str = Field(description="Positive prompt")
|
positive_prompt: str = Field(description="Positive prompt")
|
||||||
negative_prompt: str = Field(description="Negative prompt")
|
negative_prompt: str = Field(description="Negative prompt")
|
||||||
@ -49,3 +56,23 @@ StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
|
|||||||
|
|
||||||
class StylePresetRecordWithImage(StylePresetRecordDTO):
|
class StylePresetRecordWithImage(StylePresetRecordDTO):
|
||||||
image: Optional[str] = Field(description="The path for image")
|
image: Optional[str] = Field(description="The path for image")
|
||||||
|
|
||||||
|
|
||||||
|
class StylePresetImportRow(BaseModel):
|
||||||
|
name: str
|
||||||
|
prompt: str
|
||||||
|
negative_prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
def parse_csv(file: UploadFile) -> Generator[StylePresetImportRow, None, None]:
|
||||||
|
"""Yield parsed and validated rows from the CSV file."""
|
||||||
|
file_content = file.file.read().decode("utf-8")
|
||||||
|
csv_reader = csv.DictReader(io.StringIO(file_content))
|
||||||
|
|
||||||
|
for row in csv_reader:
|
||||||
|
if "name" not in row or "prompt" not in row or "negative_prompt" not in row:
|
||||||
|
raise StylePresetImportValidationError()
|
||||||
|
|
||||||
|
yield StylePresetImportRow(
|
||||||
|
name=row["name"].strip(), prompt=row["prompt"].strip(), negative_prompt=row["negative_prompt"].strip()
|
||||||
|
)
|
||||||
|
@ -75,6 +75,39 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
return self.get(style_preset_id)
|
return self.get(style_preset_id)
|
||||||
|
|
||||||
|
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||||
|
style_preset_ids = []
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
for style_preset in style_presets:
|
||||||
|
style_preset_id = uuid_string()
|
||||||
|
style_preset_ids.append(style_preset_id)
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO style_presets (
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
preset_data,
|
||||||
|
type
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?, ?);
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
style_preset_id,
|
||||||
|
style_preset.name,
|
||||||
|
style_preset.preset_data.model_dump_json(),
|
||||||
|
style_preset.type,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
Loading…
Reference in New Issue
Block a user