mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db,api): create new table for style presets, build out record storage service for style presets
This commit is contained in:
parent
94d64b8a78
commit
b76bf50b93
@ -31,6 +31,7 @@ from invokeai.app.services.session_processor.session_processor_default import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
|
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
|
||||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
@ -109,6 +110,7 @@ class ApiDependencies:
|
|||||||
session_queue = SqliteSessionQueue(db=db)
|
session_queue = SqliteSessionQueue(db=db)
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||||
|
style_presets = SqliteStylePresetRecordsStorage(db=db)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
board_image_records=board_image_records,
|
board_image_records=board_image_records,
|
||||||
@ -134,6 +136,7 @@ class ApiDependencies:
|
|||||||
workflow_records=workflow_records,
|
workflow_records=workflow_records,
|
||||||
tensors=tensors,
|
tensors=tensors,
|
||||||
conditioning=conditioning,
|
conditioning=conditioning,
|
||||||
|
style_presets=style_presets,
|
||||||
)
|
)
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||||
|
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
@ -61,6 +62,7 @@ class InvocationServices:
|
|||||||
workflow_records: "WorkflowRecordsStorageBase",
|
workflow_records: "WorkflowRecordsStorageBase",
|
||||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||||
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||||
|
style_presets: "StylePresetRecordsStorageBase",
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.board_image_records = board_image_records
|
self.board_image_records = board_image_records
|
||||||
@ -85,3 +87,4 @@ class InvocationServices:
|
|||||||
self.workflow_records = workflow_records
|
self.workflow_records = workflow_records
|
||||||
self.tensors = tensors
|
self.tensors = tensors
|
||||||
self.conditioning = conditioning
|
self.conditioning = conditioning
|
||||||
|
self.style_presets = style_presets
|
||||||
|
@ -16,6 +16,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import
|
|||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||||
|
|
||||||
|
|
||||||
@ -49,6 +50,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
|||||||
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
||||||
migrator.register_migration(build_migration_12(app_config=config))
|
migrator.register_migration(build_migration_12(app_config=config))
|
||||||
migrator.register_migration(build_migration_13())
|
migrator.register_migration(build_migration_13())
|
||||||
|
migrator.register_migration(build_migration_14())
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
|
||||||
return db
|
return db
|
||||||
|
@ -0,0 +1,60 @@
|
|||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||||
|
|
||||||
|
|
||||||
|
class Migration14Callback:
|
||||||
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
self._create_style_presets(cursor)
|
||||||
|
|
||||||
|
def _create_style_presets(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""Create the table used to store model metadata downloaded from remote sources."""
|
||||||
|
tables = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS style_presets (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
preset_data TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add trigger for `updated_at`.
|
||||||
|
triggers = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS style_presets
|
||||||
|
AFTER UPDATE
|
||||||
|
ON style_presets FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE style_presets SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE id = old.id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add indexes for searchable fields
|
||||||
|
indices = [
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_style_presets_name ON style_presets(name);",
|
||||||
|
]
|
||||||
|
|
||||||
|
for stmt in tables + indices + triggers:
|
||||||
|
cursor.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
def build_migration_14() -> Migration:
|
||||||
|
"""
|
||||||
|
Build the migration from database version 12 to 14..
|
||||||
|
|
||||||
|
This migration does the following:
|
||||||
|
- Adds `archived` columns to the board table.
|
||||||
|
"""
|
||||||
|
migration_14 = Migration(
|
||||||
|
from_version=13,
|
||||||
|
to_version=14,
|
||||||
|
callback=Migration14Callback(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return migration_14
|
@ -0,0 +1,45 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||||
|
from invokeai.app.services.style_preset_records.style_preset_records_common import StylePresetRecordDTO, StylePresetWithoutId
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||||
|
WorkflowRecordOrderBy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StylePresetRecordsStorageBase(ABC):
|
||||||
|
"""Base class for style preset storage services."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, id: str) -> StylePresetRecordDTO:
|
||||||
|
"""Get style preset by id."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||||
|
"""Creates a style preset."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(self, id: str, changes: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||||
|
"""Updates a style preset."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, id: str) -> None:
|
||||||
|
"""Deletes a style preset."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
page: int,
|
||||||
|
per_page: int,
|
||||||
|
order_by: WorkflowRecordOrderBy,
|
||||||
|
direction: SQLiteDirection,
|
||||||
|
query: Optional[str],
|
||||||
|
) -> PaginatedResults[StylePresetRecordDTO]:
|
||||||
|
"""Gets many workflows."""
|
||||||
|
pass
|
@ -0,0 +1,32 @@
|
|||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class StylePresetNotFoundError(Exception):
|
||||||
|
"""Raised when a style preset is not found"""
|
||||||
|
|
||||||
|
|
||||||
|
class PresetData(BaseModel):
|
||||||
|
positive_prompt: str = Field(description="Positive prompt")
|
||||||
|
negative_prompt: str = Field(description="Negative prompt")
|
||||||
|
|
||||||
|
|
||||||
|
PresetDataValidator = TypeAdapter(PresetData)
|
||||||
|
|
||||||
|
|
||||||
|
class StylePresetWithoutId(BaseModel):
|
||||||
|
name: str = Field(description="The name of the style preset.")
|
||||||
|
preset_data: PresetData = Field(description="The preset data")
|
||||||
|
|
||||||
|
|
||||||
|
class StylePresetRecordDTO(StylePresetWithoutId):
|
||||||
|
id: str = Field(description="The style preset ID.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO":
|
||||||
|
data["preset_data"] = PresetDataValidator.validate_json(data.get("preset_data", ""))
|
||||||
|
return StylePresetRecordDTOValidator.validate_python(data)
|
||||||
|
|
||||||
|
|
||||||
|
StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
|
@ -0,0 +1,169 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||||
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
|
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
||||||
|
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||||
|
StylePresetNotFoundError,
|
||||||
|
StylePresetRecordDTO,
|
||||||
|
StylePresetWithoutId,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||||
|
WorkflowRecordOrderBy,
|
||||||
|
)
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||||
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._lock = db.lock
|
||||||
|
self._conn = db.conn
|
||||||
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self._invoker = invoker
|
||||||
|
|
||||||
|
def get(self, id: str) -> StylePresetRecordDTO:
|
||||||
|
"""Gets a style preset by ID."""
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM style_presets
|
||||||
|
WHERE id = ?;
|
||||||
|
""",
|
||||||
|
(id,),
|
||||||
|
)
|
||||||
|
row = self._cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
raise StylePresetNotFoundError(f"Style preset with id {id} not found")
|
||||||
|
return StylePresetRecordDTO.from_dict(dict(row))
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||||
|
id = uuid_string()
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO style_presets (
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
preset_data,
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?);
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
id,
|
||||||
|
style_preset.name,
|
||||||
|
style_preset.preset_data,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get(id)
|
||||||
|
|
||||||
|
def update(self, id: str, changes: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE style_presets
|
||||||
|
SET preset_data = ?
|
||||||
|
WHERE id = ? ;
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
changes.preset_data,
|
||||||
|
id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get(id)
|
||||||
|
|
||||||
|
def delete(self, id: str) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
DELETE from style_presets
|
||||||
|
WHERE id = ? ;
|
||||||
|
""",
|
||||||
|
(id,),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
page: int,
|
||||||
|
per_page: int,
|
||||||
|
order_by: WorkflowRecordOrderBy,
|
||||||
|
direction: SQLiteDirection,
|
||||||
|
query: Optional[str] = None,
|
||||||
|
) -> PaginatedResults[StylePresetRecordDTO]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
# sanitize!
|
||||||
|
assert order_by in WorkflowRecordOrderBy
|
||||||
|
assert direction in SQLiteDirection
|
||||||
|
count_query = "SELECT COUNT(*) FROM style_presets"
|
||||||
|
main_query = """
|
||||||
|
SELECT
|
||||||
|
*
|
||||||
|
FROM style_presets
|
||||||
|
"""
|
||||||
|
main_params: list[int | str] = []
|
||||||
|
count_params: list[int | str] = []
|
||||||
|
stripped_query = query.strip() if query else None
|
||||||
|
if stripped_query:
|
||||||
|
wildcard_query = "%" + stripped_query + "%"
|
||||||
|
main_query += " AND name LIKE ? OR description LIKE ? "
|
||||||
|
count_query += " AND name LIKE ? OR description LIKE ?;"
|
||||||
|
main_params.extend([wildcard_query, wildcard_query])
|
||||||
|
count_params.extend([wildcard_query, wildcard_query])
|
||||||
|
|
||||||
|
main_query += f" ORDER BY {order_by.value} {direction.value} LIMIT ? OFFSET ?;"
|
||||||
|
main_params.extend([per_page, page * per_page])
|
||||||
|
self._cursor.execute(main_query, main_params)
|
||||||
|
rows = self._cursor.fetchall()
|
||||||
|
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
|
||||||
|
|
||||||
|
self._cursor.execute(count_query, count_params)
|
||||||
|
total = self._cursor.fetchone()[0]
|
||||||
|
pages = total // per_page + (total % per_page > 0)
|
||||||
|
|
||||||
|
return PaginatedResults(
|
||||||
|
items=style_presets,
|
||||||
|
page=page,
|
||||||
|
per_page=per_page,
|
||||||
|
pages=pages,
|
||||||
|
total=total,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
Loading…
Reference in New Issue
Block a user