From 734e871e8f78a5fd3f684be33d1ddef1cc832a42 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 1 Dec 2023 22:45:17 +1100 Subject: [PATCH] feat(backend): sync system workflows to db --- invokeai/app/api/dependencies.py | 4 +- .../app/services/shared/system_workflows.py | 33 ---------- .../workflow_records/sync_system_workflows.py | 56 +++++++++++++++++ .../system_workflows}/Text_to_Image_SD15.json | 0 .../system_workflows/__init__.py | 0 .../workflow_records/workflow_records_base.py | 19 +++++- .../workflow_records_sqlite.py | 62 +++++++++++++++++-- pyproject.toml | 2 +- 8 files changed, 134 insertions(+), 42 deletions(-) delete mode 100644 invokeai/app/services/shared/system_workflows.py create mode 100644 invokeai/app/services/workflow_records/sync_system_workflows.py rename invokeai/app/{assets/workflows => services/workflow_records/system_workflows}/Text_to_Image_SD15.json (100%) create mode 100644 invokeai/app/services/workflow_records/system_workflows/__init__.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index bceaf43644..1576a2ff90 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -30,8 +30,8 @@ from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.shared.default_graphs import create_system_graphs from ..services.shared.graph import GraphExecutionState, LibraryGraph from ..services.shared.sqlite.sqlite_database import SqliteDatabase -from ..services.shared.system_workflows import create_system_workflows from ..services.urls.urls_default import LocalUrlService +from ..services.workflow_records.sync_system_workflows import sync_system_workflows from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from .events import FastAPIEventService @@ -124,7 +124,7 @@ class ApiDependencies: ) create_system_graphs(services.graph_library) - create_system_workflows(workflow_records=services.workflow_records, logger=logger) + sync_system_workflows(workflow_records=services.workflow_records, logger=logger) db.clean() ApiDependencies.invoker = Invoker(services) diff --git a/invokeai/app/services/shared/system_workflows.py b/invokeai/app/services/shared/system_workflows.py deleted file mode 100644 index e47a5e58cc..0000000000 --- a/invokeai/app/services/shared/system_workflows.py +++ /dev/null @@ -1,33 +0,0 @@ -from logging import Logger -from pathlib import Path - -import semver - -import invokeai.app.assets.workflows as system_workflows_dir -from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase -from invokeai.app.services.workflow_records.workflow_records_common import WorkflowNotFoundError, WorkflowValidator - -system_workflows = Path(system_workflows_dir.__path__[0]).glob("*.json") - - -def create_system_workflows(workflow_records: WorkflowRecordsStorageBase, logger: Logger) -> None: - """Creates the system workflows.""" - for workflow_filename in system_workflows: - with open(workflow_filename, "rb") as f: - workflow_bytes = f.read() - if workflow_bytes is None: - raise ValueError(f"Could not find system workflow: {workflow_filename}") - - new_workflow = WorkflowValidator.validate_json(workflow_bytes) - - try: - installed_workflow = workflow_records.get(new_workflow.id).workflow - installed_version = semver.Version.parse(installed_workflow.version) - new_version = semver.Version.parse(new_workflow.version) - - if new_version.compare(installed_version) > 0: - logger.info(f"Updating system workflow: {new_workflow.name}") - workflow_records._add_system_workflow(new_workflow) - except WorkflowNotFoundError: - logger.info(f"Installing system workflow: {new_workflow.name}") - workflow_records._add_system_workflow(new_workflow) diff --git a/invokeai/app/services/workflow_records/sync_system_workflows.py b/invokeai/app/services/workflow_records/sync_system_workflows.py new file mode 100644 index 0000000000..61643689c2 --- /dev/null +++ b/invokeai/app/services/workflow_records/sync_system_workflows.py @@ -0,0 +1,56 @@ +import pkgutil +from logging import Logger +from pathlib import Path + +import semver + +from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase +from invokeai.app.services.workflow_records.workflow_records_common import ( + Workflow, + WorkflowValidator, +) + +# TODO: When I remove a workflow from system_workflows/ and do a `pip install --upgrade .`, the file +# is not removed from site-packages! The logic to delete old system workflows below doesn't work +# for normal installs. It does work for editable. Not sure why. + +system_workflows_dir = "system_workflows" + + +def get_system_workflows_from_json() -> list[Workflow]: + app_workflows: list[Workflow] = [] + workflow_paths = (Path(__file__).parent / Path(system_workflows_dir)).glob("*.json") + for workflow_path in workflow_paths: + workflow_bytes = pkgutil.get_data(__name__, f"{system_workflows_dir}/{workflow_path.name}") + if workflow_bytes is None: + raise ValueError(f"Could not load system workflow: {workflow_path.name}") + + app_workflows.append(WorkflowValidator.validate_json(workflow_bytes)) + return app_workflows + + +def sync_system_workflows(workflow_records: WorkflowRecordsStorageBase, logger: Logger) -> None: + """Syncs system workflows in the workflow_library database with the latest system workflows.""" + + system_workflows = get_system_workflows_from_json() + system_workflow_ids = [w.id for w in system_workflows] + installed_workflows = workflow_records._get_all_system_workflows() + installed_workflow_ids = [w.id for w in installed_workflows] + + for workflow in installed_workflows: + if workflow.id not in system_workflow_ids: + workflow_records._delete_system_workflow(workflow.id) + logger.info(f"Deleted system workflow: {workflow.name}") + + for workflow in system_workflows: + if workflow.id not in installed_workflow_ids: + workflow_records._create_system_workflow(workflow) + logger.info(f"Installed system workflow: {workflow.name}") + else: + installed_workflow = workflow_records.get(workflow.id).workflow + installed_version = semver.Version.parse(installed_workflow.version) + new_version = semver.Version.parse(workflow.version) + + if new_version.compare(installed_version) > 0: + workflow_records._update_system_workflow(workflow) + logger.info(f"Updated system workflow: {workflow.name}") diff --git a/invokeai/app/assets/workflows/Text_to_Image_SD15.json b/invokeai/app/services/workflow_records/system_workflows/Text_to_Image_SD15.json similarity index 100% rename from invokeai/app/assets/workflows/Text_to_Image_SD15.json rename to invokeai/app/services/workflow_records/system_workflows/Text_to_Image_SD15.json diff --git a/invokeai/app/services/workflow_records/system_workflows/__init__.py b/invokeai/app/services/workflow_records/system_workflows/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/app/services/workflow_records/workflow_records_base.py b/invokeai/app/services/workflow_records/workflow_records_base.py index ae4b6b37fa..39b87c10a4 100644 --- a/invokeai/app/services/workflow_records/workflow_records_base.py +++ b/invokeai/app/services/workflow_records/workflow_records_base.py @@ -50,6 +50,21 @@ class WorkflowRecordsStorageBase(ABC): pass @abstractmethod - def _add_system_workflow(self, workflow: Workflow) -> None: - """Adds a system workflow. Internal use only.""" + def _create_system_workflow(self, workflow: Workflow) -> None: + """Creates a system workflow. Internal use only.""" + pass + + @abstractmethod + def _update_system_workflow(self, workflow: Workflow) -> None: + """Updates a system workflow. Internal use only.""" + pass + + @abstractmethod + def _delete_system_workflow(self, workflow_id: str) -> None: + """Deletes a system workflow. Internal use only.""" + pass + + @abstractmethod + def _get_all_system_workflows(self) -> list[Workflow]: + """Gets all system workflows. Internal use only.""" pass diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py index 8b2c4dfee6..f14d2753a4 100644 --- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py +++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py @@ -91,7 +91,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): """--sql UPDATE workflow_library SET workflow = ? - WHERE workflow_id = ? AND category = "user"; + WHERE workflow_id = ? AND category = 'user'; """, (workflow.model_dump_json(), workflow.id), ) @@ -109,7 +109,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): self._cursor.execute( """--sql DELETE from workflow_library - WHERE workflow_id = ? AND category = "user"; + WHERE workflow_id = ? AND category = 'user'; """, (workflow_id,), ) @@ -182,10 +182,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): finally: self._lock.release() - def _add_system_workflow(self, workflow: Workflow) -> None: + def _create_system_workflow(self, workflow: Workflow) -> None: try: self._lock.acquire() - # Only system workflows may be created by this method + # Only system workflows may be managed by this method assert workflow.meta.category is WorkflowCategory.System self._cursor.execute( """--sql @@ -204,6 +204,60 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): finally: self._lock.release() + def _update_system_workflow(self, workflow: Workflow) -> None: + try: + self._lock.acquire() + # Only system workflows may be managed by this method + assert workflow.meta.category is WorkflowCategory.System + self._cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ? + WHERE workflow_id = ? AND category = 'system'; + """, + (workflow.model_dump_json(), workflow.id), + ) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + + def _delete_system_workflow(self, workflow_id: str) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + DELETE FROM workflow_library + WHERE workflow_id = ? AND category = 'system'; + """, + (workflow_id,), + ) + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + + def _get_all_system_workflows(self) -> list[Workflow]: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT workflow FROM workflow_library + WHERE category = 'system'; + """ + ) + rows = self._cursor.fetchall() + return [WorkflowValidator.validate_json(dict(row)["workflow"]) for row in rows] + except Exception: + self._conn.rollback() + raise + finally: + self._lock.release() + def _create_tables(self) -> None: try: self._lock.acquire() diff --git a/pyproject.toml b/pyproject.toml index d28788b31d..e7279f8a4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,7 +165,7 @@ version = { attr = "invokeai.version.__version__" } [tool.setuptools.package-data] "invokeai.app.assets" = ["**/*.png"] -"invokeai.app.assets.workflows" = ["**/*.json"] +"invokeai.app.services.workflow_records.system_workflows" = ["*.json"] "invokeai.assets.fonts" = ["**/*.ttf"] "invokeai.backend" = ["**.png"] "invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]