feat(backend): sync system workflows to db

This commit is contained in:
psychedelicious 2023-12-01 22:45:17 +11:00
parent b0350e9bc8
commit 734e871e8f
8 changed files with 134 additions and 42 deletions

View File

@ -30,8 +30,8 @@ from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.default_graphs import create_system_graphs
from ..services.shared.graph import GraphExecutionState, LibraryGraph
from ..services.shared.sqlite.sqlite_database import SqliteDatabase
from ..services.shared.system_workflows import create_system_workflows
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.sync_system_workflows import sync_system_workflows
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
@ -124,7 +124,7 @@ class ApiDependencies:
)
create_system_graphs(services.graph_library)
create_system_workflows(workflow_records=services.workflow_records, logger=logger)
sync_system_workflows(workflow_records=services.workflow_records, logger=logger)
db.clean()
ApiDependencies.invoker = Invoker(services)

View File

@ -1,33 +0,0 @@
from logging import Logger
from pathlib import Path
import semver
import invokeai.app.assets.workflows as system_workflows_dir
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowNotFoundError, WorkflowValidator
system_workflows = Path(system_workflows_dir.__path__[0]).glob("*.json")
def create_system_workflows(workflow_records: WorkflowRecordsStorageBase, logger: Logger) -> None:
"""Creates the system workflows."""
for workflow_filename in system_workflows:
with open(workflow_filename, "rb") as f:
workflow_bytes = f.read()
if workflow_bytes is None:
raise ValueError(f"Could not find system workflow: {workflow_filename}")
new_workflow = WorkflowValidator.validate_json(workflow_bytes)
try:
installed_workflow = workflow_records.get(new_workflow.id).workflow
installed_version = semver.Version.parse(installed_workflow.version)
new_version = semver.Version.parse(new_workflow.version)
if new_version.compare(installed_version) > 0:
logger.info(f"Updating system workflow: {new_workflow.name}")
workflow_records._add_system_workflow(new_workflow)
except WorkflowNotFoundError:
logger.info(f"Installing system workflow: {new_workflow.name}")
workflow_records._add_system_workflow(new_workflow)

View File

@ -0,0 +1,56 @@
import pkgutil
from logging import Logger
from pathlib import Path
import semver
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.app.services.workflow_records.workflow_records_common import (
Workflow,
WorkflowValidator,
)
# TODO: When I remove a workflow from system_workflows/ and do a `pip install --upgrade .`, the file
# is not removed from site-packages! The logic to delete old system workflows below doesn't work
# for normal installs. It does work for editable. Not sure why.
system_workflows_dir = "system_workflows"
def get_system_workflows_from_json() -> list[Workflow]:
app_workflows: list[Workflow] = []
workflow_paths = (Path(__file__).parent / Path(system_workflows_dir)).glob("*.json")
for workflow_path in workflow_paths:
workflow_bytes = pkgutil.get_data(__name__, f"{system_workflows_dir}/{workflow_path.name}")
if workflow_bytes is None:
raise ValueError(f"Could not load system workflow: {workflow_path.name}")
app_workflows.append(WorkflowValidator.validate_json(workflow_bytes))
return app_workflows
def sync_system_workflows(workflow_records: WorkflowRecordsStorageBase, logger: Logger) -> None:
"""Syncs system workflows in the workflow_library database with the latest system workflows."""
system_workflows = get_system_workflows_from_json()
system_workflow_ids = [w.id for w in system_workflows]
installed_workflows = workflow_records._get_all_system_workflows()
installed_workflow_ids = [w.id for w in installed_workflows]
for workflow in installed_workflows:
if workflow.id not in system_workflow_ids:
workflow_records._delete_system_workflow(workflow.id)
logger.info(f"Deleted system workflow: {workflow.name}")
for workflow in system_workflows:
if workflow.id not in installed_workflow_ids:
workflow_records._create_system_workflow(workflow)
logger.info(f"Installed system workflow: {workflow.name}")
else:
installed_workflow = workflow_records.get(workflow.id).workflow
installed_version = semver.Version.parse(installed_workflow.version)
new_version = semver.Version.parse(workflow.version)
if new_version.compare(installed_version) > 0:
workflow_records._update_system_workflow(workflow)
logger.info(f"Updated system workflow: {workflow.name}")

View File

@ -50,6 +50,21 @@ class WorkflowRecordsStorageBase(ABC):
pass
@abstractmethod
def _add_system_workflow(self, workflow: Workflow) -> None:
"""Adds a system workflow. Internal use only."""
def _create_system_workflow(self, workflow: Workflow) -> None:
"""Creates a system workflow. Internal use only."""
pass
@abstractmethod
def _update_system_workflow(self, workflow: Workflow) -> None:
"""Updates a system workflow. Internal use only."""
pass
@abstractmethod
def _delete_system_workflow(self, workflow_id: str) -> None:
"""Deletes a system workflow. Internal use only."""
pass
@abstractmethod
def _get_all_system_workflows(self) -> list[Workflow]:
"""Gets all system workflows. Internal use only."""
pass

View File

@ -91,7 +91,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
"""--sql
UPDATE workflow_library
SET workflow = ?
WHERE workflow_id = ? AND category = "user";
WHERE workflow_id = ? AND category = 'user';
""",
(workflow.model_dump_json(), workflow.id),
)
@ -109,7 +109,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
self._cursor.execute(
"""--sql
DELETE from workflow_library
WHERE workflow_id = ? AND category = "user";
WHERE workflow_id = ? AND category = 'user';
""",
(workflow_id,),
)
@ -182,10 +182,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
finally:
self._lock.release()
def _add_system_workflow(self, workflow: Workflow) -> None:
def _create_system_workflow(self, workflow: Workflow) -> None:
try:
self._lock.acquire()
# Only system workflows may be created by this method
# Only system workflows may be managed by this method
assert workflow.meta.category is WorkflowCategory.System
self._cursor.execute(
"""--sql
@ -204,6 +204,60 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
finally:
self._lock.release()
def _update_system_workflow(self, workflow: Workflow) -> None:
try:
self._lock.acquire()
# Only system workflows may be managed by this method
assert workflow.meta.category is WorkflowCategory.System
self._cursor.execute(
"""--sql
UPDATE workflow_library
SET workflow = ?
WHERE workflow_id = ? AND category = 'system';
""",
(workflow.model_dump_json(), workflow.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _delete_system_workflow(self, workflow_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM workflow_library
WHERE workflow_id = ? AND category = 'system';
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _get_all_system_workflows(self) -> list[Workflow]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT workflow FROM workflow_library
WHERE category = 'system';
"""
)
rows = self._cursor.fetchall()
return [WorkflowValidator.validate_json(dict(row)["workflow"]) for row in rows]
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _create_tables(self) -> None:
try:
self._lock.acquire()

View File

@ -165,7 +165,7 @@ version = { attr = "invokeai.version.__version__" }
[tool.setuptools.package-data]
"invokeai.app.assets" = ["**/*.png"]
"invokeai.app.assets.workflows" = ["**/*.json"]
"invokeai.app.services.workflow_records.system_workflows" = ["*.json"]
"invokeai.assets.fonts" = ["**/*.ttf"]
"invokeai.backend" = ["**.png"]
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]