From b76bf50b9306c56e7ca940dbebd02f476554a239 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Thu, 1 Aug 2024 22:20:11 -0400 Subject: [PATCH] feat(db,api): create new table for style presets, build out record storage service for style presets --- invokeai/app/api/dependencies.py | 3 + invokeai/app/services/invocation_services.py | 3 + .../app/services/shared/sqlite/sqlite_util.py | 2 + .../migrations/migration_14.py | 60 +++++++ .../services/style_preset_records/__init__.py | 0 .../style_preset_records_base.py | 45 +++++ .../style_preset_records_common.py | 32 ++++ .../style_preset_records_sqlite.py | 169 ++++++++++++++++++ 8 files changed, 314 insertions(+) create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py create mode 100644 invokeai/app/services/style_preset_records/__init__.py create mode 100644 invokeai/app/services/style_preset_records/style_preset_records_base.py create mode 100644 invokeai/app/services/style_preset_records/style_preset_records_common.py create mode 100644 invokeai/app/services/style_preset_records/style_preset_records_sqlite.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 6e049399db..18ffd39947 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -31,6 +31,7 @@ from invokeai.app.services.session_processor.session_processor_default import ( ) from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue from invokeai.app.services.shared.sqlite.sqlite_util import init_db +from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage from invokeai.app.services.urls.urls_default import LocalUrlService from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -109,6 +110,7 @@ class ApiDependencies: session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() workflow_records = SqliteWorkflowRecordsStorage(db=db) + style_presets = SqliteStylePresetRecordsStorage(db=db) services = InvocationServices( board_image_records=board_image_records, @@ -134,6 +136,7 @@ class ApiDependencies: workflow_records=workflow_records, tensors=tensors, conditioning=conditioning, + style_presets=style_presets, ) ApiDependencies.invoker = Invoker(services) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 90ca613074..439625e23f 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -4,6 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase +from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase if TYPE_CHECKING: from logging import Logger @@ -61,6 +62,7 @@ class InvocationServices: workflow_records: "WorkflowRecordsStorageBase", tensors: "ObjectSerializerBase[torch.Tensor]", conditioning: "ObjectSerializerBase[ConditioningFieldData]", + style_presets: "StylePresetRecordsStorageBase", ): self.board_images = board_images self.board_image_records = board_image_records @@ -85,3 +87,4 @@ class InvocationServices: self.workflow_records = workflow_records self.tensors = tensors self.conditioning = conditioning + self.style_presets = style_presets diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 49fd337da2..e35c351ff0 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -16,6 +16,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -49,6 +50,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_11(app_config=config, logger=logger)) migrator.register_migration(build_migration_12(app_config=config)) migrator.register_migration(build_migration_13()) + migrator.register_migration(build_migration_14()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py new file mode 100644 index 0000000000..38a35d547c --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py @@ -0,0 +1,60 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration14Callback: + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._create_style_presets(cursor) + + def _create_style_presets(self, cursor: sqlite3.Cursor) -> None: + """Create the table used to store model metadata downloaded from remote sources.""" + tables = [ + """--sql + CREATE TABLE IF NOT EXISTS style_presets ( + id TEXT NOT NULL PRIMARY KEY, + name TEXT NOT NULL, + preset_data TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) + ); + """ + ] + + # Add trigger for `updated_at`. + triggers = [ + """--sql + CREATE TRIGGER IF NOT EXISTS style_presets + AFTER UPDATE + ON style_presets FOR EACH ROW + BEGIN + UPDATE style_presets SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE id = old.id; + END; + """ + ] + + # Add indexes for searchable fields + indices = [ + "CREATE INDEX IF NOT EXISTS idx_style_presets_name ON style_presets(name);", + ] + + for stmt in tables + indices + triggers: + cursor.execute(stmt) + + +def build_migration_14() -> Migration: + """ + Build the migration from database version 12 to 14.. + + This migration does the following: + - Adds `archived` columns to the board table. + """ + migration_14 = Migration( + from_version=13, + to_version=14, + callback=Migration14Callback(), + ) + + return migration_14 diff --git a/invokeai/app/services/style_preset_records/__init__.py b/invokeai/app/services/style_preset_records/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/app/services/style_preset_records/style_preset_records_base.py b/invokeai/app/services/style_preset_records/style_preset_records_base.py new file mode 100644 index 0000000000..e7e9c6655a --- /dev/null +++ b/invokeai/app/services/style_preset_records/style_preset_records_base.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection +from invokeai.app.services.style_preset_records.style_preset_records_common import StylePresetRecordDTO, StylePresetWithoutId +from invokeai.app.services.workflow_records.workflow_records_common import ( + WorkflowRecordOrderBy, +) + + +class StylePresetRecordsStorageBase(ABC): + """Base class for style preset storage services.""" + + @abstractmethod + def get(self, id: str) -> StylePresetRecordDTO: + """Get style preset by id.""" + pass + + @abstractmethod + def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO: + """Creates a style preset.""" + pass + + @abstractmethod + def update(self, id: str, changes: StylePresetWithoutId) -> StylePresetRecordDTO: + """Updates a style preset.""" + pass + + @abstractmethod + def delete(self, id: str) -> None: + """Deletes a style preset.""" + pass + + @abstractmethod + def get_many( + self, + page: int, + per_page: int, + order_by: WorkflowRecordOrderBy, + direction: SQLiteDirection, + query: Optional[str], + ) -> PaginatedResults[StylePresetRecordDTO]: + """Gets many workflows.""" + pass diff --git a/invokeai/app/services/style_preset_records/style_preset_records_common.py b/invokeai/app/services/style_preset_records/style_preset_records_common.py new file mode 100644 index 0000000000..a5b787b9cf --- /dev/null +++ b/invokeai/app/services/style_preset_records/style_preset_records_common.py @@ -0,0 +1,32 @@ +from typing import Any, Union + +from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator + + +class StylePresetNotFoundError(Exception): + """Raised when a style preset is not found""" + + +class PresetData(BaseModel): + positive_prompt: str = Field(description="Positive prompt") + negative_prompt: str = Field(description="Negative prompt") + + +PresetDataValidator = TypeAdapter(PresetData) + + +class StylePresetWithoutId(BaseModel): + name: str = Field(description="The name of the style preset.") + preset_data: PresetData = Field(description="The preset data") + + +class StylePresetRecordDTO(StylePresetWithoutId): + id: str = Field(description="The style preset ID.") + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO": + data["preset_data"] = PresetDataValidator.validate_json(data.get("preset_data", "")) + return StylePresetRecordDTOValidator.validate_python(data) + + +StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO) diff --git a/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py new file mode 100644 index 0000000000..8ad44f4054 --- /dev/null +++ b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py @@ -0,0 +1,169 @@ +from pathlib import Path +from typing import Optional + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase +from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase +from invokeai.app.services.style_preset_records.style_preset_records_common import ( + StylePresetNotFoundError, + StylePresetRecordDTO, + StylePresetWithoutId, +) +from invokeai.app.services.workflow_records.workflow_records_common import ( + WorkflowRecordOrderBy, +) +from invokeai.app.util.misc import uuid_string + + +class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase): + def __init__(self, db: SqliteDatabase) -> None: + super().__init__() + self._lock = db.lock + self._conn = db.conn + self._cursor = self._conn.cursor() + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + + def get(self, id: str) -> StylePresetRecordDTO: + """Gets a style preset by ID.""" + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT * + FROM style_presets + WHERE id = ?; + """, + (id,), + ) + row = self._cursor.fetchone() + if row is None: + raise StylePresetNotFoundError(f"Style preset with id {id} not found") + return StylePresetRecordDTO.from_dict(dict(row)) + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + + def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO: + id = uuid_string() + try: + self._lock.acquire() + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO style_presets ( + id, + name, + preset_data, + ) + VALUES (?, ?, ?); + """, + ( + id, + style_preset.name, + style_preset.preset_data, + ), + ) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + return self.get(id) + + def update(self, id: str, changes: StylePresetWithoutId) -> StylePresetRecordDTO: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + UPDATE style_presets + SET preset_data = ? + WHERE id = ? ; + """, + ( + changes.preset_data, + id, + ), + ) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + return self.get(id) + + def delete(self, id: str) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + DELETE from style_presets + WHERE id = ? ; + """, + (id,), + ) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + return None + + def get_many( + self, + page: int, + per_page: int, + order_by: WorkflowRecordOrderBy, + direction: SQLiteDirection, + query: Optional[str] = None, + ) -> PaginatedResults[StylePresetRecordDTO]: + try: + self._lock.acquire() + # sanitize! + assert order_by in WorkflowRecordOrderBy + assert direction in SQLiteDirection + count_query = "SELECT COUNT(*) FROM style_presets" + main_query = """ + SELECT + * + FROM style_presets + """ + main_params: list[int | str] = [] + count_params: list[int | str] = [] + stripped_query = query.strip() if query else None + if stripped_query: + wildcard_query = "%" + stripped_query + "%" + main_query += " AND name LIKE ? OR description LIKE ? " + count_query += " AND name LIKE ? OR description LIKE ?;" + main_params.extend([wildcard_query, wildcard_query]) + count_params.extend([wildcard_query, wildcard_query]) + + main_query += f" ORDER BY {order_by.value} {direction.value} LIMIT ? OFFSET ?;" + main_params.extend([per_page, page * per_page]) + self._cursor.execute(main_query, main_params) + rows = self._cursor.fetchall() + style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows] + + self._cursor.execute(count_query, count_params) + total = self._cursor.fetchone()[0] + pages = total // per_page + (total % per_page > 0) + + return PaginatedResults( + items=style_presets, + page=page, + per_page=per_page, + pages=pages, + total=total, + ) + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release()