mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(backend): sync system workflows to db
This commit is contained in:
parent
b0350e9bc8
commit
734e871e8f
@ -30,8 +30,8 @@ from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.shared.default_graphs import create_system_graphs
|
||||
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from ..services.shared.system_workflows import create_system_workflows
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.sync_system_workflows import sync_system_workflows
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
|
||||
@ -124,7 +124,7 @@ class ApiDependencies:
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
create_system_workflows(workflow_records=services.workflow_records, logger=logger)
|
||||
sync_system_workflows(workflow_records=services.workflow_records, logger=logger)
|
||||
|
||||
db.clean()
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
@ -1,33 +0,0 @@
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
import semver
|
||||
|
||||
import invokeai.app.assets.workflows as system_workflows_dir
|
||||
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowNotFoundError, WorkflowValidator
|
||||
|
||||
system_workflows = Path(system_workflows_dir.__path__[0]).glob("*.json")
|
||||
|
||||
|
||||
def create_system_workflows(workflow_records: WorkflowRecordsStorageBase, logger: Logger) -> None:
|
||||
"""Creates the system workflows."""
|
||||
for workflow_filename in system_workflows:
|
||||
with open(workflow_filename, "rb") as f:
|
||||
workflow_bytes = f.read()
|
||||
if workflow_bytes is None:
|
||||
raise ValueError(f"Could not find system workflow: {workflow_filename}")
|
||||
|
||||
new_workflow = WorkflowValidator.validate_json(workflow_bytes)
|
||||
|
||||
try:
|
||||
installed_workflow = workflow_records.get(new_workflow.id).workflow
|
||||
installed_version = semver.Version.parse(installed_workflow.version)
|
||||
new_version = semver.Version.parse(new_workflow.version)
|
||||
|
||||
if new_version.compare(installed_version) > 0:
|
||||
logger.info(f"Updating system workflow: {new_workflow.name}")
|
||||
workflow_records._add_system_workflow(new_workflow)
|
||||
except WorkflowNotFoundError:
|
||||
logger.info(f"Installing system workflow: {new_workflow.name}")
|
||||
workflow_records._add_system_workflow(new_workflow)
|
@ -0,0 +1,56 @@
|
||||
import pkgutil
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
import semver
|
||||
|
||||
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
Workflow,
|
||||
WorkflowValidator,
|
||||
)
|
||||
|
||||
# TODO: When I remove a workflow from system_workflows/ and do a `pip install --upgrade .`, the file
|
||||
# is not removed from site-packages! The logic to delete old system workflows below doesn't work
|
||||
# for normal installs. It does work for editable. Not sure why.
|
||||
|
||||
system_workflows_dir = "system_workflows"
|
||||
|
||||
|
||||
def get_system_workflows_from_json() -> list[Workflow]:
|
||||
app_workflows: list[Workflow] = []
|
||||
workflow_paths = (Path(__file__).parent / Path(system_workflows_dir)).glob("*.json")
|
||||
for workflow_path in workflow_paths:
|
||||
workflow_bytes = pkgutil.get_data(__name__, f"{system_workflows_dir}/{workflow_path.name}")
|
||||
if workflow_bytes is None:
|
||||
raise ValueError(f"Could not load system workflow: {workflow_path.name}")
|
||||
|
||||
app_workflows.append(WorkflowValidator.validate_json(workflow_bytes))
|
||||
return app_workflows
|
||||
|
||||
|
||||
def sync_system_workflows(workflow_records: WorkflowRecordsStorageBase, logger: Logger) -> None:
|
||||
"""Syncs system workflows in the workflow_library database with the latest system workflows."""
|
||||
|
||||
system_workflows = get_system_workflows_from_json()
|
||||
system_workflow_ids = [w.id for w in system_workflows]
|
||||
installed_workflows = workflow_records._get_all_system_workflows()
|
||||
installed_workflow_ids = [w.id for w in installed_workflows]
|
||||
|
||||
for workflow in installed_workflows:
|
||||
if workflow.id not in system_workflow_ids:
|
||||
workflow_records._delete_system_workflow(workflow.id)
|
||||
logger.info(f"Deleted system workflow: {workflow.name}")
|
||||
|
||||
for workflow in system_workflows:
|
||||
if workflow.id not in installed_workflow_ids:
|
||||
workflow_records._create_system_workflow(workflow)
|
||||
logger.info(f"Installed system workflow: {workflow.name}")
|
||||
else:
|
||||
installed_workflow = workflow_records.get(workflow.id).workflow
|
||||
installed_version = semver.Version.parse(installed_workflow.version)
|
||||
new_version = semver.Version.parse(workflow.version)
|
||||
|
||||
if new_version.compare(installed_version) > 0:
|
||||
workflow_records._update_system_workflow(workflow)
|
||||
logger.info(f"Updated system workflow: {workflow.name}")
|
@ -50,6 +50,21 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _add_system_workflow(self, workflow: Workflow) -> None:
|
||||
"""Adds a system workflow. Internal use only."""
|
||||
def _create_system_workflow(self, workflow: Workflow) -> None:
|
||||
"""Creates a system workflow. Internal use only."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _update_system_workflow(self, workflow: Workflow) -> None:
|
||||
"""Updates a system workflow. Internal use only."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _delete_system_workflow(self, workflow_id: str) -> None:
|
||||
"""Deletes a system workflow. Internal use only."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_all_system_workflows(self) -> list[Workflow]:
|
||||
"""Gets all system workflows. Internal use only."""
|
||||
pass
|
||||
|
@ -91,7 +91,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
"""--sql
|
||||
UPDATE workflow_library
|
||||
SET workflow = ?
|
||||
WHERE workflow_id = ? AND category = "user";
|
||||
WHERE workflow_id = ? AND category = 'user';
|
||||
""",
|
||||
(workflow.model_dump_json(), workflow.id),
|
||||
)
|
||||
@ -109,7 +109,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE from workflow_library
|
||||
WHERE workflow_id = ? AND category = "user";
|
||||
WHERE workflow_id = ? AND category = 'user';
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
@ -182,10 +182,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _add_system_workflow(self, workflow: Workflow) -> None:
|
||||
def _create_system_workflow(self, workflow: Workflow) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Only system workflows may be created by this method
|
||||
# Only system workflows may be managed by this method
|
||||
assert workflow.meta.category is WorkflowCategory.System
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -204,6 +204,60 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _update_system_workflow(self, workflow: Workflow) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Only system workflows may be managed by this method
|
||||
assert workflow.meta.category is WorkflowCategory.System
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE workflow_library
|
||||
SET workflow = ?
|
||||
WHERE workflow_id = ? AND category = 'system';
|
||||
""",
|
||||
(workflow.model_dump_json(), workflow.id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _delete_system_workflow(self, workflow_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM workflow_library
|
||||
WHERE workflow_id = ? AND category = 'system';
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _get_all_system_workflows(self) -> list[Workflow]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow FROM workflow_library
|
||||
WHERE category = 'system';
|
||||
"""
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [WorkflowValidator.validate_json(dict(row)["workflow"]) for row in rows]
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
@ -165,7 +165,7 @@ version = { attr = "invokeai.version.__version__" }
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"invokeai.app.assets" = ["**/*.png"]
|
||||
"invokeai.app.assets.workflows" = ["**/*.json"]
|
||||
"invokeai.app.services.workflow_records.system_workflows" = ["*.json"]
|
||||
"invokeai.assets.fonts" = ["**/*.ttf"]
|
||||
"invokeai.backend" = ["**.png"]
|
||||
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]
|
||||
|
Loading…
Reference in New Issue
Block a user