From b76bf50b9306c56e7ca940dbebd02f476554a239 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Thu, 1 Aug 2024 22:20:11 -0400 Subject: [PATCH 01/48] feat(db,api): create new table for style presets, build out record storage service for style presets --- invokeai/app/api/dependencies.py | 3 + invokeai/app/services/invocation_services.py | 3 + .../app/services/shared/sqlite/sqlite_util.py | 2 + .../migrations/migration_14.py | 60 +++++++ .../services/style_preset_records/__init__.py | 0 .../style_preset_records_base.py | 45 +++++ .../style_preset_records_common.py | 32 ++++ .../style_preset_records_sqlite.py | 169 ++++++++++++++++++ 8 files changed, 314 insertions(+) create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py create mode 100644 invokeai/app/services/style_preset_records/__init__.py create mode 100644 invokeai/app/services/style_preset_records/style_preset_records_base.py create mode 100644 invokeai/app/services/style_preset_records/style_preset_records_common.py create mode 100644 invokeai/app/services/style_preset_records/style_preset_records_sqlite.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 6e049399db..18ffd39947 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -31,6 +31,7 @@ from invokeai.app.services.session_processor.session_processor_default import ( ) from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue from invokeai.app.services.shared.sqlite.sqlite_util import init_db +from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage from invokeai.app.services.urls.urls_default import LocalUrlService from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -109,6 +110,7 @@ class ApiDependencies: session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() workflow_records = SqliteWorkflowRecordsStorage(db=db) + style_presets = SqliteStylePresetRecordsStorage(db=db) services = InvocationServices( board_image_records=board_image_records, @@ -134,6 +136,7 @@ class ApiDependencies: workflow_records=workflow_records, tensors=tensors, conditioning=conditioning, + style_presets=style_presets, ) ApiDependencies.invoker = Invoker(services) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 90ca613074..439625e23f 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -4,6 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase +from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase if TYPE_CHECKING: from logging import Logger @@ -61,6 +62,7 @@ class InvocationServices: workflow_records: "WorkflowRecordsStorageBase", tensors: "ObjectSerializerBase[torch.Tensor]", conditioning: "ObjectSerializerBase[ConditioningFieldData]", + style_presets: "StylePresetRecordsStorageBase", ): self.board_images = board_images self.board_image_records = board_image_records @@ -85,3 +87,4 @@ class InvocationServices: self.workflow_records = workflow_records self.tensors = tensors self.conditioning = conditioning + self.style_presets = style_presets diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 49fd337da2..e35c351ff0 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -16,6 +16,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -49,6 +50,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_11(app_config=config, logger=logger)) migrator.register_migration(build_migration_12(app_config=config)) migrator.register_migration(build_migration_13()) + migrator.register_migration(build_migration_14()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py new file mode 100644 index 0000000000..38a35d547c --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py @@ -0,0 +1,60 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration14Callback: + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._create_style_presets(cursor) + + def _create_style_presets(self, cursor: sqlite3.Cursor) -> None: + """Create the table used to store model metadata downloaded from remote sources.""" + tables = [ + """--sql + CREATE TABLE IF NOT EXISTS style_presets ( + id TEXT NOT NULL PRIMARY KEY, + name TEXT NOT NULL, + preset_data TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) + ); + """ + ] + + # Add trigger for `updated_at`. + triggers = [ + """--sql + CREATE TRIGGER IF NOT EXISTS style_presets + AFTER UPDATE + ON style_presets FOR EACH ROW + BEGIN + UPDATE style_presets SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE id = old.id; + END; + """ + ] + + # Add indexes for searchable fields + indices = [ + "CREATE INDEX IF NOT EXISTS idx_style_presets_name ON style_presets(name);", + ] + + for stmt in tables + indices + triggers: + cursor.execute(stmt) + + +def build_migration_14() -> Migration: + """ + Build the migration from database version 12 to 14.. + + This migration does the following: + - Adds `archived` columns to the board table. + """ + migration_14 = Migration( + from_version=13, + to_version=14, + callback=Migration14Callback(), + ) + + return migration_14 diff --git a/invokeai/app/services/style_preset_records/__init__.py b/invokeai/app/services/style_preset_records/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/app/services/style_preset_records/style_preset_records_base.py b/invokeai/app/services/style_preset_records/style_preset_records_base.py new file mode 100644 index 0000000000..e7e9c6655a --- /dev/null +++ b/invokeai/app/services/style_preset_records/style_preset_records_base.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection +from invokeai.app.services.style_preset_records.style_preset_records_common import StylePresetRecordDTO, StylePresetWithoutId +from invokeai.app.services.workflow_records.workflow_records_common import ( + WorkflowRecordOrderBy, +) + + +class StylePresetRecordsStorageBase(ABC): + """Base class for style preset storage services.""" + + @abstractmethod + def get(self, id: str) -> StylePresetRecordDTO: + """Get style preset by id.""" + pass + + @abstractmethod + def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO: + """Creates a style preset.""" + pass + + @abstractmethod + def update(self, id: str, changes: StylePresetWithoutId) -> StylePresetRecordDTO: + """Updates a style preset.""" + pass + + @abstractmethod + def delete(self, id: str) -> None: + """Deletes a style preset.""" + pass + + @abstractmethod + def get_many( + self, + page: int, + per_page: int, + order_by: WorkflowRecordOrderBy, + direction: SQLiteDirection, + query: Optional[str], + ) -> PaginatedResults[StylePresetRecordDTO]: + """Gets many workflows.""" + pass diff --git a/invokeai/app/services/style_preset_records/style_preset_records_common.py b/invokeai/app/services/style_preset_records/style_preset_records_common.py new file mode 100644 index 0000000000..a5b787b9cf --- /dev/null +++ b/invokeai/app/services/style_preset_records/style_preset_records_common.py @@ -0,0 +1,32 @@ +from typing import Any, Union + +from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator + + +class StylePresetNotFoundError(Exception): + """Raised when a style preset is not found""" + + +class PresetData(BaseModel): + positive_prompt: str = Field(description="Positive prompt") + negative_prompt: str = Field(description="Negative prompt") + + +PresetDataValidator = TypeAdapter(PresetData) + + +class StylePresetWithoutId(BaseModel): + name: str = Field(description="The name of the style preset.") + preset_data: PresetData = Field(description="The preset data") + + +class StylePresetRecordDTO(StylePresetWithoutId): + id: str = Field(description="The style preset ID.") + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO": + data["preset_data"] = PresetDataValidator.validate_json(data.get("preset_data", "")) + return StylePresetRecordDTOValidator.validate_python(data) + + +StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO) diff --git a/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py new file mode 100644 index 0000000000..8ad44f4054 --- /dev/null +++ b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py @@ -0,0 +1,169 @@ +from pathlib import Path +from typing import Optional + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase +from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase +from invokeai.app.services.style_preset_records.style_preset_records_common import ( + StylePresetNotFoundError, + StylePresetRecordDTO, + StylePresetWithoutId, +) +from invokeai.app.services.workflow_records.workflow_records_common import ( + WorkflowRecordOrderBy, +) +from invokeai.app.util.misc import uuid_string + + +class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase): + def __init__(self, db: SqliteDatabase) -> None: + super().__init__() + self._lock = db.lock + self._conn = db.conn + self._cursor = self._conn.cursor() + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + + def get(self, id: str) -> StylePresetRecordDTO: + """Gets a style preset by ID.""" + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT * + FROM style_presets + WHERE id = ?; + """, + (id,), + ) + row = self._cursor.fetchone() + if row is None: + raise StylePresetNotFoundError(f"Style preset with id {id} not found") + return StylePresetRecordDTO.from_dict(dict(row)) + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + + def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO: + id = uuid_string() + try: + self._lock.acquire() + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO style_presets ( + id, + name, + preset_data, + ) + VALUES (?, ?, ?); + """, + ( + id, + style_preset.name, + style_preset.preset_data, + ), + ) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + return self.get(id) + + def update(self, id: str, changes: StylePresetWithoutId) -> StylePresetRecordDTO: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + UPDATE style_presets + SET preset_data = ? + WHERE id = ? ; + """, + ( + changes.preset_data, + id, + ), + ) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + return self.get(id) + + def delete(self, id: str) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + DELETE from style_presets + WHERE id = ? ; + """, + (id,), + ) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + return None + + def get_many( + self, + page: int, + per_page: int, + order_by: WorkflowRecordOrderBy, + direction: SQLiteDirection, + query: Optional[str] = None, + ) -> PaginatedResults[StylePresetRecordDTO]: + try: + self._lock.acquire() + # sanitize! + assert order_by in WorkflowRecordOrderBy + assert direction in SQLiteDirection + count_query = "SELECT COUNT(*) FROM style_presets" + main_query = """ + SELECT + * + FROM style_presets + """ + main_params: list[int | str] = [] + count_params: list[int | str] = [] + stripped_query = query.strip() if query else None + if stripped_query: + wildcard_query = "%" + stripped_query + "%" + main_query += " AND name LIKE ? OR description LIKE ? " + count_query += " AND name LIKE ? OR description LIKE ?;" + main_params.extend([wildcard_query, wildcard_query]) + count_params.extend([wildcard_query, wildcard_query]) + + main_query += f" ORDER BY {order_by.value} {direction.value} LIMIT ? OFFSET ?;" + main_params.extend([per_page, page * per_page]) + self._cursor.execute(main_query, main_params) + rows = self._cursor.fetchall() + style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows] + + self._cursor.execute(count_query, count_params) + total = self._cursor.fetchone()[0] + pages = total // per_page + (total % per_page > 0) + + return PaginatedResults( + items=style_presets, + page=page, + per_page=per_page, + pages=pages, + total=total, + ) + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() From 217fe40d99eb479f827bb5014c2dbed9d388ef33 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Fri, 2 Aug 2024 12:29:54 -0400 Subject: [PATCH 02/48] feat(api): add style_presets router, make sure all CRUD is working, add is_default --- invokeai/app/api/dependencies.py | 4 +- invokeai/app/api/routers/style_presets.py | 100 ++++++++++++++++++ invokeai/app/api_app.py | 2 + invokeai/app/services/invocation_services.py | 4 +- .../migrations/migration_14.py | 5 +- .../style_preset_records_base.py | 12 ++- .../style_preset_records_common.py | 20 +++- .../style_preset_records_sqlite.py | 52 +++++---- 8 files changed, 165 insertions(+), 34 deletions(-) create mode 100644 invokeai/app/api/routers/style_presets.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 18ffd39947..d2eeab6219 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -110,7 +110,7 @@ class ApiDependencies: session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() workflow_records = SqliteWorkflowRecordsStorage(db=db) - style_presets = SqliteStylePresetRecordsStorage(db=db) + style_preset_records = SqliteStylePresetRecordsStorage(db=db) services = InvocationServices( board_image_records=board_image_records, @@ -136,7 +136,7 @@ class ApiDependencies: workflow_records=workflow_records, tensors=tensors, conditioning=conditioning, - style_presets=style_presets, + style_preset_records=style_preset_records, ) ApiDependencies.invoker = Invoker(services) diff --git a/invokeai/app/api/routers/style_presets.py b/invokeai/app/api/routers/style_presets.py new file mode 100644 index 0000000000..ae38db76aa --- /dev/null +++ b/invokeai/app/api/routers/style_presets.py @@ -0,0 +1,100 @@ +from typing import Optional + +from fastapi import APIRouter, Body, HTTPException, Path, Query + +from invokeai.app.api.dependencies import ApiDependencies +from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection +from invokeai.app.services.style_preset_records.style_preset_records_common import ( + StylePresetChanges, + StylePresetNotFoundError, + StylePresetRecordDTO, + StylePresetWithoutId, + StylePresetRecordOrderBy, +) + + +style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"]) + + +@style_presets_router.get( + "/i/{style_preset_id}", + operation_id="get_style_preset", + responses={ + 200: {"model": StylePresetRecordDTO}, + }, +) +async def get_style_preset( + style_preset_id: str = Path(description="The style preset to get"), +) -> StylePresetRecordDTO: + """Gets a style preset""" + try: + return ApiDependencies.invoker.services.style_preset_records.get(style_preset_id) + except StylePresetNotFoundError: + raise HTTPException(status_code=404, detail="Style preset not found") + + +@style_presets_router.patch( + "/i/{style_preset_id}", + operation_id="update_style_preset", + responses={ + 200: {"model": StylePresetRecordDTO}, + }, +) +async def update_style_preset( + style_preset_id: str = Path(description="The id of the style preset to update"), + changes: StylePresetChanges = Body(description="The updated style preset", embed=True), +) -> StylePresetRecordDTO: + """Updates a style preset""" + return ApiDependencies.invoker.services.style_preset_records.update(id=style_preset_id, changes=changes) + + +@style_presets_router.delete( + "/i/{style_preset_id}", + operation_id="delete_style_preset", +) +async def delete_style_preset( + style_preset_id: str = Path(description="The style preset to delete"), +) -> None: + """Deletes a style preset""" + ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id) + + +@style_presets_router.post( + "/", + operation_id="create_style_preset", + responses={ + 200: {"model": StylePresetRecordDTO}, + }, +) +async def create_style_preset( + style_preset: StylePresetWithoutId = Body(description="The style preset to create", embed=True), +) -> StylePresetRecordDTO: + """Creates a style preset""" + return ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset) + + +@style_presets_router.get( + "/", + operation_id="list_style_presets", + responses={ + 200: {"model": PaginatedResults[StylePresetRecordDTO]}, + }, +) +async def list_style_presets( + page: int = Query(default=0, description="The page to get"), + per_page: int = Query(default=10, description="The number of style presets per page"), + order_by: StylePresetRecordOrderBy = Query( + default=StylePresetRecordOrderBy.Name, description="The attribute to order by" + ), + direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"), + query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"), +) -> PaginatedResults[StylePresetRecordDTO]: + """Gets a page of style presets""" + return ApiDependencies.invoker.services.style_preset_records.get_many( + page=page, + per_page=per_page, + order_by=order_by, + direction=direction, + query=query, + ) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 88820a0c4c..492c465573 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -32,6 +32,7 @@ from invokeai.app.api.routers import ( session_queue, utilities, workflows, + style_presets, ) from invokeai.app.api.sockets import SocketIO from invokeai.app.services.config.config_default import get_config @@ -106,6 +107,7 @@ app.include_router(board_images.board_images_router, prefix="/api") app.include_router(app_info.app_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api") app.include_router(workflows.workflows_router, prefix="/api") +app.include_router(style_presets.style_presets_router, prefix="/api") app.openapi = get_openapi_func(app) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 439625e23f..c756eae8e5 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -62,7 +62,7 @@ class InvocationServices: workflow_records: "WorkflowRecordsStorageBase", tensors: "ObjectSerializerBase[torch.Tensor]", conditioning: "ObjectSerializerBase[ConditioningFieldData]", - style_presets: "StylePresetRecordsStorageBase", + style_preset_records: "StylePresetRecordsStorageBase", ): self.board_images = board_images self.board_image_records = board_image_records @@ -87,4 +87,4 @@ class InvocationServices: self.workflow_records = workflow_records self.tensors = tensors self.conditioning = conditioning - self.style_presets = style_presets + self.style_preset_records = style_preset_records diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py index 38a35d547c..e7e77651dd 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py @@ -15,13 +15,14 @@ class Migration14Callback: id TEXT NOT NULL PRIMARY KEY, name TEXT NOT NULL, preset_data TEXT NOT NULL, + is_default BOOLEAN DEFAULT FALSE, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Updated via trigger updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) ); """ - ] - + ] + # Add trigger for `updated_at`. triggers = [ """--sql diff --git a/invokeai/app/services/style_preset_records/style_preset_records_base.py b/invokeai/app/services/style_preset_records/style_preset_records_base.py index e7e9c6655a..58121c3445 100644 --- a/invokeai/app/services/style_preset_records/style_preset_records_base.py +++ b/invokeai/app/services/style_preset_records/style_preset_records_base.py @@ -3,9 +3,11 @@ from typing import Optional from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection -from invokeai.app.services.style_preset_records.style_preset_records_common import StylePresetRecordDTO, StylePresetWithoutId -from invokeai.app.services.workflow_records.workflow_records_common import ( - WorkflowRecordOrderBy, +from invokeai.app.services.style_preset_records.style_preset_records_common import ( + StylePresetChanges, + StylePresetRecordDTO, + StylePresetWithoutId, + StylePresetRecordOrderBy, ) @@ -23,7 +25,7 @@ class StylePresetRecordsStorageBase(ABC): pass @abstractmethod - def update(self, id: str, changes: StylePresetWithoutId) -> StylePresetRecordDTO: + def update(self, id: str, changes: StylePresetChanges) -> StylePresetRecordDTO: """Updates a style preset.""" pass @@ -37,7 +39,7 @@ class StylePresetRecordsStorageBase(ABC): self, page: int, per_page: int, - order_by: WorkflowRecordOrderBy, + order_by: StylePresetRecordOrderBy, direction: SQLiteDirection, query: Optional[str], ) -> PaginatedResults[StylePresetRecordDTO]: diff --git a/invokeai/app/services/style_preset_records/style_preset_records_common.py b/invokeai/app/services/style_preset_records/style_preset_records_common.py index a5b787b9cf..7ce7ccc730 100644 --- a/invokeai/app/services/style_preset_records/style_preset_records_common.py +++ b/invokeai/app/services/style_preset_records/style_preset_records_common.py @@ -1,13 +1,23 @@ -from typing import Any, Union +from enum import Enum +from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator +from invokeai.app.util.metaenum import MetaEnum + class StylePresetNotFoundError(Exception): """Raised when a style preset is not found""" -class PresetData(BaseModel): +class StylePresetRecordOrderBy(str, Enum, metaclass=MetaEnum): + """The order by options for workflow records""" + + CreatedAt = "created_at" + Name = "name" + + +class PresetData(BaseModel, extra="forbid"): positive_prompt: str = Field(description="Positive prompt") negative_prompt: str = Field(description="Negative prompt") @@ -15,6 +25,11 @@ class PresetData(BaseModel): PresetDataValidator = TypeAdapter(PresetData) +class StylePresetChanges(BaseModel, extra="forbid"): + name: Optional[str] = Field(default=None, description="The style preset's new name.") + preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.") + + class StylePresetWithoutId(BaseModel): name: str = Field(description="The name of the style preset.") preset_data: PresetData = Field(description="The preset data") @@ -22,6 +37,7 @@ class StylePresetWithoutId(BaseModel): class StylePresetRecordDTO(StylePresetWithoutId): id: str = Field(description="The style preset ID.") + is_default: bool = Field(description="Whether or not the style preset is default") @classmethod def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO": diff --git a/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py index 8ad44f4054..e0110afa9a 100644 --- a/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py +++ b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py @@ -7,12 +7,11 @@ from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase from invokeai.app.services.style_preset_records.style_preset_records_common import ( + StylePresetChanges, StylePresetNotFoundError, StylePresetRecordDTO, StylePresetWithoutId, -) -from invokeai.app.services.workflow_records.workflow_records_common import ( - WorkflowRecordOrderBy, + StylePresetRecordOrderBy, ) from invokeai.app.util.misc import uuid_string @@ -58,14 +57,14 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase): INSERT OR IGNORE INTO style_presets ( id, name, - preset_data, + preset_data ) VALUES (?, ?, ?); """, ( id, style_preset.name, - style_preset.preset_data, + style_preset.preset_data.model_dump_json(), ), ) self._conn.commit() @@ -76,20 +75,31 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase): self._lock.release() return self.get(id) - def update(self, id: str, changes: StylePresetWithoutId) -> StylePresetRecordDTO: + def update(self, id: str, changes: StylePresetChanges) -> StylePresetRecordDTO: try: self._lock.acquire() - self._cursor.execute( - """--sql - UPDATE style_presets - SET preset_data = ? - WHERE id = ? ; - """, - ( - changes.preset_data, - id, - ), - ) + # Change the name of a style preset + if changes.name is not None: + self._cursor.execute( + """--sql + UPDATE style_presets + SET name = ? + WHERE id = ?; + """, + (changes.name, id), + ) + + # Change the preset data for a style preset + if changes.preset_data is not None: + self._cursor.execute( + """--sql + UPDATE style_presets + SET preset_data = ? + WHERE id = ?; + """, + (changes.preset_data.model_dump_json(), id), + ) + self._conn.commit() except Exception: self._conn.rollback() @@ -120,14 +130,14 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase): self, page: int, per_page: int, - order_by: WorkflowRecordOrderBy, + order_by: StylePresetRecordOrderBy, direction: SQLiteDirection, query: Optional[str] = None, ) -> PaginatedResults[StylePresetRecordDTO]: try: self._lock.acquire() # sanitize! - assert order_by in WorkflowRecordOrderBy + assert order_by in StylePresetRecordOrderBy assert direction in SQLiteDirection count_query = "SELECT COUNT(*) FROM style_presets" main_query = """ @@ -140,8 +150,8 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase): stripped_query = query.strip() if query else None if stripped_query: wildcard_query = "%" + stripped_query + "%" - main_query += " AND name LIKE ? OR description LIKE ? " - count_query += " AND name LIKE ? OR description LIKE ?;" + main_query += " AND name LIKE ? " + count_query += " AND name LIKE ?;" main_params.extend([wildcard_query, wildcard_query]) count_params.extend([wildcard_query, wildcard_query]) From e05cc62e5f8c5f24bfe9757e54832b1770c7f626 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 5 Aug 2024 13:37:07 -0400 Subject: [PATCH 03/48] add style presets API layer to UI --- .../services/api/endpoints/stylePresets.ts | 79 +++ .../frontend/web/src/services/api/index.ts | 1 + .../frontend/web/src/services/api/schema.ts | 530 +++++++++++++----- 3 files changed, 479 insertions(+), 131 deletions(-) create mode 100644 invokeai/frontend/web/src/services/api/endpoints/stylePresets.ts diff --git a/invokeai/frontend/web/src/services/api/endpoints/stylePresets.ts b/invokeai/frontend/web/src/services/api/endpoints/stylePresets.ts new file mode 100644 index 0000000000..7aa4d8478c --- /dev/null +++ b/invokeai/frontend/web/src/services/api/endpoints/stylePresets.ts @@ -0,0 +1,79 @@ +import type { paths } from 'services/api/schema'; + +import { api, buildV1Url, LIST_TAG } from '..'; + +/** + * Builds an endpoint URL for the style_presets router + * @example + * buildStylePresetsUrl('some-path') + * // '/api/v1/style_presets/some-path' + */ +const buildStylePresetsUrl = (path: string = '') => buildV1Url(`style_presets/${path}`); + +export const stylePresetsApi = api.injectEndpoints({ + endpoints: (build) => ({ + getStylePreset: build.query< + paths['/api/v1/style_presets/i/{style_preset_id}']['get']['responses']['200']['content']['application/json'], + string + >({ + query: (style_preset_id) => buildStylePresetsUrl(`i/${style_preset_id}`), + providesTags: (result, error, style_preset_id) => [{ type: 'StylePreset', id: style_preset_id }, 'FetchOnReconnect'], + }), + deleteStylePreset: build.mutation({ + query: (style_preset_id) => ({ + url: buildStylePresetsUrl(`i/${style_preset_id}`), + method: 'DELETE', + }), + invalidatesTags: (result, error, style_preset_id) => [ + { type: 'StylePreset', id: LIST_TAG }, + { type: 'StylePreset', id: style_preset_id }, + ], + }), + createStylePreset: build.mutation< + paths['/api/v1/style_presets/']['post']['responses']['200']['content']['application/json'], + paths['/api/v1/style_presets/']['post']['requestBody']['content']['application/json']['style_preset'] + >({ + query: (style_preset) => ({ + url: buildStylePresetsUrl(), + method: 'POST', + body: { style_preset }, + }), + invalidatesTags: [ + { type: 'StylePreset', id: LIST_TAG }, + { type: 'StylePreset', id: LIST_TAG }, + ], + }), + updateStylePreset: build.mutation< + paths['/api/v1/style_presets/i/{style_preset_id}']['patch']['responses']['200']['content']['application/json'], + { id: string, changes: paths['/api/v1/style_presets/i/{style_preset_id}']['patch']['requestBody']['content']['application/json']['changes'] } + >({ + query: ({ id, changes }) => ({ + url: buildStylePresetsUrl(`i/${id}`), + method: 'PATCH', + body: { changes }, + }), + invalidatesTags: (response, error, { id, changes }) => [ + { type: 'StylePreset', id: LIST_TAG }, + { type: 'StylePreset', id: id }, + ], + }), + listStylePresets: build.query< + paths['/api/v1/style_presets/']['get']['responses']['200']['content']['application/json'], + NonNullable + >({ + query: (params) => ({ + url: buildStylePresetsUrl(), + params, + }), + providesTags: ['FetchOnReconnect', { type: 'StylePreset', id: LIST_TAG }], + }), + }), +}); + +export const { + useGetStylePresetQuery, + useCreateStylePresetMutation, + useDeleteStylePresetMutation, + useUpdateStylePresetMutation, + useListStylePresetsQuery, +} = stylePresetsApi; diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 79ea662717..828388b079 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -40,6 +40,7 @@ const tagTypes = [ 'SDXLRefinerModel', 'Workflow', 'WorkflowsRecent', + 'StylePreset', 'Schema', // This is invalidated on reconnect. It should be used for queries that have changing data, // especially related to the queue and generation. diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 79b82a23fa..ae93497bbf 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -532,6 +532,35 @@ export type paths = { */ post: operations["create_workflow"]; }; + "/api/v1/style_presets/i/{style_preset_id}": { + /** + * Get Style Preset + * @description Gets a style preset + */ + get: operations["get_style_preset"]; + /** + * Delete Style Preset + * @description Deletes a style preset + */ + delete: operations["delete_style_preset"]; + /** + * Update Style Preset + * @description Updates a style preset + */ + patch: operations["update_style_preset"]; + }; + "/api/v1/style_presets/": { + /** + * List Style Presets + * @description Gets a page of style presets + */ + get: operations["list_style_presets"]; + /** + * Create Style Preset + * @description Creates a style preset + */ + post: operations["create_style_preset"]; + }; }; export type webhooks = Record; @@ -1115,6 +1144,11 @@ export type components = { */ batch_ids: string[]; }; + /** Body_create_style_preset */ + Body_create_style_preset: { + /** @description The style preset to create */ + style_preset: components["schemas"]["StylePresetWithoutId"]; + }; /** Body_create_workflow */ Body_create_workflow: { /** @description The workflow to create */ @@ -1237,6 +1271,11 @@ export type components = { */ image: Blob; }; + /** Body_update_style_preset */ + Body_update_style_preset: { + /** @description The updated style preset */ + changes: components["schemas"]["StylePresetChanges"]; + }; /** Body_update_workflow */ Body_update_workflow: { /** @description The updated workflow */ @@ -7306,147 +7345,147 @@ export type components = { project_id: string | null; }; InvocationOutputMap: { - dynamic_prompt: components["schemas"]["StringCollectionOutput"]; - img_ilerp: components["schemas"]["ImageOutput"]; - random_range: components["schemas"]["IntegerCollectionOutput"]; - ideal_size: components["schemas"]["IdealSizeOutput"]; - lblend: components["schemas"]["LatentsOutput"]; - tiled_multi_diffusion_denoise_latents: components["schemas"]["LatentsOutput"]; - color_correct: components["schemas"]["ImageOutput"]; - sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; - img_hue_adjust: components["schemas"]["ImageOutput"]; - lscale: components["schemas"]["LatentsOutput"]; - infill_patchmatch: components["schemas"]["ImageOutput"]; - mask_edge: components["schemas"]["ImageOutput"]; - spandrel_image_to_image: components["schemas"]["ImageOutput"]; - round_float: components["schemas"]["FloatOutput"]; - color: components["schemas"]["ColorOutput"]; - img_channel_offset: components["schemas"]["ImageOutput"]; - create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; - string_join_three: components["schemas"]["StringOutput"]; - unsharp_mask: components["schemas"]["ImageOutput"]; - canvas_paste_back: components["schemas"]["ImageOutput"]; - string_collection: components["schemas"]["StringCollectionOutput"]; - face_mask_detection: components["schemas"]["FaceMaskOutput"]; + crop_latents: components["schemas"]["LatentsOutput"]; + sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; color_map_image_processor: components["schemas"]["ImageOutput"]; - normalbae_image_processor: components["schemas"]["ImageOutput"]; - create_gradient_mask: components["schemas"]["GradientMaskOutput"]; - vae_loader: components["schemas"]["VAEOutput"]; - integer: components["schemas"]["IntegerOutput"]; - image_mask_to_tensor: components["schemas"]["MaskOutput"]; - esrgan: components["schemas"]["ImageOutput"]; - mlsd_image_processor: components["schemas"]["ImageOutput"]; - blank_image: components["schemas"]["ImageOutput"]; - zoe_depth_image_processor: components["schemas"]["ImageOutput"]; - merge_metadata: components["schemas"]["MetadataOutput"]; - infill_rgba: components["schemas"]["ImageOutput"]; - mask_from_id: components["schemas"]["ImageOutput"]; - string_split: components["schemas"]["String2Output"]; - boolean_collection: components["schemas"]["BooleanCollectionOutput"]; - img_resize: components["schemas"]["ImageOutput"]; - float_collection: components["schemas"]["FloatCollectionOutput"]; - img_conv: components["schemas"]["ImageOutput"]; - boolean: components["schemas"]["BooleanOutput"]; - mul: components["schemas"]["IntegerOutput"]; - img_watermark: components["schemas"]["ImageOutput"]; - heuristic_resize: components["schemas"]["ImageOutput"]; - tile_image_processor: components["schemas"]["ImageOutput"]; - face_identifier: components["schemas"]["ImageOutput"]; + img_paste: components["schemas"]["ImageOutput"]; + image_collection: components["schemas"]["ImageCollectionOutput"]; + img_channel_offset: components["schemas"]["ImageOutput"]; + compel: components["schemas"]["ConditioningOutput"]; denoise_latents: components["schemas"]["LatentsOutput"]; - img_lerp: components["schemas"]["ImageOutput"]; + t2i_adapter: components["schemas"]["T2IAdapterOutput"]; pair_tile_image: components["schemas"]["PairTileImageOutput"]; - add: components["schemas"]["IntegerOutput"]; - range_of_size: components["schemas"]["IntegerCollectionOutput"]; + invert_tensor_mask: components["schemas"]["MaskOutput"]; + img_conv: components["schemas"]["ImageOutput"]; + img_hue_adjust: components["schemas"]["ImageOutput"]; + sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; + conditioning: components["schemas"]["ConditioningOutput"]; + mlsd_image_processor: components["schemas"]["ImageOutput"]; + image_mask_to_tensor: components["schemas"]["MaskOutput"]; + float_to_int: components["schemas"]["IntegerOutput"]; img_scale: components["schemas"]["ImageOutput"]; + color: components["schemas"]["ColorOutput"]; + midas_depth_image_processor: components["schemas"]["ImageOutput"]; + mul: components["schemas"]["IntegerOutput"]; + pidi_image_processor: components["schemas"]["ImageOutput"]; + img_crop: components["schemas"]["ImageOutput"]; + metadata_item: components["schemas"]["MetadataItemOutput"]; + sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; + metadata: components["schemas"]["MetadataOutput"]; + dynamic_prompt: components["schemas"]["StringCollectionOutput"]; + lblend: components["schemas"]["LatentsOutput"]; + canvas_paste_back: components["schemas"]["ImageOutput"]; + lresize: components["schemas"]["LatentsOutput"]; + sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; + normalbae_image_processor: components["schemas"]["ImageOutput"]; + random_range: components["schemas"]["IntegerCollectionOutput"]; + img_chan: components["schemas"]["ImageOutput"]; + alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; + img_nsfw: components["schemas"]["ImageOutput"]; + mask_combine: components["schemas"]["ImageOutput"]; + color_correct: components["schemas"]["ImageOutput"]; + latents: components["schemas"]["LatentsOutput"]; + seamless: components["schemas"]["SeamlessModeOutput"]; + lscale: components["schemas"]["LatentsOutput"]; + string_join_three: components["schemas"]["StringOutput"]; + lineart_image_processor: components["schemas"]["ImageOutput"]; + lineart_anime_image_processor: components["schemas"]["ImageOutput"]; + face_mask_detection: components["schemas"]["FaceMaskOutput"]; + esrgan: components["schemas"]["ImageOutput"]; + mask_edge: components["schemas"]["ImageOutput"]; + segment_anything_processor: components["schemas"]["ImageOutput"]; + cv_inpaint: components["schemas"]["ImageOutput"]; + tomask: components["schemas"]["ImageOutput"]; + lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; + infill_cv2: components["schemas"]["ImageOutput"]; + spandrel_image_to_image: components["schemas"]["ImageOutput"]; + conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; + img_resize: components["schemas"]["ImageOutput"]; + create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; + controlnet: components["schemas"]["ControlOutput"]; + merge_tiles_to_image: components["schemas"]["ImageOutput"]; + mediapipe_face_processor: components["schemas"]["ImageOutput"]; + img_mul: components["schemas"]["ImageOutput"]; + image: components["schemas"]["ImageOutput"]; + merge_metadata: components["schemas"]["MetadataOutput"]; + add: components["schemas"]["IntegerOutput"]; + mask_from_id: components["schemas"]["ImageOutput"]; + vae_loader: components["schemas"]["VAEOutput"]; + ideal_size: components["schemas"]["IdealSizeOutput"]; + noise: components["schemas"]["NoiseOutput"]; + sub: components["schemas"]["IntegerOutput"]; + freeu: components["schemas"]["UNetOutput"]; + lora_selector: components["schemas"]["LoRASelectorOutput"]; + integer: components["schemas"]["IntegerOutput"]; + l2i: components["schemas"]["ImageOutput"]; + integer_collection: components["schemas"]["IntegerCollectionOutput"]; + rand_float: components["schemas"]["FloatOutput"]; + iterate: components["schemas"]["IterateInvocationOutput"]; + sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; + content_shuffle_image_processor: components["schemas"]["ImageOutput"]; + dw_openpose_image_processor: components["schemas"]["ImageOutput"]; + zoe_depth_image_processor: components["schemas"]["ImageOutput"]; + hed_image_processor: components["schemas"]["ImageOutput"]; + infill_lama: components["schemas"]["ImageOutput"]; + boolean: components["schemas"]["BooleanOutput"]; + float: components["schemas"]["FloatOutput"]; + range_of_size: components["schemas"]["IntegerCollectionOutput"]; + float_collection: components["schemas"]["FloatCollectionOutput"]; + img_channel_multiply: components["schemas"]["ImageOutput"]; + prompt_from_file: components["schemas"]["StringCollectionOutput"]; + infill_rgba: components["schemas"]["ImageOutput"]; + rectangle_mask: components["schemas"]["MaskOutput"]; + depth_anything_image_processor: components["schemas"]["ImageOutput"]; + boolean_collection: components["schemas"]["BooleanCollectionOutput"]; + rand_int: components["schemas"]["IntegerOutput"]; latents_collection: components["schemas"]["LatentsCollectionOutput"]; + string_split: components["schemas"]["String2Output"]; + round_float: components["schemas"]["FloatOutput"]; canny_image_processor: components["schemas"]["ImageOutput"]; + calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; + float_range: components["schemas"]["FloatCollectionOutput"]; + img_ilerp: components["schemas"]["ImageOutput"]; + unsharp_mask: components["schemas"]["ImageOutput"]; + img_watermark: components["schemas"]["ImageOutput"]; + model_identifier: components["schemas"]["ModelIdentifierOutput"]; + string_split_neg: components["schemas"]["StringPosNegOutput"]; + calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; + collect: components["schemas"]["CollectInvocationOutput"]; + core_metadata: components["schemas"]["MetadataOutput"]; + img_pad_crop: components["schemas"]["ImageOutput"]; + i2l: components["schemas"]["LatentsOutput"]; + tiled_multi_diffusion_denoise_latents: components["schemas"]["LatentsOutput"]; + save_image: components["schemas"]["ImageOutput"]; + show_image: components["schemas"]["ImageOutput"]; + string_collection: components["schemas"]["StringCollectionOutput"]; + calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; + infill_patchmatch: components["schemas"]["ImageOutput"]; + main_model_loader: components["schemas"]["ModelLoaderOutput"]; + lora_loader: components["schemas"]["LoRALoaderOutput"]; + string: components["schemas"]["StringOutput"]; + heuristic_resize: components["schemas"]["ImageOutput"]; + create_gradient_mask: components["schemas"]["GradientMaskOutput"]; + div: components["schemas"]["IntegerOutput"]; + blank_image: components["schemas"]["ImageOutput"]; float_math: components["schemas"]["FloatOutput"]; infill_tile: components["schemas"]["ImageOutput"]; - controlnet: components["schemas"]["ControlOutput"]; - dw_openpose_image_processor: components["schemas"]["ImageOutput"]; - img_chan: components["schemas"]["ImageOutput"]; - conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; - lresize: components["schemas"]["LatentsOutput"]; - sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; - noise: components["schemas"]["NoiseOutput"]; - ip_adapter: components["schemas"]["IPAdapterOutput"]; - sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; - sub: components["schemas"]["IntegerOutput"]; - seamless: components["schemas"]["SeamlessModeOutput"]; - div: components["schemas"]["IntegerOutput"]; - img_channel_multiply: components["schemas"]["ImageOutput"]; - compel: components["schemas"]["ConditioningOutput"]; - freeu: components["schemas"]["UNetOutput"]; - conditioning: components["schemas"]["ConditioningOutput"]; - rand_int: components["schemas"]["IntegerOutput"]; - rectangle_mask: components["schemas"]["MaskOutput"]; - string_split_neg: components["schemas"]["StringPosNegOutput"]; - collect: components["schemas"]["CollectInvocationOutput"]; - float_range: components["schemas"]["FloatCollectionOutput"]; - string_join: components["schemas"]["StringOutput"]; - metadata_item: components["schemas"]["MetadataItemOutput"]; - l2i: components["schemas"]["ImageOutput"]; - img_paste: components["schemas"]["ImageOutput"]; - calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; - calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; - show_image: components["schemas"]["ImageOutput"]; - invert_tensor_mask: components["schemas"]["MaskOutput"]; - spandrel_image_to_image_autoscale: components["schemas"]["ImageOutput"]; - face_off: components["schemas"]["FaceOffOutput"]; - image_collection: components["schemas"]["ImageCollectionOutput"]; - i2l: components["schemas"]["LatentsOutput"]; - mediapipe_face_processor: components["schemas"]["ImageOutput"]; - scheduler: components["schemas"]["SchedulerOutput"]; - content_shuffle_image_processor: components["schemas"]["ImageOutput"]; - lineart_anime_image_processor: components["schemas"]["ImageOutput"]; - float_to_int: components["schemas"]["IntegerOutput"]; - pidi_image_processor: components["schemas"]["ImageOutput"]; - leres_image_processor: components["schemas"]["ImageOutput"]; - lora_selector: components["schemas"]["LoRASelectorOutput"]; - depth_anything_image_processor: components["schemas"]["ImageOutput"]; - img_pad_crop: components["schemas"]["ImageOutput"]; - calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; - prompt_from_file: components["schemas"]["StringCollectionOutput"]; - t2i_adapter: components["schemas"]["T2IAdapterOutput"]; - lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; - img_crop: components["schemas"]["ImageOutput"]; - img_mul: components["schemas"]["ImageOutput"]; - sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; - tomask: components["schemas"]["ImageOutput"]; - tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; - string_replace: components["schemas"]["StringOutput"]; - merge_tiles_to_image: components["schemas"]["ImageOutput"]; - integer_collection: components["schemas"]["IntegerCollectionOutput"]; - crop_latents: components["schemas"]["LatentsOutput"]; - lineart_image_processor: components["schemas"]["ImageOutput"]; - midas_depth_image_processor: components["schemas"]["ImageOutput"]; - segment_anything_processor: components["schemas"]["ImageOutput"]; - save_image: components["schemas"]["ImageOutput"]; - infill_cv2: components["schemas"]["ImageOutput"]; - cv_inpaint: components["schemas"]["ImageOutput"]; - lora_loader: components["schemas"]["LoRALoaderOutput"]; - latents: components["schemas"]["LatentsOutput"]; - model_identifier: components["schemas"]["ModelIdentifierOutput"]; - step_param_easing: components["schemas"]["FloatCollectionOutput"]; - metadata: components["schemas"]["MetadataOutput"]; - range: components["schemas"]["IntegerCollectionOutput"]; - img_blur: components["schemas"]["ImageOutput"]; - core_metadata: components["schemas"]["MetadataOutput"]; - mask_combine: components["schemas"]["ImageOutput"]; sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - infill_lama: components["schemas"]["ImageOutput"]; - iterate: components["schemas"]["IterateInvocationOutput"]; - hed_image_processor: components["schemas"]["ImageOutput"]; - img_nsfw: components["schemas"]["ImageOutput"]; - main_model_loader: components["schemas"]["ModelLoaderOutput"]; - string: components["schemas"]["StringOutput"]; - sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - image: components["schemas"]["ImageOutput"]; - rand_float: components["schemas"]["FloatOutput"]; - float: components["schemas"]["FloatOutput"]; - alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; - integer_math: components["schemas"]["IntegerOutput"]; + face_off: components["schemas"]["FaceOffOutput"]; + tile_image_processor: components["schemas"]["ImageOutput"]; + leres_image_processor: components["schemas"]["ImageOutput"]; + face_identifier: components["schemas"]["ImageOutput"]; clip_skip: components["schemas"]["CLIPSkipInvocationOutput"]; + scheduler: components["schemas"]["SchedulerOutput"]; + img_lerp: components["schemas"]["ImageOutput"]; + string_replace: components["schemas"]["StringOutput"]; + step_param_easing: components["schemas"]["FloatCollectionOutput"]; + spandrel_image_to_image_autoscale: components["schemas"]["ImageOutput"]; + ip_adapter: components["schemas"]["IPAdapterOutput"]; + tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; + integer_math: components["schemas"]["IntegerOutput"]; + range: components["schemas"]["IntegerCollectionOutput"]; + string_join: components["schemas"]["StringOutput"]; + img_blur: components["schemas"]["ImageOutput"]; }; /** * InvocationStartedEvent @@ -10103,6 +10142,34 @@ export type components = { /** Ui Order */ ui_order: number | null; }; + /** PaginatedResults[StylePresetRecordDTO] */ + PaginatedResults_StylePresetRecordDTO_: { + /** + * Page + * @description Current Page + */ + page: number; + /** + * Pages + * @description Total number of pages + */ + pages: number; + /** + * Per Page + * @description Number of items per page + */ + per_page: number; + /** + * Total + * @description Total number of items in result + */ + total: number; + /** + * Items + * @description Items + */ + items: components["schemas"]["StylePresetRecordDTO"][]; + }; /** PaginatedResults[WorkflowRecordListItemDTO] */ PaginatedResults_WorkflowRecordListItemDTO_: { /** @@ -10252,6 +10319,19 @@ export type components = { */ type: "pidi_image_processor"; }; + /** PresetData */ + PresetData: { + /** + * Positive Prompt + * @description Positive prompt + */ + positive_prompt: string; + /** + * Negative Prompt + * @description Negative prompt + */ + negative_prompt: string; + }; /** * ProgressImage * @description The progress image sent intermittently during processing @@ -12590,6 +12670,52 @@ export type components = { */ type: "string_split_neg"; }; + /** StylePresetChanges */ + StylePresetChanges: { + /** + * Name + * @description The style preset's new name. + */ + name?: string | null; + /** @description The updated data for style preset. */ + preset_data?: components["schemas"]["PresetData"] | null; + }; + /** StylePresetRecordDTO */ + StylePresetRecordDTO: { + /** + * Name + * @description The name of the style preset. + */ + name: string; + /** @description The preset data */ + preset_data: components["schemas"]["PresetData"]; + /** + * Id + * @description The style preset ID. + */ + id: string; + /** + * Is Default + * @description Whether or not the style preset is default + */ + is_default: boolean; + }; + /** + * StylePresetRecordOrderBy + * @description The order by options for workflow records + * @enum {string} + */ + StylePresetRecordOrderBy: "created_at" | "name"; + /** StylePresetWithoutId */ + StylePresetWithoutId: { + /** + * Name + * @description The name of the style preset. + */ + name: string; + /** @description The preset data */ + preset_data: components["schemas"]["PresetData"]; + }; /** * SubModelType * @description Submodel type. @@ -16117,4 +16243,146 @@ export type operations = { }; }; }; + /** + * Get Style Preset + * @description Gets a style preset + */ + get_style_preset: { + parameters: { + path: { + /** @description The style preset to get */ + style_preset_id: string; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + content: { + "application/json": components["schemas"]["StylePresetRecordDTO"]; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + /** + * Delete Style Preset + * @description Deletes a style preset + */ + delete_style_preset: { + parameters: { + path: { + /** @description The style preset to delete */ + style_preset_id: string; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + content: { + "application/json": unknown; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + /** + * Update Style Preset + * @description Updates a style preset + */ + update_style_preset: { + parameters: { + path: { + /** @description The id of the style preset to update */ + style_preset_id: string; + }; + }; + requestBody: { + content: { + "application/json": components["schemas"]["Body_update_style_preset"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + content: { + "application/json": components["schemas"]["StylePresetRecordDTO"]; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + /** + * List Style Presets + * @description Gets a page of style presets + */ + list_style_presets: { + parameters: { + query?: { + /** @description The page to get */ + page?: number; + /** @description The number of style presets per page */ + per_page?: number; + /** @description The attribute to order by */ + order_by?: components["schemas"]["StylePresetRecordOrderBy"]; + /** @description The direction to order by */ + direction?: components["schemas"]["SQLiteDirection"]; + /** @description The text to query by (matches name and description) */ + query?: string | null; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + content: { + "application/json": components["schemas"]["PaginatedResults_StylePresetRecordDTO_"]; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + /** + * Create Style Preset + * @description Creates a style preset + */ + create_style_preset: { + requestBody: { + content: { + "application/json": components["schemas"]["Body_create_style_preset"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + content: { + "application/json": components["schemas"]["StylePresetRecordDTO"]; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; }; From a61209206b0c395ddaa7702374d70afb001c53f4 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 5 Aug 2024 13:40:46 -0400 Subject: [PATCH 04/48] remove custom SDXL prompts component --- .../frontend/web/src/app/components/App.tsx | 11 +++++++++- .../web/src/app/components/InvokeAIUI.tsx | 5 ++++- .../frontend/web/src/app/types/invokeai.ts | 1 - .../parameters/components/Prompts/Prompts.tsx | 16 ++++++++++++++ .../SDXLPrompts/SDXLPrompts.stories.tsx | 20 ----------------- .../components/SDXLPrompts/SDXLPrompts.tsx | 22 ------------------- .../ParametersPanelCanvas.tsx | 3 +-- .../ParametersPanelTextToImage.tsx | 3 +-- .../ParametersPanelUpscale.tsx | 6 +---- 9 files changed, 33 insertions(+), 54 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/SDXLPrompts.stories.tsx delete mode 100644 invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/SDXLPrompts.tsx diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 2d878d96e7..760eddbee8 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -16,6 +16,8 @@ import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterM import { configChanged } from 'features/system/store/configSlice'; import { languageSelector } from 'features/system/store/systemSelectors'; import InvokeTabs from 'features/ui/components/InvokeTabs'; +import type { InvokeTabName } from 'features/ui/store/tabMap'; +import { setActiveTab } from 'features/ui/store/uiSlice'; import { AnimatePresence } from 'framer-motion'; import i18n from 'i18n'; import { size } from 'lodash-es'; @@ -34,9 +36,10 @@ interface Props { imageName: string; action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters'; }; + destination?: InvokeTabName | undefined; } -const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => { +const App = ({ config = DEFAULT_CONFIG, selectedImage, destination }: Props) => { const language = useAppSelector(languageSelector); const logger = useLogger('system'); const dispatch = useAppDispatch(); @@ -67,6 +70,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => { } }, [dispatch, config, logger]); + useEffect(() => { + if (destination) { + dispatch(setActiveTab(destination)); + } + }, [dispatch, destination]); + useEffect(() => { dispatch(appStarted()); }, [dispatch]); diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index 1dd1a265fb..0a80b7e92d 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -19,6 +19,7 @@ import type { PartialAppConfig } from 'app/types/invokeai'; import Loading from 'common/components/Loading/Loading'; import AppDndContext from 'features/dnd/components/AppDndContext'; import type { WorkflowCategory } from 'features/nodes/types/workflow'; +import type { InvokeTabName } from 'features/ui/store/tabMap'; import type { PropsWithChildren, ReactNode } from 'react'; import React, { lazy, memo, useEffect, useMemo } from 'react'; import { Provider } from 'react-redux'; @@ -43,6 +44,7 @@ interface Props extends PropsWithChildren { imageName: string; action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters'; }; + destination?: InvokeTabName; customStarUi?: CustomStarUi; socketOptions?: Partial; isDebugging?: boolean; @@ -62,6 +64,7 @@ const InvokeAIUI = ({ projectUrl, queueId, selectedImage, + destination, customStarUi, socketOptions, isDebugging = false, @@ -218,7 +221,7 @@ const InvokeAIUI = ({ }> - + diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index bb27cf58a8..8cc8422c24 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -69,7 +69,6 @@ export type AppConfig = { disabledTabs: InvokeTabName[]; disabledFeatures: AppFeature[]; disabledSDFeatures: SDFeature[]; - canRestoreDeletedImagesFromBin: boolean; nodesAllowlist: string[] | undefined; nodesDenylist: string[] | undefined; metadataFetchDebounce?: number; diff --git a/invokeai/frontend/web/src/features/parameters/components/Prompts/Prompts.tsx b/invokeai/frontend/web/src/features/parameters/components/Prompts/Prompts.tsx index 320c15f7df..767ee97459 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Prompts/Prompts.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Prompts/Prompts.tsx @@ -1,13 +1,29 @@ import { Flex } from '@invoke-ai/ui-library'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectControlLayersSlice } from 'features/controlLayers/store/controlLayersSlice'; import { ParamNegativePrompt } from 'features/parameters/components/Core/ParamNegativePrompt'; import { ParamPositivePrompt } from 'features/parameters/components/Core/ParamPositivePrompt'; +import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; +import { ParamSDXLNegativeStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt'; +import { ParamSDXLPositiveStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLPositiveStylePrompt'; import { memo } from 'react'; +const concatPromptsSelector = createSelector( + [selectGenerationSlice, selectControlLayersSlice], + (generation, controlLayers) => { + return generation.model?.base === 'sdxl' && controlLayers.present.shouldConcatPrompts; + } +); + export const Prompts = memo(() => { + const shouldConcatPrompts = useAppSelector(concatPromptsSelector); return ( + {!shouldConcatPrompts && } + {!shouldConcatPrompts && } ); }); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/SDXLPrompts.stories.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/SDXLPrompts.stories.tsx deleted file mode 100644 index 75c288f7d1..0000000000 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/SDXLPrompts.stories.tsx +++ /dev/null @@ -1,20 +0,0 @@ -import type { Meta, StoryObj } from '@storybook/react'; - -import { SDXLPrompts } from './SDXLPrompts'; - -const meta: Meta = { - title: 'Feature/Prompt/SDXLPrompts', - tags: ['autodocs'], - component: SDXLPrompts, -}; - -export default meta; -type Story = StoryObj; - -const Component = () => { - return ; -}; - -export const Default: Story = { - render: Component, -}; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/SDXLPrompts.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/SDXLPrompts.tsx deleted file mode 100644 index b585e92a5f..0000000000 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/SDXLPrompts.tsx +++ /dev/null @@ -1,22 +0,0 @@ -import { Flex } from '@invoke-ai/ui-library'; -import { useAppSelector } from 'app/store/storeHooks'; -import { ParamNegativePrompt } from 'features/parameters/components/Core/ParamNegativePrompt'; -import { ParamPositivePrompt } from 'features/parameters/components/Core/ParamPositivePrompt'; -import { memo } from 'react'; - -import { ParamSDXLNegativeStylePrompt } from './ParamSDXLNegativeStylePrompt'; -import { ParamSDXLPositiveStylePrompt } from './ParamSDXLPositiveStylePrompt'; - -export const SDXLPrompts = memo(() => { - const shouldConcatPrompts = useAppSelector((s) => s.controlLayers.present.shouldConcatPrompts); - return ( - - - {!shouldConcatPrompts && } - - {!shouldConcatPrompts && } - - ); -}); - -SDXLPrompts.displayName = 'SDXLPrompts'; diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelCanvas.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelCanvas.tsx index 622ed96696..a2ffe8d497 100644 --- a/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelCanvas.tsx +++ b/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelCanvas.tsx @@ -3,7 +3,6 @@ import { useAppSelector } from 'app/store/storeHooks'; import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants'; import { Prompts } from 'features/parameters/components/Prompts/Prompts'; import QueueControls from 'features/queue/components/QueueControls'; -import { SDXLPrompts } from 'features/sdxl/components/SDXLPrompts/SDXLPrompts'; import { AdvancedSettingsAccordion } from 'features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion'; import { CompositingSettingsAccordion } from 'features/settingsAccordions/components/CompositingSettingsAccordion/CompositingSettingsAccordion'; import { ControlSettingsAccordion } from 'features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion'; @@ -29,7 +28,7 @@ const ParametersPanelCanvas = () => { - {isSDXL ? : } + diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelTextToImage.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelTextToImage.tsx index 3c58a08e4c..05c6d06216 100644 --- a/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelTextToImage.tsx +++ b/invokeai/frontend/web/src/features/ui/components/ParametersPanels/ParametersPanelTextToImage.tsx @@ -7,7 +7,6 @@ import { $isPreviewVisible } from 'features/controlLayers/store/controlLayersSli import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { Prompts } from 'features/parameters/components/Prompts/Prompts'; import QueueControls from 'features/queue/components/QueueControls'; -import { SDXLPrompts } from 'features/sdxl/components/SDXLPrompts/SDXLPrompts'; import { AdvancedSettingsAccordion } from 'features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion'; import { CompositingSettingsAccordion } from 'features/settingsAccordions/components/CompositingSettingsAccordion/CompositingSettingsAccordion'; import { ControlSettingsAccordion } from 'features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion'; @@ -66,7 +65,7 @@ const ParametersPanelTextToImage = () => { - {isSDXL ? : } + { - const isSDXL = useAppSelector((s) => s.generation.model?.base === 'sdxl'); - return ( @@ -26,7 +22,7 @@ const ParametersPanelUpscale = () => { - {isSDXL ? : } + From af9110e964a6d6e90b7e0116c007d515f9f2e841 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 5 Aug 2024 13:42:28 -0400 Subject: [PATCH 05/48] fix prompt concat logic --- .../web/src/features/parameters/components/Prompts/Prompts.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/parameters/components/Prompts/Prompts.tsx b/invokeai/frontend/web/src/features/parameters/components/Prompts/Prompts.tsx index 767ee97459..c41f929ae9 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Prompts/Prompts.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Prompts/Prompts.tsx @@ -12,7 +12,7 @@ import { memo } from 'react'; const concatPromptsSelector = createSelector( [selectGenerationSlice, selectControlLayersSlice], (generation, controlLayers) => { - return generation.model?.base === 'sdxl' && controlLayers.present.shouldConcatPrompts; + return generation.model?.base !== 'sdxl' || controlLayers.present.shouldConcatPrompts; } ); From fd7a635777fba264a8e2faba5694726ecd6cd517 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Mon, 5 Aug 2024 15:48:23 -0400 Subject: [PATCH 06/48] (ui) the most basic crud ui: view list of presets, create a new preset, edit/delete existing presets --- .../frontend/web/src/app/components/App.tsx | 2 + invokeai/frontend/web/src/app/store/store.ts | 2 + .../parameters/components/Prompts/Prompts.tsx | 2 + .../components/StylePresetForm.tsx | 88 +++++++++++++++++++ .../components/StylePresetListItem.tsx | 46 ++++++++++ .../components/StylePresetMenu.tsx | 34 +++++++ .../components/StylePresetMenuTrigger.tsx | 24 +++++ .../components/StylePresetModal.tsx | 43 +++++++++ .../src/features/stylePresets/store/slice.ts | 30 +++++++ .../src/features/stylePresets/store/types.ts | 8 ++ .../services/api/endpoints/stylePresets.ts | 14 ++- 11 files changed, 290 insertions(+), 3 deletions(-) create mode 100644 invokeai/frontend/web/src/features/stylePresets/components/StylePresetForm.tsx create mode 100644 invokeai/frontend/web/src/features/stylePresets/components/StylePresetListItem.tsx create mode 100644 invokeai/frontend/web/src/features/stylePresets/components/StylePresetMenu.tsx create mode 100644 invokeai/frontend/web/src/features/stylePresets/components/StylePresetMenuTrigger.tsx create mode 100644 invokeai/frontend/web/src/features/stylePresets/components/StylePresetModal.tsx create mode 100644 invokeai/frontend/web/src/features/stylePresets/store/slice.ts create mode 100644 invokeai/frontend/web/src/features/stylePresets/store/types.ts diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 760eddbee8..07b959e684 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -13,6 +13,7 @@ import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardMo import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal'; import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal'; import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast'; +import { StylePresetModal } from 'features/stylePresets/components/StylePresetModal'; import { configChanged } from 'features/system/store/configSlice'; import { languageSelector } from 'features/system/store/systemSelectors'; import InvokeTabs from 'features/ui/components/InvokeTabs'; @@ -104,6 +105,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, destination }: Props) => + ); diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 6ae2011355..f061d0e59f 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -28,6 +28,7 @@ import { generationPersistConfig, generationSlice } from 'features/parameters/st import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice'; import { queueSlice } from 'features/queue/store/queueSlice'; import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice'; +import { stylePresetModalSlice } from 'features/stylePresets/store/slice'; import { configSlice } from 'features/system/store/configSlice'; import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice'; import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice'; @@ -69,6 +70,7 @@ const allReducers = { [workflowSettingsSlice.name]: workflowSettingsSlice.reducer, [api.reducerPath]: api.reducer, [upscaleSlice.name]: upscaleSlice.reducer, + [stylePresetModalSlice.name]: stylePresetModalSlice.reducer }; const rootReducer = combineReducers(allReducers); diff --git a/invokeai/frontend/web/src/features/parameters/components/Prompts/Prompts.tsx b/invokeai/frontend/web/src/features/parameters/components/Prompts/Prompts.tsx index c41f929ae9..39746c9ba6 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Prompts/Prompts.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Prompts/Prompts.tsx @@ -7,6 +7,7 @@ import { ParamPositivePrompt } from 'features/parameters/components/Core/ParamPo import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { ParamSDXLNegativeStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt'; import { ParamSDXLPositiveStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLPositiveStylePrompt'; +import { StylePresetMenuTrigger } from 'features/stylePresets/components/StylePresetMenuTrigger'; import { memo } from 'react'; const concatPromptsSelector = createSelector( @@ -20,6 +21,7 @@ export const Prompts = memo(() => { const shouldConcatPrompts = useAppSelector(concatPromptsSelector); return ( + {!shouldConcatPrompts && } diff --git a/invokeai/frontend/web/src/features/stylePresets/components/StylePresetForm.tsx b/invokeai/frontend/web/src/features/stylePresets/components/StylePresetForm.tsx new file mode 100644 index 0000000000..201c1e27cc --- /dev/null +++ b/invokeai/frontend/web/src/features/stylePresets/components/StylePresetForm.tsx @@ -0,0 +1,88 @@ +import { Button, Flex, FormControl, FormLabel, Input, Textarea } from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { isModalOpenChanged,updatingStylePresetChanged } from 'features/stylePresets/store/slice'; +import { toast } from 'features/toast/toast'; +import type { ChangeEventHandler} from 'react'; +import { useCallback, useEffect, useState } from 'react'; +import type { + StylePresetRecordDTO} from 'services/api/endpoints/stylePresets'; +import { + useCreateStylePresetMutation, + useUpdateStylePresetMutation, +} from 'services/api/endpoints/stylePresets'; + +export const StylePresetForm = ({ updatingPreset }: { updatingPreset: StylePresetRecordDTO | null }) => { + const [createStylePreset] = useCreateStylePresetMutation(); + const [updateStylePreset] = useUpdateStylePresetMutation(); + const dispatch = useAppDispatch(); + + const [name, setName] = useState(updatingPreset ? updatingPreset.name : ''); + const [posPrompt, setPosPrompt] = useState(updatingPreset ? updatingPreset.preset_data.positive_prompt : ''); + const [negPrompt, setNegPrompt] = useState(updatingPreset ? updatingPreset.preset_data.negative_prompt : ''); + + const handleChangeName = useCallback>((e) => { + setName(e.target.value); + }, []); + + const handleChangePosPrompt = useCallback>((e) => { + setPosPrompt(e.target.value); + }, []); + + const handleChangeNegPrompt = useCallback>((e) => { + setNegPrompt(e.target.value); + }, []); + + useEffect(() => { + if (updatingPreset) { + setName(updatingPreset.name); + setPosPrompt(updatingPreset.preset_data.positive_prompt); + setNegPrompt(updatingPreset.preset_data.negative_prompt); + } else { + setName(''); + setPosPrompt(''); + setNegPrompt(''); + } + }, [updatingPreset]); + + const handleClickSave = useCallback(async () => { + try { + if (updatingPreset) { + await updateStylePreset({ + id: updatingPreset.id, + changes: { name, preset_data: { positive_prompt: posPrompt, negative_prompt: negPrompt } }, + }).unwrap(); + } else { + await createStylePreset({ + name: name, + preset_data: { positive_prompt: posPrompt, negative_prompt: negPrompt }, + }).unwrap(); + } + } catch (error) { + toast({ + status: 'error', + title: 'Failed to save style preset', + }); + } + + dispatch(updatingStylePresetChanged(null)); + dispatch(isModalOpenChanged(false)); + }, [dispatch, updatingPreset, name, posPrompt, negPrompt, updateStylePreset, createStylePreset]); + + return ( + + + Name + + + + Positive Prompt +