feat: add workflows table & service

This commit is contained in:
psychedelicious 2023-10-17 17:02:15 +11:00
parent 9195c8c957
commit c2da74c587
11 changed files with 235 additions and 33 deletions

View File

@ -30,6 +30,7 @@ from ..services.shared.default_graphs import create_system_graphs
from ..services.shared.graph import GraphExecutionState, LibraryGraph
from ..services.shared.sqlite import SqliteDatabase
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
@ -90,6 +91,7 @@ class ApiDependencies:
session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)
services = InvocationServices(
board_image_records=board_image_records,
@ -114,6 +116,7 @@ class ApiDependencies:
session_processor=session_processor,
session_queue=session_queue,
urls=urls,
workflow_records=workflow_records,
)
create_system_graphs(services.graph_library)

View File

@ -0,0 +1,20 @@
from fastapi import APIRouter, Body, Path
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowField
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
@workflows_router.get(
"/i/{workflow_id}",
operation_id="get_workflow",
responses={
200: {"model": WorkflowField},
},
)
async def get_workflow(
workflow_id: str = Path(description="The workflow to get"),
) -> WorkflowField:
"""Gets a workflow"""
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)

View File

@ -38,7 +38,17 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
from .api.routers import (
app_info,
board_images,
boards,
images,
models,
sessions,
session_queue,
utilities,
workflows,
)
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
@ -95,18 +105,13 @@ async def shutdown_event() -> None:
app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
@ -166,7 +171,6 @@ def custom_openapi() -> dict[str, Any]:
# print(f"Config with name {name} already defined")
continue
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import json
import re
from abc import ABC, abstractmethod
from enum import Enum
@ -11,12 +10,13 @@ from types import UnionType
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
import semver
from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
from pydantic.fields import _Unset
from pydantic_core import PydanticUndefined
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.util.misc import uuid_string
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowField
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
@ -60,7 +60,7 @@ class FieldDescriptions:
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"
core_metadata = "Optional core metadata to be written to image"
workflow = "Optional workflow to be saved with the image"
interp_mode = "Interpolation mode"
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
fp32 = "Whether or not to use full float32 precision"
@ -665,27 +665,7 @@ class BaseInvocation(ABC, BaseModel):
description="Whether or not this is an intermediate invocation.",
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
)
workflow: Optional[str] = Field(
default=None,
description="The workflow to save with the image",
json_schema_extra=dict(ui_type=UIType.WorkflowField),
)
use_cache: Optional[bool] = Field(
default=True,
description="Whether or not to use the cache",
)
@field_validator("workflow", mode="before")
@classmethod
def validate_workflow_is_json(cls, v):
"""We don't have a workflow schema in the backend, so we just check that it's valid JSON"""
if v is None:
return None
try:
json.loads(v)
except json.decoder.JSONDecodeError:
raise ValueError("Workflow must be valid JSON")
return v
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
UIConfig: ClassVar[Type[UIConfigBase]]
@ -824,4 +804,6 @@ def invocation_output(
return wrapper
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
class WithWorkflow(BaseModel):
workflow: Optional[WorkflowField] = InputField(default=None, description=FieldDescriptions.workflow)

View File

@ -27,6 +27,7 @@ if TYPE_CHECKING:
from .session_queue.session_queue_base import SessionQueueBase
from .shared.graph import GraphExecutionState, LibraryGraph
from .urls.urls_base import UrlServiceBase
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
class InvocationServices:
@ -55,6 +56,7 @@ class InvocationServices:
invocation_cache: "InvocationCacheBase"
names: "NameServiceBase"
urls: "UrlServiceBase"
workflow_records: "WorkflowRecordsStorageBase"
def __init__(
self,
@ -80,6 +82,7 @@ class InvocationServices:
invocation_cache: "InvocationCacheBase",
names: "NameServiceBase",
urls: "UrlServiceBase",
workflow_records: "WorkflowRecordsStorageBase",
):
self.board_images = board_images
self.board_image_records = board_image_records
@ -103,3 +106,4 @@ class InvocationServices:
self.invocation_cache = invocation_cache
self.names = names
self.urls = urls
self.workflow_records = workflow_records

View File

@ -0,0 +1,17 @@
from abc import ABC, abstractmethod
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowField
class WorkflowRecordsStorageBase(ABC):
"""Base class for workflow storage services."""
@abstractmethod
def get(self, workflow_id: str) -> WorkflowField:
"""Get workflow by id."""
pass
@abstractmethod
def create(self, workflow: WorkflowField) -> WorkflowField:
"""Creates a workflow."""
pass

View File

@ -0,0 +1,22 @@
from typing import Any
from pydantic import Field, RootModel, TypeAdapter
class WorkflowNotFoundError(Exception):
"""Raised when a workflow is not found"""
class WorkflowField(RootModel):
"""
Pydantic model for workflows with custom root of type dict[str, Any].
Workflows are stored without a strict schema.
"""
root: dict[str, Any] = Field(description="Workflow dict")
def model_dump(self, *args, **kwargs) -> dict[str, Any]:
return super().model_dump(*args, **kwargs)["root"]
type_adapter_WorkflowField = TypeAdapter(WorkflowField)

View File

@ -0,0 +1,148 @@
import sqlite3
import threading
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowField,
WorkflowNotFoundError,
type_adapter_WorkflowField,
)
from invokeai.app.util.misc import uuid_string
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
_invoker: Invoker
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
self._create_tables()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def get(self, workflow_id: str) -> WorkflowField:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT workflow
FROM workflows
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = self._cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return type_adapter_WorkflowField.validate_json(row[0])
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def create(self, workflow: WorkflowField) -> WorkflowField:
try:
# workflows do not have ids until they are saved
workflow_id = uuid_string()
workflow.root["id"] = workflow_id
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO workflows(workflow)
VALUES (?);
""",
(workflow.json(),),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(workflow_id)
def _create_tables(self) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS workflows (
workflow TEXT NOT NULL,
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
);
"""
)
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
AFTER UPDATE
ON workflows FOR EACH ROW
BEGIN
UPDATE workflows
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
# def update(self, workflow_id: str, workflow: Workflow) -> Workflow:
# """Updates a workflow record."""
# try:
# workflow_id = workflow.get("id", None)
# if type(workflow_id) is not str:
# raise WorkflowNotFoundError(f"Workflow does not have a valid id, got {workflow_id}")
# self._lock.acquire()
# self._cursor.execute(
# """--sql
# UPDATE workflows
# SET workflow = ?
# WHERE workflow_id = ?
# """,
# (workflow, workflow_id),
# )
# self._conn.commit()
# except Exception:
# self._conn.rollback()
# raise
# finally:
# self._lock.release()
# return self.get(workflow_id)
# def delete(self, workflow_id: str) -> Workflow:
# """Updates a workflow record."""
# workflow = self.get(workflow_id)
# try:
# self._lock.acquire()
# self._cursor.execute(
# """--sql
# DELETE FROM workflows
# WHERE workflow_id = ?
# """,
# (workflow_id,),
# )
# self._conn.commit()
# except Exception:
# self._conn.rollback()
# raise
# finally:
# self._lock.release()
# return workflow

View File

@ -75,6 +75,7 @@ def mock_services() -> InvocationServices:
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
)

View File

@ -80,6 +80,7 @@ def mock_services() -> InvocationServices:
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
)