mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: add workflows table & service
This commit is contained in:
parent
9195c8c957
commit
c2da74c587
@ -30,6 +30,7 @@ from ..services.shared.default_graphs import create_system_graphs
|
|||||||
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||||
from ..services.shared.sqlite import SqliteDatabase
|
from ..services.shared.sqlite import SqliteDatabase
|
||||||
from ..services.urls.urls_default import LocalUrlService
|
from ..services.urls.urls_default import LocalUrlService
|
||||||
|
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -90,6 +91,7 @@ class ApiDependencies:
|
|||||||
session_processor = DefaultSessionProcessor()
|
session_processor = DefaultSessionProcessor()
|
||||||
session_queue = SqliteSessionQueue(db=db)
|
session_queue = SqliteSessionQueue(db=db)
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
|
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
board_image_records=board_image_records,
|
board_image_records=board_image_records,
|
||||||
@ -114,6 +116,7 @@ class ApiDependencies:
|
|||||||
session_processor=session_processor,
|
session_processor=session_processor,
|
||||||
session_queue=session_queue,
|
session_queue=session_queue,
|
||||||
urls=urls,
|
urls=urls,
|
||||||
|
workflow_records=workflow_records,
|
||||||
)
|
)
|
||||||
|
|
||||||
create_system_graphs(services.graph_library)
|
create_system_graphs(services.graph_library)
|
||||||
|
20
invokeai/app/api/routers/workflows.py
Normal file
20
invokeai/app/api/routers/workflows.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from fastapi import APIRouter, Body, Path
|
||||||
|
|
||||||
|
from invokeai.app.api.dependencies import ApiDependencies
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowField
|
||||||
|
|
||||||
|
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||||
|
|
||||||
|
|
||||||
|
@workflows_router.get(
|
||||||
|
"/i/{workflow_id}",
|
||||||
|
operation_id="get_workflow",
|
||||||
|
responses={
|
||||||
|
200: {"model": WorkflowField},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_workflow(
|
||||||
|
workflow_id: str = Path(description="The workflow to get"),
|
||||||
|
) -> WorkflowField:
|
||||||
|
"""Gets a workflow"""
|
||||||
|
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
@ -38,7 +38,17 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
|
from .api.routers import (
|
||||||
|
app_info,
|
||||||
|
board_images,
|
||||||
|
boards,
|
||||||
|
images,
|
||||||
|
models,
|
||||||
|
sessions,
|
||||||
|
session_queue,
|
||||||
|
utilities,
|
||||||
|
workflows,
|
||||||
|
)
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||||
|
|
||||||
@ -95,18 +105,13 @@ async def shutdown_event() -> None:
|
|||||||
app.include_router(sessions.session_router, prefix="/api")
|
app.include_router(sessions.session_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(utilities.utilities_router, prefix="/api")
|
app.include_router(utilities.utilities_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(models.models_router, prefix="/api")
|
app.include_router(models.models_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(images.images_router, prefix="/api")
|
app.include_router(images.images_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(boards.boards_router, prefix="/api")
|
app.include_router(boards.boards_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(board_images.board_images_router, prefix="/api")
|
app.include_router(board_images.board_images_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(app_info.app_router, prefix="/api")
|
app.include_router(app_info.app_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||||
|
app.include_router(workflows.workflows_router, prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
# Build a custom OpenAPI to include all outputs
|
# Build a custom OpenAPI to include all outputs
|
||||||
@ -166,7 +171,6 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
# print(f"Config with name {name} already defined")
|
# print(f"Config with name {name} already defined")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
|
|
||||||
openapi_schema["components"]["schemas"][name] = dict(
|
openapi_schema["components"]["schemas"][name] = dict(
|
||||||
title=name,
|
title=name,
|
||||||
description="An enumeration.",
|
description="An enumeration.",
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -11,12 +10,13 @@ from types import UnionType
|
|||||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
import semver
|
import semver
|
||||||
from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||||
from pydantic.fields import _Unset
|
from pydantic.fields import _Unset
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowField
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
@ -60,7 +60,7 @@ class FieldDescriptions:
|
|||||||
denoised_latents = "Denoised latents tensor"
|
denoised_latents = "Denoised latents tensor"
|
||||||
latents = "Latents tensor"
|
latents = "Latents tensor"
|
||||||
strength = "Strength of denoising (proportional to steps)"
|
strength = "Strength of denoising (proportional to steps)"
|
||||||
core_metadata = "Optional core metadata to be written to image"
|
workflow = "Optional workflow to be saved with the image"
|
||||||
interp_mode = "Interpolation mode"
|
interp_mode = "Interpolation mode"
|
||||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||||
fp32 = "Whether or not to use full float32 precision"
|
fp32 = "Whether or not to use full float32 precision"
|
||||||
@ -665,27 +665,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
description="Whether or not this is an intermediate invocation.",
|
description="Whether or not this is an intermediate invocation.",
|
||||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
|
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
|
||||||
)
|
)
|
||||||
workflow: Optional[str] = Field(
|
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
|
||||||
default=None,
|
|
||||||
description="The workflow to save with the image",
|
|
||||||
json_schema_extra=dict(ui_type=UIType.WorkflowField),
|
|
||||||
)
|
|
||||||
use_cache: Optional[bool] = Field(
|
|
||||||
default=True,
|
|
||||||
description="Whether or not to use the cache",
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("workflow", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_workflow_is_json(cls, v):
|
|
||||||
"""We don't have a workflow schema in the backend, so we just check that it's valid JSON"""
|
|
||||||
if v is None:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
json.loads(v)
|
|
||||||
except json.decoder.JSONDecodeError:
|
|
||||||
raise ValueError("Workflow must be valid JSON")
|
|
||||||
return v
|
|
||||||
|
|
||||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||||
|
|
||||||
@ -824,4 +804,6 @@ def invocation_output(
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
class WithWorkflow(BaseModel):
|
||||||
|
workflow: Optional[WorkflowField] = InputField(default=None, description=FieldDescriptions.workflow)
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ if TYPE_CHECKING:
|
|||||||
from .session_queue.session_queue_base import SessionQueueBase
|
from .session_queue.session_queue_base import SessionQueueBase
|
||||||
from .shared.graph import GraphExecutionState, LibraryGraph
|
from .shared.graph import GraphExecutionState, LibraryGraph
|
||||||
from .urls.urls_base import UrlServiceBase
|
from .urls.urls_base import UrlServiceBase
|
||||||
|
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||||
|
|
||||||
|
|
||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
@ -55,6 +56,7 @@ class InvocationServices:
|
|||||||
invocation_cache: "InvocationCacheBase"
|
invocation_cache: "InvocationCacheBase"
|
||||||
names: "NameServiceBase"
|
names: "NameServiceBase"
|
||||||
urls: "UrlServiceBase"
|
urls: "UrlServiceBase"
|
||||||
|
workflow_records: "WorkflowRecordsStorageBase"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -80,6 +82,7 @@ class InvocationServices:
|
|||||||
invocation_cache: "InvocationCacheBase",
|
invocation_cache: "InvocationCacheBase",
|
||||||
names: "NameServiceBase",
|
names: "NameServiceBase",
|
||||||
urls: "UrlServiceBase",
|
urls: "UrlServiceBase",
|
||||||
|
workflow_records: "WorkflowRecordsStorageBase",
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.board_image_records = board_image_records
|
self.board_image_records = board_image_records
|
||||||
@ -103,3 +106,4 @@ class InvocationServices:
|
|||||||
self.invocation_cache = invocation_cache
|
self.invocation_cache = invocation_cache
|
||||||
self.names = names
|
self.names = names
|
||||||
self.urls = urls
|
self.urls = urls
|
||||||
|
self.workflow_records = workflow_records
|
||||||
|
0
invokeai/app/services/workflow_records/__init__.py
Normal file
0
invokeai/app/services/workflow_records/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowField
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRecordsStorageBase(ABC):
|
||||||
|
"""Base class for workflow storage services."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, workflow_id: str) -> WorkflowField:
|
||||||
|
"""Get workflow by id."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(self, workflow: WorkflowField) -> WorkflowField:
|
||||||
|
"""Creates a workflow."""
|
||||||
|
pass
|
@ -0,0 +1,22 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import Field, RootModel, TypeAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNotFoundError(Exception):
|
||||||
|
"""Raised when a workflow is not found"""
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowField(RootModel):
|
||||||
|
"""
|
||||||
|
Pydantic model for workflows with custom root of type dict[str, Any].
|
||||||
|
Workflows are stored without a strict schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
root: dict[str, Any] = Field(description="Workflow dict")
|
||||||
|
|
||||||
|
def model_dump(self, *args, **kwargs) -> dict[str, Any]:
|
||||||
|
return super().model_dump(*args, **kwargs)["root"]
|
||||||
|
|
||||||
|
|
||||||
|
type_adapter_WorkflowField = TypeAdapter(WorkflowField)
|
@ -0,0 +1,148 @@
|
|||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||||
|
WorkflowField,
|
||||||
|
WorkflowNotFoundError,
|
||||||
|
type_adapter_WorkflowField,
|
||||||
|
)
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||||
|
_invoker: Invoker
|
||||||
|
_conn: sqlite3.Connection
|
||||||
|
_cursor: sqlite3.Cursor
|
||||||
|
_lock: threading.RLock
|
||||||
|
|
||||||
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._lock = db.lock
|
||||||
|
self._conn = db.conn
|
||||||
|
self._cursor = self._conn.cursor()
|
||||||
|
self._create_tables()
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self._invoker = invoker
|
||||||
|
|
||||||
|
def get(self, workflow_id: str) -> WorkflowField:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT workflow
|
||||||
|
FROM workflows
|
||||||
|
WHERE workflow_id = ?;
|
||||||
|
""",
|
||||||
|
(workflow_id,),
|
||||||
|
)
|
||||||
|
row = self._cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||||
|
return type_adapter_WorkflowField.validate_json(row[0])
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def create(self, workflow: WorkflowField) -> WorkflowField:
|
||||||
|
try:
|
||||||
|
# workflows do not have ids until they are saved
|
||||||
|
workflow_id = uuid_string()
|
||||||
|
workflow.root["id"] = workflow_id
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT INTO workflows(workflow)
|
||||||
|
VALUES (?);
|
||||||
|
""",
|
||||||
|
(workflow.json(),),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get(workflow_id)
|
||||||
|
|
||||||
|
def _create_tables(self) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS workflows (
|
||||||
|
workflow TEXT NOT NULL,
|
||||||
|
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON workflows FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE workflows
|
||||||
|
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE workflow_id = old.workflow_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self._conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
# def update(self, workflow_id: str, workflow: Workflow) -> Workflow:
|
||||||
|
# """Updates a workflow record."""
|
||||||
|
# try:
|
||||||
|
# workflow_id = workflow.get("id", None)
|
||||||
|
# if type(workflow_id) is not str:
|
||||||
|
# raise WorkflowNotFoundError(f"Workflow does not have a valid id, got {workflow_id}")
|
||||||
|
# self._lock.acquire()
|
||||||
|
# self._cursor.execute(
|
||||||
|
# """--sql
|
||||||
|
# UPDATE workflows
|
||||||
|
# SET workflow = ?
|
||||||
|
# WHERE workflow_id = ?
|
||||||
|
# """,
|
||||||
|
# (workflow, workflow_id),
|
||||||
|
# )
|
||||||
|
# self._conn.commit()
|
||||||
|
# except Exception:
|
||||||
|
# self._conn.rollback()
|
||||||
|
# raise
|
||||||
|
# finally:
|
||||||
|
# self._lock.release()
|
||||||
|
# return self.get(workflow_id)
|
||||||
|
|
||||||
|
# def delete(self, workflow_id: str) -> Workflow:
|
||||||
|
# """Updates a workflow record."""
|
||||||
|
# workflow = self.get(workflow_id)
|
||||||
|
# try:
|
||||||
|
# self._lock.acquire()
|
||||||
|
# self._cursor.execute(
|
||||||
|
# """--sql
|
||||||
|
# DELETE FROM workflows
|
||||||
|
# WHERE workflow_id = ?
|
||||||
|
# """,
|
||||||
|
# (workflow_id,),
|
||||||
|
# )
|
||||||
|
# self._conn.commit()
|
||||||
|
# except Exception:
|
||||||
|
# self._conn.rollback()
|
||||||
|
# raise
|
||||||
|
# finally:
|
||||||
|
# self._lock.release()
|
||||||
|
# return workflow
|
@ -75,6 +75,7 @@ def mock_services() -> InvocationServices:
|
|||||||
session_processor=None, # type: ignore
|
session_processor=None, # type: ignore
|
||||||
session_queue=None, # type: ignore
|
session_queue=None, # type: ignore
|
||||||
urls=None, # type: ignore
|
urls=None, # type: ignore
|
||||||
|
workflow_records=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,6 +80,7 @@ def mock_services() -> InvocationServices:
|
|||||||
session_processor=None, # type: ignore
|
session_processor=None, # type: ignore
|
||||||
session_queue=None, # type: ignore
|
session_queue=None, # type: ignore
|
||||||
urls=None, # type: ignore
|
urls=None, # type: ignore
|
||||||
|
workflow_records=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user