mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(backend): sync system workflows to db
This commit is contained in:
parent
b0350e9bc8
commit
734e871e8f
@ -30,8 +30,8 @@ from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
|||||||
from ..services.shared.default_graphs import create_system_graphs
|
from ..services.shared.default_graphs import create_system_graphs
|
||||||
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||||
from ..services.shared.sqlite.sqlite_database import SqliteDatabase
|
from ..services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from ..services.shared.system_workflows import create_system_workflows
|
|
||||||
from ..services.urls.urls_default import LocalUrlService
|
from ..services.urls.urls_default import LocalUrlService
|
||||||
|
from ..services.workflow_records.sync_system_workflows import sync_system_workflows
|
||||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
@ -124,7 +124,7 @@ class ApiDependencies:
|
|||||||
)
|
)
|
||||||
|
|
||||||
create_system_graphs(services.graph_library)
|
create_system_graphs(services.graph_library)
|
||||||
create_system_workflows(workflow_records=services.workflow_records, logger=logger)
|
sync_system_workflows(workflow_records=services.workflow_records, logger=logger)
|
||||||
|
|
||||||
db.clean()
|
db.clean()
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
from logging import Logger
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import semver
|
|
||||||
|
|
||||||
import invokeai.app.assets.workflows as system_workflows_dir
|
|
||||||
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
|
||||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowNotFoundError, WorkflowValidator
|
|
||||||
|
|
||||||
system_workflows = Path(system_workflows_dir.__path__[0]).glob("*.json")
|
|
||||||
|
|
||||||
|
|
||||||
def create_system_workflows(workflow_records: WorkflowRecordsStorageBase, logger: Logger) -> None:
|
|
||||||
"""Creates the system workflows."""
|
|
||||||
for workflow_filename in system_workflows:
|
|
||||||
with open(workflow_filename, "rb") as f:
|
|
||||||
workflow_bytes = f.read()
|
|
||||||
if workflow_bytes is None:
|
|
||||||
raise ValueError(f"Could not find system workflow: {workflow_filename}")
|
|
||||||
|
|
||||||
new_workflow = WorkflowValidator.validate_json(workflow_bytes)
|
|
||||||
|
|
||||||
try:
|
|
||||||
installed_workflow = workflow_records.get(new_workflow.id).workflow
|
|
||||||
installed_version = semver.Version.parse(installed_workflow.version)
|
|
||||||
new_version = semver.Version.parse(new_workflow.version)
|
|
||||||
|
|
||||||
if new_version.compare(installed_version) > 0:
|
|
||||||
logger.info(f"Updating system workflow: {new_workflow.name}")
|
|
||||||
workflow_records._add_system_workflow(new_workflow)
|
|
||||||
except WorkflowNotFoundError:
|
|
||||||
logger.info(f"Installing system workflow: {new_workflow.name}")
|
|
||||||
workflow_records._add_system_workflow(new_workflow)
|
|
@ -0,0 +1,56 @@
|
|||||||
|
import pkgutil
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import semver
|
||||||
|
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||||
|
Workflow,
|
||||||
|
WorkflowValidator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: When I remove a workflow from system_workflows/ and do a `pip install --upgrade .`, the file
|
||||||
|
# is not removed from site-packages! The logic to delete old system workflows below doesn't work
|
||||||
|
# for normal installs. It does work for editable. Not sure why.
|
||||||
|
|
||||||
|
system_workflows_dir = "system_workflows"
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_workflows_from_json() -> list[Workflow]:
|
||||||
|
app_workflows: list[Workflow] = []
|
||||||
|
workflow_paths = (Path(__file__).parent / Path(system_workflows_dir)).glob("*.json")
|
||||||
|
for workflow_path in workflow_paths:
|
||||||
|
workflow_bytes = pkgutil.get_data(__name__, f"{system_workflows_dir}/{workflow_path.name}")
|
||||||
|
if workflow_bytes is None:
|
||||||
|
raise ValueError(f"Could not load system workflow: {workflow_path.name}")
|
||||||
|
|
||||||
|
app_workflows.append(WorkflowValidator.validate_json(workflow_bytes))
|
||||||
|
return app_workflows
|
||||||
|
|
||||||
|
|
||||||
|
def sync_system_workflows(workflow_records: WorkflowRecordsStorageBase, logger: Logger) -> None:
|
||||||
|
"""Syncs system workflows in the workflow_library database with the latest system workflows."""
|
||||||
|
|
||||||
|
system_workflows = get_system_workflows_from_json()
|
||||||
|
system_workflow_ids = [w.id for w in system_workflows]
|
||||||
|
installed_workflows = workflow_records._get_all_system_workflows()
|
||||||
|
installed_workflow_ids = [w.id for w in installed_workflows]
|
||||||
|
|
||||||
|
for workflow in installed_workflows:
|
||||||
|
if workflow.id not in system_workflow_ids:
|
||||||
|
workflow_records._delete_system_workflow(workflow.id)
|
||||||
|
logger.info(f"Deleted system workflow: {workflow.name}")
|
||||||
|
|
||||||
|
for workflow in system_workflows:
|
||||||
|
if workflow.id not in installed_workflow_ids:
|
||||||
|
workflow_records._create_system_workflow(workflow)
|
||||||
|
logger.info(f"Installed system workflow: {workflow.name}")
|
||||||
|
else:
|
||||||
|
installed_workflow = workflow_records.get(workflow.id).workflow
|
||||||
|
installed_version = semver.Version.parse(installed_workflow.version)
|
||||||
|
new_version = semver.Version.parse(workflow.version)
|
||||||
|
|
||||||
|
if new_version.compare(installed_version) > 0:
|
||||||
|
workflow_records._update_system_workflow(workflow)
|
||||||
|
logger.info(f"Updated system workflow: {workflow.name}")
|
@ -50,6 +50,21 @@ class WorkflowRecordsStorageBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _add_system_workflow(self, workflow: Workflow) -> None:
|
def _create_system_workflow(self, workflow: Workflow) -> None:
|
||||||
"""Adds a system workflow. Internal use only."""
|
"""Creates a system workflow. Internal use only."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _update_system_workflow(self, workflow: Workflow) -> None:
|
||||||
|
"""Updates a system workflow. Internal use only."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _delete_system_workflow(self, workflow_id: str) -> None:
|
||||||
|
"""Deletes a system workflow. Internal use only."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_all_system_workflows(self) -> list[Workflow]:
|
||||||
|
"""Gets all system workflows. Internal use only."""
|
||||||
pass
|
pass
|
||||||
|
@ -91,7 +91,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
"""--sql
|
"""--sql
|
||||||
UPDATE workflow_library
|
UPDATE workflow_library
|
||||||
SET workflow = ?
|
SET workflow = ?
|
||||||
WHERE workflow_id = ? AND category = "user";
|
WHERE workflow_id = ? AND category = 'user';
|
||||||
""",
|
""",
|
||||||
(workflow.model_dump_json(), workflow.id),
|
(workflow.model_dump_json(), workflow.id),
|
||||||
)
|
)
|
||||||
@ -109,7 +109,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
DELETE from workflow_library
|
DELETE from workflow_library
|
||||||
WHERE workflow_id = ? AND category = "user";
|
WHERE workflow_id = ? AND category = 'user';
|
||||||
""",
|
""",
|
||||||
(workflow_id,),
|
(workflow_id,),
|
||||||
)
|
)
|
||||||
@ -182,10 +182,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def _add_system_workflow(self, workflow: Workflow) -> None:
|
def _create_system_workflow(self, workflow: Workflow) -> None:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
# Only system workflows may be created by this method
|
# Only system workflows may be managed by this method
|
||||||
assert workflow.meta.category is WorkflowCategory.System
|
assert workflow.meta.category is WorkflowCategory.System
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
@ -204,6 +204,60 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
|
def _update_system_workflow(self, workflow: Workflow) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
# Only system workflows may be managed by this method
|
||||||
|
assert workflow.meta.category is WorkflowCategory.System
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE workflow_library
|
||||||
|
SET workflow = ?
|
||||||
|
WHERE workflow_id = ? AND category = 'system';
|
||||||
|
""",
|
||||||
|
(workflow.model_dump_json(), workflow.id),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def _delete_system_workflow(self, workflow_id: str) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
DELETE FROM workflow_library
|
||||||
|
WHERE workflow_id = ? AND category = 'system';
|
||||||
|
""",
|
||||||
|
(workflow_id,),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def _get_all_system_workflows(self) -> list[Workflow]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT workflow FROM workflow_library
|
||||||
|
WHERE category = 'system';
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
rows = self._cursor.fetchall()
|
||||||
|
return [WorkflowValidator.validate_json(dict(row)["workflow"]) for row in rows]
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
def _create_tables(self) -> None:
|
def _create_tables(self) -> None:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
@ -165,7 +165,7 @@ version = { attr = "invokeai.version.__version__" }
|
|||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
"invokeai.app.assets" = ["**/*.png"]
|
"invokeai.app.assets" = ["**/*.png"]
|
||||||
"invokeai.app.assets.workflows" = ["**/*.json"]
|
"invokeai.app.services.workflow_records.system_workflows" = ["*.json"]
|
||||||
"invokeai.assets.fonts" = ["**/*.ttf"]
|
"invokeai.assets.fonts" = ["**/*.ttf"]
|
||||||
"invokeai.backend" = ["**.png"]
|
"invokeai.backend" = ["**.png"]
|
||||||
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]
|
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]
|
||||||
|
Loading…
Reference in New Issue
Block a user