diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index c9a2f0a843..ae4882c0d0 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -30,6 +30,7 @@ from ..services.shared.default_graphs import create_system_graphs from ..services.shared.graph import GraphExecutionState, LibraryGraph from ..services.shared.sqlite import SqliteDatabase from ..services.urls.urls_default import LocalUrlService +from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from .events import FastAPIEventService @@ -90,6 +91,7 @@ class ApiDependencies: session_processor = DefaultSessionProcessor() session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() + workflow_records = SqliteWorkflowRecordsStorage(db=db) services = InvocationServices( board_image_records=board_image_records, @@ -114,6 +116,7 @@ class ApiDependencies: session_processor=session_processor, session_queue=session_queue, urls=urls, + workflow_records=workflow_records, ) create_system_graphs(services.graph_library) diff --git a/invokeai/app/api/routers/workflows.py b/invokeai/app/api/routers/workflows.py new file mode 100644 index 0000000000..814123fc81 --- /dev/null +++ b/invokeai/app/api/routers/workflows.py @@ -0,0 +1,20 @@ +from fastapi import APIRouter, Body, Path + +from invokeai.app.api.dependencies import ApiDependencies +from invokeai.app.services.workflow_records.workflow_records_common import WorkflowField + +workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"]) + + +@workflows_router.get( + "/i/{workflow_id}", + operation_id="get_workflow", + responses={ + 200: {"model": WorkflowField}, + }, +) +async def get_workflow( + workflow_id: str = Path(description="The workflow to get"), +) -> WorkflowField: + """Gets a workflow""" + return ApiDependencies.invoker.services.workflow_records.get(workflow_id) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 866a6665c8..e04cf564ab 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -38,7 +38,17 @@ if True: # hack to make flake8 happy with imports coming after setting up the c from ..backend.util.logging import InvokeAILogger from .api.dependencies import ApiDependencies - from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities + from .api.routers import ( + app_info, + board_images, + boards, + images, + models, + sessions, + session_queue, + utilities, + workflows, + ) from .api.sockets import SocketIO from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField @@ -95,18 +105,13 @@ async def shutdown_event() -> None: app.include_router(sessions.session_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api") - app.include_router(models.models_router, prefix="/api") - app.include_router(images.images_router, prefix="/api") - app.include_router(boards.boards_router, prefix="/api") - app.include_router(board_images.board_images_router, prefix="/api") - app.include_router(app_info.app_router, prefix="/api") - app.include_router(session_queue.session_queue_router, prefix="/api") +app.include_router(workflows.workflows_router, prefix="/api") # Build a custom OpenAPI to include all outputs @@ -166,7 +171,6 @@ def custom_openapi() -> dict[str, Any]: # print(f"Config with name {name} already defined") continue - # "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"} openapi_schema["components"]["schemas"][name] = dict( title=name, description="An enumeration.", diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index ba94e7c440..f2e5f33e6e 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -2,7 +2,6 @@ from __future__ import annotations -import json import re from abc import ABC, abstractmethod from enum import Enum @@ -11,12 +10,13 @@ from types import UnionType from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union import semver -from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator +from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model from pydantic.fields import _Unset from pydantic_core import PydanticUndefined from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.util.misc import uuid_string +from invokeai.app.services.workflow_records.workflow_records_common import WorkflowField if TYPE_CHECKING: from ..services.invocation_services import InvocationServices @@ -60,7 +60,7 @@ class FieldDescriptions: denoised_latents = "Denoised latents tensor" latents = "Latents tensor" strength = "Strength of denoising (proportional to steps)" - core_metadata = "Optional core metadata to be written to image" + workflow = "Optional workflow to be saved with the image" interp_mode = "Interpolation mode" torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)" fp32 = "Whether or not to use full float32 precision" @@ -665,27 +665,7 @@ class BaseInvocation(ABC, BaseModel): description="Whether or not this is an intermediate invocation.", json_schema_extra=dict(ui_type=UIType.IsIntermediate), ) - workflow: Optional[str] = Field( - default=None, - description="The workflow to save with the image", - json_schema_extra=dict(ui_type=UIType.WorkflowField), - ) - use_cache: Optional[bool] = Field( - default=True, - description="Whether or not to use the cache", - ) - - @field_validator("workflow", mode="before") - @classmethod - def validate_workflow_is_json(cls, v): - """We don't have a workflow schema in the backend, so we just check that it's valid JSON""" - if v is None: - return None - try: - json.loads(v) - except json.decoder.JSONDecodeError: - raise ValueError("Workflow must be valid JSON") - return v + use_cache: bool = InputField(default=True, description="Whether or not to use the cache") UIConfig: ClassVar[Type[UIConfigBase]] @@ -824,4 +804,6 @@ def invocation_output( return wrapper -GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel) +class WithWorkflow(BaseModel): + workflow: Optional[WorkflowField] = InputField(default=None, description=FieldDescriptions.workflow) + diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index ba53ea50cf..94db75d810 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from .session_queue.session_queue_base import SessionQueueBase from .shared.graph import GraphExecutionState, LibraryGraph from .urls.urls_base import UrlServiceBase + from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase class InvocationServices: @@ -55,6 +56,7 @@ class InvocationServices: invocation_cache: "InvocationCacheBase" names: "NameServiceBase" urls: "UrlServiceBase" + workflow_records: "WorkflowRecordsStorageBase" def __init__( self, @@ -80,6 +82,7 @@ class InvocationServices: invocation_cache: "InvocationCacheBase", names: "NameServiceBase", urls: "UrlServiceBase", + workflow_records: "WorkflowRecordsStorageBase", ): self.board_images = board_images self.board_image_records = board_image_records @@ -103,3 +106,4 @@ class InvocationServices: self.invocation_cache = invocation_cache self.names = names self.urls = urls + self.workflow_records = workflow_records diff --git a/invokeai/app/services/workflow_records/__init__.py b/invokeai/app/services/workflow_records/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/app/services/workflow_records/workflow_records_base.py b/invokeai/app/services/workflow_records/workflow_records_base.py new file mode 100644 index 0000000000..97f7cfe3c0 --- /dev/null +++ b/invokeai/app/services/workflow_records/workflow_records_base.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod + +from invokeai.app.services.workflow_records.workflow_records_common import WorkflowField + + +class WorkflowRecordsStorageBase(ABC): + """Base class for workflow storage services.""" + + @abstractmethod + def get(self, workflow_id: str) -> WorkflowField: + """Get workflow by id.""" + pass + + @abstractmethod + def create(self, workflow: WorkflowField) -> WorkflowField: + """Creates a workflow.""" + pass diff --git a/invokeai/app/services/workflow_records/workflow_records_common.py b/invokeai/app/services/workflow_records/workflow_records_common.py new file mode 100644 index 0000000000..d548656dab --- /dev/null +++ b/invokeai/app/services/workflow_records/workflow_records_common.py @@ -0,0 +1,22 @@ +from typing import Any + +from pydantic import Field, RootModel, TypeAdapter + + +class WorkflowNotFoundError(Exception): + """Raised when a workflow is not found""" + + +class WorkflowField(RootModel): + """ + Pydantic model for workflows with custom root of type dict[str, Any]. + Workflows are stored without a strict schema. + """ + + root: dict[str, Any] = Field(description="Workflow dict") + + def model_dump(self, *args, **kwargs) -> dict[str, Any]: + return super().model_dump(*args, **kwargs)["root"] + + +type_adapter_WorkflowField = TypeAdapter(WorkflowField) diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py new file mode 100644 index 0000000000..2b284ac03f --- /dev/null +++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py @@ -0,0 +1,148 @@ +import sqlite3 +import threading + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.shared.sqlite import SqliteDatabase +from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase +from invokeai.app.services.workflow_records.workflow_records_common import ( + WorkflowField, + WorkflowNotFoundError, + type_adapter_WorkflowField, +) +from invokeai.app.util.misc import uuid_string + + +class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): + _invoker: Invoker + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _lock: threading.RLock + + def __init__(self, db: SqliteDatabase) -> None: + super().__init__() + self._lock = db.lock + self._conn = db.conn + self._cursor = self._conn.cursor() + self._create_tables() + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + + def get(self, workflow_id: str) -> WorkflowField: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT workflow + FROM workflows + WHERE workflow_id = ?; + """, + (workflow_id,), + ) + row = self._cursor.fetchone() + if row is None: + raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found") + return type_adapter_WorkflowField.validate_json(row[0]) + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + + def create(self, workflow: WorkflowField) -> WorkflowField: + try: + # workflows do not have ids until they are saved + workflow_id = uuid_string() + workflow.root["id"] = workflow_id + self._lock.acquire() + self._cursor.execute( + """--sql + INSERT INTO workflows(workflow) + VALUES (?); + """, + (workflow.json(),), + ) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + return self.get(workflow_id) + + def _create_tables(self) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS workflows ( + workflow TEXT NOT NULL, + workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger + ); + """ + ) + + self._cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at + AFTER UPDATE + ON workflows FOR EACH ROW + BEGIN + UPDATE workflows + SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE workflow_id = old.workflow_id; + END; + """ + ) + + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + + # def update(self, workflow_id: str, workflow: Workflow) -> Workflow: + # """Updates a workflow record.""" + # try: + # workflow_id = workflow.get("id", None) + # if type(workflow_id) is not str: + # raise WorkflowNotFoundError(f"Workflow does not have a valid id, got {workflow_id}") + # self._lock.acquire() + # self._cursor.execute( + # """--sql + # UPDATE workflows + # SET workflow = ? + # WHERE workflow_id = ? + # """, + # (workflow, workflow_id), + # ) + # self._conn.commit() + # except Exception: + # self._conn.rollback() + # raise + # finally: + # self._lock.release() + # return self.get(workflow_id) + + # def delete(self, workflow_id: str) -> Workflow: + # """Updates a workflow record.""" + # workflow = self.get(workflow_id) + # try: + # self._lock.acquire() + # self._cursor.execute( + # """--sql + # DELETE FROM workflows + # WHERE workflow_id = ? + # """, + # (workflow_id,), + # ) + # self._conn.commit() + # except Exception: + # self._conn.rollback() + # raise + # finally: + # self._lock.release() + # return workflow diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 27b8a58bea..e2d435e621 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -75,6 +75,7 @@ def mock_services() -> InvocationServices: session_processor=None, # type: ignore session_queue=None, # type: ignore urls=None, # type: ignore + workflow_records=None, # type: ignore ) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 105f7417cd..9774f07fdd 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -80,6 +80,7 @@ def mock_services() -> InvocationServices: session_processor=None, # type: ignore session_queue=None, # type: ignore urls=None, # type: ignore + workflow_records=None, # type: ignore )