mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: add workflows table & service
This commit is contained in:
parent
9195c8c957
commit
c2da74c587
@ -30,6 +30,7 @@ from ..services.shared.default_graphs import create_system_graphs
|
||||
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.shared.sqlite import SqliteDatabase
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@ -90,6 +91,7 @@ class ApiDependencies:
|
||||
session_processor = DefaultSessionProcessor()
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
@ -114,6 +116,7 @@ class ApiDependencies:
|
||||
session_processor=session_processor,
|
||||
session_queue=session_queue,
|
||||
urls=urls,
|
||||
workflow_records=workflow_records,
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
20
invokeai/app/api/routers/workflows.py
Normal file
20
invokeai/app/api/routers/workflows.py
Normal file
@ -0,0 +1,20 @@
|
||||
from fastapi import APIRouter, Body, Path
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowField
|
||||
|
||||
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
|
||||
|
||||
@workflows_router.get(
|
||||
"/i/{workflow_id}",
|
||||
operation_id="get_workflow",
|
||||
responses={
|
||||
200: {"model": WorkflowField},
|
||||
},
|
||||
)
|
||||
async def get_workflow(
|
||||
workflow_id: str = Path(description="The workflow to get"),
|
||||
) -> WorkflowField:
|
||||
"""Gets a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
@ -38,7 +38,17 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
|
||||
from .api.routers import (
|
||||
app_info,
|
||||
board_images,
|
||||
boards,
|
||||
images,
|
||||
models,
|
||||
sessions,
|
||||
session_queue,
|
||||
utilities,
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||
|
||||
@ -95,18 +105,13 @@ async def shutdown_event() -> None:
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
@ -166,7 +171,6 @@ def custom_openapi() -> dict[str, Any]:
|
||||
# print(f"Config with name {name} already defined")
|
||||
continue
|
||||
|
||||
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
|
||||
openapi_schema["components"]["schemas"][name] = dict(
|
||||
title=name,
|
||||
description="An enumeration.",
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@ -11,12 +10,13 @@ from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
@ -60,7 +60,7 @@ class FieldDescriptions:
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
core_metadata = "Optional core metadata to be written to image"
|
||||
workflow = "Optional workflow to be saved with the image"
|
||||
interp_mode = "Interpolation mode"
|
||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||
fp32 = "Whether or not to use full float32 precision"
|
||||
@ -665,27 +665,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
description="Whether or not this is an intermediate invocation.",
|
||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
|
||||
)
|
||||
workflow: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The workflow to save with the image",
|
||||
json_schema_extra=dict(ui_type=UIType.WorkflowField),
|
||||
)
|
||||
use_cache: Optional[bool] = Field(
|
||||
default=True,
|
||||
description="Whether or not to use the cache",
|
||||
)
|
||||
|
||||
@field_validator("workflow", mode="before")
|
||||
@classmethod
|
||||
def validate_workflow_is_json(cls, v):
|
||||
"""We don't have a workflow schema in the backend, so we just check that it's valid JSON"""
|
||||
if v is None:
|
||||
return None
|
||||
try:
|
||||
json.loads(v)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise ValueError("Workflow must be valid JSON")
|
||||
return v
|
||||
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
|
||||
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
|
||||
@ -824,4 +804,6 @@ def invocation_output(
|
||||
return wrapper
|
||||
|
||||
|
||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||
class WithWorkflow(BaseModel):
|
||||
workflow: Optional[WorkflowField] = InputField(default=None, description=FieldDescriptions.workflow)
|
||||
|
||||
|
@ -27,6 +27,7 @@ if TYPE_CHECKING:
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
from .shared.graph import GraphExecutionState, LibraryGraph
|
||||
from .urls.urls_base import UrlServiceBase
|
||||
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
|
||||
|
||||
class InvocationServices:
|
||||
@ -55,6 +56,7 @@ class InvocationServices:
|
||||
invocation_cache: "InvocationCacheBase"
|
||||
names: "NameServiceBase"
|
||||
urls: "UrlServiceBase"
|
||||
workflow_records: "WorkflowRecordsStorageBase"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -80,6 +82,7 @@ class InvocationServices:
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
names: "NameServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
@ -103,3 +106,4 @@ class InvocationServices:
|
||||
self.invocation_cache = invocation_cache
|
||||
self.names = names
|
||||
self.urls = urls
|
||||
self.workflow_records = workflow_records
|
||||
|
0
invokeai/app/services/workflow_records/__init__.py
Normal file
0
invokeai/app/services/workflow_records/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowField
|
||||
|
||||
|
||||
class WorkflowRecordsStorageBase(ABC):
|
||||
"""Base class for workflow storage services."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, workflow_id: str) -> WorkflowField:
|
||||
"""Get workflow by id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, workflow: WorkflowField) -> WorkflowField:
|
||||
"""Creates a workflow."""
|
||||
pass
|
@ -0,0 +1,22 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, RootModel, TypeAdapter
|
||||
|
||||
|
||||
class WorkflowNotFoundError(Exception):
|
||||
"""Raised when a workflow is not found"""
|
||||
|
||||
|
||||
class WorkflowField(RootModel):
|
||||
"""
|
||||
Pydantic model for workflows with custom root of type dict[str, Any].
|
||||
Workflows are stored without a strict schema.
|
||||
"""
|
||||
|
||||
root: dict[str, Any] = Field(description="Workflow dict")
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict[str, Any]:
|
||||
return super().model_dump(*args, **kwargs)["root"]
|
||||
|
||||
|
||||
type_adapter_WorkflowField = TypeAdapter(WorkflowField)
|
@ -0,0 +1,148 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WorkflowField,
|
||||
WorkflowNotFoundError,
|
||||
type_adapter_WorkflowField,
|
||||
)
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
_invoker: Invoker
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
self._create_tables()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def get(self, workflow_id: str) -> WorkflowField:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow
|
||||
FROM workflows
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
row = self._cursor.fetchone()
|
||||
if row is None:
|
||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||
return type_adapter_WorkflowField.validate_json(row[0])
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create(self, workflow: WorkflowField) -> WorkflowField:
|
||||
try:
|
||||
# workflows do not have ids until they are saved
|
||||
workflow_id = uuid_string()
|
||||
workflow.root["id"] = workflow_id
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO workflows(workflow)
|
||||
VALUES (?);
|
||||
""",
|
||||
(workflow.json(),),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(workflow_id)
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
workflow TEXT NOT NULL,
|
||||
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflows FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflows
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
# def update(self, workflow_id: str, workflow: Workflow) -> Workflow:
|
||||
# """Updates a workflow record."""
|
||||
# try:
|
||||
# workflow_id = workflow.get("id", None)
|
||||
# if type(workflow_id) is not str:
|
||||
# raise WorkflowNotFoundError(f"Workflow does not have a valid id, got {workflow_id}")
|
||||
# self._lock.acquire()
|
||||
# self._cursor.execute(
|
||||
# """--sql
|
||||
# UPDATE workflows
|
||||
# SET workflow = ?
|
||||
# WHERE workflow_id = ?
|
||||
# """,
|
||||
# (workflow, workflow_id),
|
||||
# )
|
||||
# self._conn.commit()
|
||||
# except Exception:
|
||||
# self._conn.rollback()
|
||||
# raise
|
||||
# finally:
|
||||
# self._lock.release()
|
||||
# return self.get(workflow_id)
|
||||
|
||||
# def delete(self, workflow_id: str) -> Workflow:
|
||||
# """Updates a workflow record."""
|
||||
# workflow = self.get(workflow_id)
|
||||
# try:
|
||||
# self._lock.acquire()
|
||||
# self._cursor.execute(
|
||||
# """--sql
|
||||
# DELETE FROM workflows
|
||||
# WHERE workflow_id = ?
|
||||
# """,
|
||||
# (workflow_id,),
|
||||
# )
|
||||
# self._conn.commit()
|
||||
# except Exception:
|
||||
# self._conn.rollback()
|
||||
# raise
|
||||
# finally:
|
||||
# self._lock.release()
|
||||
# return workflow
|
@ -75,6 +75,7 @@ def mock_services() -> InvocationServices:
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
@ -80,6 +80,7 @@ def mock_services() -> InvocationServices:
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user