From 217fe40d99eb479f827bb5014c2dbed9d388ef33 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Fri, 2 Aug 2024 12:29:54 -0400 Subject: [PATCH] feat(api): add style_presets router, make sure all CRUD is working, add is_default --- invokeai/app/api/dependencies.py | 4 +- invokeai/app/api/routers/style_presets.py | 100 ++++++++++++++++++ invokeai/app/api_app.py | 2 + invokeai/app/services/invocation_services.py | 4 +- .../migrations/migration_14.py | 5 +- .../style_preset_records_base.py | 12 ++- .../style_preset_records_common.py | 20 +++- .../style_preset_records_sqlite.py | 52 +++++---- 8 files changed, 165 insertions(+), 34 deletions(-) create mode 100644 invokeai/app/api/routers/style_presets.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 18ffd39947..d2eeab6219 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -110,7 +110,7 @@ class ApiDependencies: session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() workflow_records = SqliteWorkflowRecordsStorage(db=db) - style_presets = SqliteStylePresetRecordsStorage(db=db) + style_preset_records = SqliteStylePresetRecordsStorage(db=db) services = InvocationServices( board_image_records=board_image_records, @@ -136,7 +136,7 @@ class ApiDependencies: workflow_records=workflow_records, tensors=tensors, conditioning=conditioning, - style_presets=style_presets, + style_preset_records=style_preset_records, ) ApiDependencies.invoker = Invoker(services) diff --git a/invokeai/app/api/routers/style_presets.py b/invokeai/app/api/routers/style_presets.py new file mode 100644 index 0000000000..ae38db76aa --- /dev/null +++ b/invokeai/app/api/routers/style_presets.py @@ -0,0 +1,100 @@ +from typing import Optional + +from fastapi import APIRouter, Body, HTTPException, Path, Query + +from invokeai.app.api.dependencies import ApiDependencies +from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection +from invokeai.app.services.style_preset_records.style_preset_records_common import ( + StylePresetChanges, + StylePresetNotFoundError, + StylePresetRecordDTO, + StylePresetWithoutId, + StylePresetRecordOrderBy, +) + + +style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"]) + + +@style_presets_router.get( + "/i/{style_preset_id}", + operation_id="get_style_preset", + responses={ + 200: {"model": StylePresetRecordDTO}, + }, +) +async def get_style_preset( + style_preset_id: str = Path(description="The style preset to get"), +) -> StylePresetRecordDTO: + """Gets a style preset""" + try: + return ApiDependencies.invoker.services.style_preset_records.get(style_preset_id) + except StylePresetNotFoundError: + raise HTTPException(status_code=404, detail="Style preset not found") + + +@style_presets_router.patch( + "/i/{style_preset_id}", + operation_id="update_style_preset", + responses={ + 200: {"model": StylePresetRecordDTO}, + }, +) +async def update_style_preset( + style_preset_id: str = Path(description="The id of the style preset to update"), + changes: StylePresetChanges = Body(description="The updated style preset", embed=True), +) -> StylePresetRecordDTO: + """Updates a style preset""" + return ApiDependencies.invoker.services.style_preset_records.update(id=style_preset_id, changes=changes) + + +@style_presets_router.delete( + "/i/{style_preset_id}", + operation_id="delete_style_preset", +) +async def delete_style_preset( + style_preset_id: str = Path(description="The style preset to delete"), +) -> None: + """Deletes a style preset""" + ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id) + + +@style_presets_router.post( + "/", + operation_id="create_style_preset", + responses={ + 200: {"model": StylePresetRecordDTO}, + }, +) +async def create_style_preset( + style_preset: StylePresetWithoutId = Body(description="The style preset to create", embed=True), +) -> StylePresetRecordDTO: + """Creates a style preset""" + return ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset) + + +@style_presets_router.get( + "/", + operation_id="list_style_presets", + responses={ + 200: {"model": PaginatedResults[StylePresetRecordDTO]}, + }, +) +async def list_style_presets( + page: int = Query(default=0, description="The page to get"), + per_page: int = Query(default=10, description="The number of style presets per page"), + order_by: StylePresetRecordOrderBy = Query( + default=StylePresetRecordOrderBy.Name, description="The attribute to order by" + ), + direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"), + query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"), +) -> PaginatedResults[StylePresetRecordDTO]: + """Gets a page of style presets""" + return ApiDependencies.invoker.services.style_preset_records.get_many( + page=page, + per_page=per_page, + order_by=order_by, + direction=direction, + query=query, + ) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 88820a0c4c..492c465573 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -32,6 +32,7 @@ from invokeai.app.api.routers import ( session_queue, utilities, workflows, + style_presets, ) from invokeai.app.api.sockets import SocketIO from invokeai.app.services.config.config_default import get_config @@ -106,6 +107,7 @@ app.include_router(board_images.board_images_router, prefix="/api") app.include_router(app_info.app_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api") app.include_router(workflows.workflows_router, prefix="/api") +app.include_router(style_presets.style_presets_router, prefix="/api") app.openapi = get_openapi_func(app) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 439625e23f..c756eae8e5 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -62,7 +62,7 @@ class InvocationServices: workflow_records: "WorkflowRecordsStorageBase", tensors: "ObjectSerializerBase[torch.Tensor]", conditioning: "ObjectSerializerBase[ConditioningFieldData]", - style_presets: "StylePresetRecordsStorageBase", + style_preset_records: "StylePresetRecordsStorageBase", ): self.board_images = board_images self.board_image_records = board_image_records @@ -87,4 +87,4 @@ class InvocationServices: self.workflow_records = workflow_records self.tensors = tensors self.conditioning = conditioning - self.style_presets = style_presets + self.style_preset_records = style_preset_records diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py index 38a35d547c..e7e77651dd 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_14.py @@ -15,13 +15,14 @@ class Migration14Callback: id TEXT NOT NULL PRIMARY KEY, name TEXT NOT NULL, preset_data TEXT NOT NULL, + is_default BOOLEAN DEFAULT FALSE, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Updated via trigger updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) ); """ - ] - + ] + # Add trigger for `updated_at`. triggers = [ """--sql diff --git a/invokeai/app/services/style_preset_records/style_preset_records_base.py b/invokeai/app/services/style_preset_records/style_preset_records_base.py index e7e9c6655a..58121c3445 100644 --- a/invokeai/app/services/style_preset_records/style_preset_records_base.py +++ b/invokeai/app/services/style_preset_records/style_preset_records_base.py @@ -3,9 +3,11 @@ from typing import Optional from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection -from invokeai.app.services.style_preset_records.style_preset_records_common import StylePresetRecordDTO, StylePresetWithoutId -from invokeai.app.services.workflow_records.workflow_records_common import ( - WorkflowRecordOrderBy, +from invokeai.app.services.style_preset_records.style_preset_records_common import ( + StylePresetChanges, + StylePresetRecordDTO, + StylePresetWithoutId, + StylePresetRecordOrderBy, ) @@ -23,7 +25,7 @@ class StylePresetRecordsStorageBase(ABC): pass @abstractmethod - def update(self, id: str, changes: StylePresetWithoutId) -> StylePresetRecordDTO: + def update(self, id: str, changes: StylePresetChanges) -> StylePresetRecordDTO: """Updates a style preset.""" pass @@ -37,7 +39,7 @@ class StylePresetRecordsStorageBase(ABC): self, page: int, per_page: int, - order_by: WorkflowRecordOrderBy, + order_by: StylePresetRecordOrderBy, direction: SQLiteDirection, query: Optional[str], ) -> PaginatedResults[StylePresetRecordDTO]: diff --git a/invokeai/app/services/style_preset_records/style_preset_records_common.py b/invokeai/app/services/style_preset_records/style_preset_records_common.py index a5b787b9cf..7ce7ccc730 100644 --- a/invokeai/app/services/style_preset_records/style_preset_records_common.py +++ b/invokeai/app/services/style_preset_records/style_preset_records_common.py @@ -1,13 +1,23 @@ -from typing import Any, Union +from enum import Enum +from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator +from invokeai.app.util.metaenum import MetaEnum + class StylePresetNotFoundError(Exception): """Raised when a style preset is not found""" -class PresetData(BaseModel): +class StylePresetRecordOrderBy(str, Enum, metaclass=MetaEnum): + """The order by options for workflow records""" + + CreatedAt = "created_at" + Name = "name" + + +class PresetData(BaseModel, extra="forbid"): positive_prompt: str = Field(description="Positive prompt") negative_prompt: str = Field(description="Negative prompt") @@ -15,6 +25,11 @@ class PresetData(BaseModel): PresetDataValidator = TypeAdapter(PresetData) +class StylePresetChanges(BaseModel, extra="forbid"): + name: Optional[str] = Field(default=None, description="The style preset's new name.") + preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.") + + class StylePresetWithoutId(BaseModel): name: str = Field(description="The name of the style preset.") preset_data: PresetData = Field(description="The preset data") @@ -22,6 +37,7 @@ class StylePresetWithoutId(BaseModel): class StylePresetRecordDTO(StylePresetWithoutId): id: str = Field(description="The style preset ID.") + is_default: bool = Field(description="Whether or not the style preset is default") @classmethod def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO": diff --git a/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py index 8ad44f4054..e0110afa9a 100644 --- a/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py +++ b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py @@ -7,12 +7,11 @@ from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase from invokeai.app.services.style_preset_records.style_preset_records_common import ( + StylePresetChanges, StylePresetNotFoundError, StylePresetRecordDTO, StylePresetWithoutId, -) -from invokeai.app.services.workflow_records.workflow_records_common import ( - WorkflowRecordOrderBy, + StylePresetRecordOrderBy, ) from invokeai.app.util.misc import uuid_string @@ -58,14 +57,14 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase): INSERT OR IGNORE INTO style_presets ( id, name, - preset_data, + preset_data ) VALUES (?, ?, ?); """, ( id, style_preset.name, - style_preset.preset_data, + style_preset.preset_data.model_dump_json(), ), ) self._conn.commit() @@ -76,20 +75,31 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase): self._lock.release() return self.get(id) - def update(self, id: str, changes: StylePresetWithoutId) -> StylePresetRecordDTO: + def update(self, id: str, changes: StylePresetChanges) -> StylePresetRecordDTO: try: self._lock.acquire() - self._cursor.execute( - """--sql - UPDATE style_presets - SET preset_data = ? - WHERE id = ? ; - """, - ( - changes.preset_data, - id, - ), - ) + # Change the name of a style preset + if changes.name is not None: + self._cursor.execute( + """--sql + UPDATE style_presets + SET name = ? + WHERE id = ?; + """, + (changes.name, id), + ) + + # Change the preset data for a style preset + if changes.preset_data is not None: + self._cursor.execute( + """--sql + UPDATE style_presets + SET preset_data = ? + WHERE id = ?; + """, + (changes.preset_data.model_dump_json(), id), + ) + self._conn.commit() except Exception: self._conn.rollback() @@ -120,14 +130,14 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase): self, page: int, per_page: int, - order_by: WorkflowRecordOrderBy, + order_by: StylePresetRecordOrderBy, direction: SQLiteDirection, query: Optional[str] = None, ) -> PaginatedResults[StylePresetRecordDTO]: try: self._lock.acquire() # sanitize! - assert order_by in WorkflowRecordOrderBy + assert order_by in StylePresetRecordOrderBy assert direction in SQLiteDirection count_query = "SELECT COUNT(*) FROM style_presets" main_query = """ @@ -140,8 +150,8 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase): stripped_query = query.strip() if query else None if stripped_query: wildcard_query = "%" + stripped_query + "%" - main_query += " AND name LIKE ? OR description LIKE ? " - count_query += " AND name LIKE ? OR description LIKE ?;" + main_query += " AND name LIKE ? " + count_query += " AND name LIKE ?;" main_params.extend([wildcard_query, wildcard_query]) count_params.extend([wildcard_query, wildcard_query])