From fcc056fe6ac3a303fc287e911960622994ce29b4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 29 Nov 2023 23:37:11 +1100 Subject: [PATCH] feat(backend): add WorkflowRecordListItemDTO This is the id, name, description, created at and updated at workflow columns/attrs. Used to display lists of workflowsl --- invokeai/app/api/routers/workflows.py | 10 +++-- .../workflow_records/workflow_records_base.py | 6 ++- .../workflow_records_common.py | 13 +++++- .../workflow_records_sqlite.py | 44 +++++++++++++------ 4 files changed, 52 insertions(+), 21 deletions(-) diff --git a/invokeai/app/api/routers/workflows.py b/invokeai/app/api/routers/workflows.py index 830ce2bc16..7ee8300648 100644 --- a/invokeai/app/api/routers/workflows.py +++ b/invokeai/app/api/routers/workflows.py @@ -6,6 +6,8 @@ from invokeai.app.services.workflow_records.workflow_records_common import ( Workflow, WorkflowNotFoundError, WorkflowRecordDTO, + WorkflowRecordListItemDTO, + WorkflowWithoutID, ) workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"]) @@ -61,7 +63,7 @@ async def delete_workflow( }, ) async def create_workflow( - workflow: Workflow = Body(description="The workflow to create", embed=True), + workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True), ) -> WorkflowRecordDTO: """Creates a workflow""" return ApiDependencies.invoker.services.workflow_records.create(workflow) @@ -71,12 +73,12 @@ async def create_workflow( "/", operation_id="list_workflows", responses={ - 200: {"model": PaginatedResults[WorkflowRecordDTO]}, + 200: {"model": PaginatedResults[WorkflowRecordListItemDTO]}, }, ) async def list_workflows( page: int = Query(default=0, description="The page to get"), per_page: int = Query(default=10, description="The number of workflows per page"), -) -> PaginatedResults[WorkflowRecordDTO]: - """Deletes a workflow""" +) -> PaginatedResults[WorkflowRecordListItemDTO]: + """Gets a page of workflows""" return ApiDependencies.invoker.services.workflow_records.get_many(page=page, per_page=per_page) diff --git a/invokeai/app/services/workflow_records/workflow_records_base.py b/invokeai/app/services/workflow_records/workflow_records_base.py index 807152576f..240165417f 100644 --- a/invokeai/app/services/workflow_records/workflow_records_base.py +++ b/invokeai/app/services/workflow_records/workflow_records_base.py @@ -4,6 +4,8 @@ from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import ( Workflow, WorkflowRecordDTO, + WorkflowRecordListItemDTO, + WorkflowWithoutID, ) @@ -16,7 +18,7 @@ class WorkflowRecordsStorageBase(ABC): pass @abstractmethod - def create(self, workflow: Workflow) -> WorkflowRecordDTO: + def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO: """Creates a workflow.""" pass @@ -31,6 +33,6 @@ class WorkflowRecordsStorageBase(ABC): pass @abstractmethod - def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordDTO]: + def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordListItemDTO]: """Gets many workflows.""" pass diff --git a/invokeai/app/services/workflow_records/workflow_records_common.py b/invokeai/app/services/workflow_records/workflow_records_common.py index be27263b6d..4af8f456cc 100644 --- a/invokeai/app/services/workflow_records/workflow_records_common.py +++ b/invokeai/app/services/workflow_records/workflow_records_common.py @@ -48,7 +48,7 @@ WorkflowWithoutIDValidator = TypeAdapter(WorkflowWithoutID) class Workflow(WorkflowWithoutID): - id: str = Field(default_factory=uuid_string, description="The id of the workflow.") + workflow_id: str = Field(default_factory=uuid_string, description="The id of the workflow.") WorkflowValidator = TypeAdapter(Workflow) @@ -69,5 +69,16 @@ class WorkflowRecordDTO(BaseModel): WorkflowRecordDTOValidator = TypeAdapter(WorkflowRecordDTO) +class WorkflowRecordListItemDTO(BaseModel): + workflow_id: str = Field(description="The id of the workflow.") + name: str = Field(description="The name of the workflow.") + description: str = Field(description="The description of the workflow.") + created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the workflow.") + updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the workflow.") + + +WorkflowRecordListItemDTOValidator = TypeAdapter(WorkflowRecordListItemDTO) + + class WorkflowNotFoundError(Exception): """Raised when a workflow is not found""" diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py index a97b0a5cd8..3e9727dd57 100644 --- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py +++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py @@ -6,6 +6,10 @@ from invokeai.app.services.workflow_records.workflow_records_common import ( Workflow, WorkflowNotFoundError, WorkflowRecordDTO, + WorkflowRecordListItemDTO, + WorkflowRecordListItemDTOValidator, + WorkflowValidator, + WorkflowWithoutID, ) @@ -41,15 +45,22 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): finally: self._lock.release() - def create(self, workflow: Workflow) -> WorkflowRecordDTO: + def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO: try: + workflow_with_id = WorkflowValidator.validate_python(workflow.model_dump()) self._lock.acquire() self._cursor.execute( """--sql - INSERT OR IGNORE INTO workflow_library(workflow) - VALUES (?); + INSERT OR IGNORE INTO workflow_library ( + workflow_id, + workflow + ) + VALUES (?, ?); """, - (workflow.model_dump_json(),), + ( + workflow_with_id.workflow_id, + workflow_with_id.model_dump_json(), + ), ) self._conn.commit() except Exception: @@ -57,7 +68,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): raise finally: self._lock.release() - return self.get(workflow.id) + return self.get(workflow_with_id.workflow_id) def update(self, workflow: Workflow) -> WorkflowRecordDTO: try: @@ -68,7 +79,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): SET workflow = ? WHERE workflow_id = ?; """, - (workflow.model_dump_json(), workflow.id), + (workflow.model_dump_json(), workflow.workflow_id), ) self._conn.commit() except Exception: @@ -76,7 +87,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): raise finally: self._lock.release() - return self.get(workflow.id) + return self.get(workflow.workflow_id) def delete(self, workflow_id: str) -> None: try: @@ -96,26 +107,31 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): self._lock.release() return None - def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordDTO]: + def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordListItemDTO]: try: self._lock.acquire() self._cursor.execute( """--sql - SELECT workflow_id, workflow, created_at, updated_at + SELECT + workflow_id, + json_extract(workflow, '$.name') AS name, + json_extract(workflow, '$.description') AS description, + created_at, + updated_at FROM workflow_library - ORDER BY created_at DESC + ORDER BY name ASC LIMIT ? OFFSET ?; - """, + """, (per_page, page * per_page), ) rows = self._cursor.fetchall() - workflows = [WorkflowRecordDTO.from_dict(dict(row)) for row in rows] + workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows] self._cursor.execute( """--sql SELECT COUNT(*) FROM workflow_library; - """ + """ ) total = self._cursor.fetchone()[0] pages = int(total / per_page) + 1 @@ -138,8 +154,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): self._cursor.execute( """--sql CREATE TABLE IF NOT EXISTS workflow_library ( + workflow_id TEXT NOT NULL PRIMARY KEY, -- gets implicit index workflow TEXT NOT NULL, - workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger );