feat(backend): add WorkflowRecordListItemDTO

This is the id, name, description, created at and updated at workflow columns/attrs. Used to display lists of workflowsl
This commit is contained in:
psychedelicious 2023-11-29 23:37:11 +11:00
parent c1bfc1f47b
commit fcc056fe6a
4 changed files with 52 additions and 21 deletions

View File

@ -6,6 +6,8 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
Workflow,
WorkflowNotFoundError,
WorkflowRecordDTO,
WorkflowRecordListItemDTO,
WorkflowWithoutID,
)
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
@ -61,7 +63,7 @@ async def delete_workflow(
},
)
async def create_workflow(
workflow: Workflow = Body(description="The workflow to create", embed=True),
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
) -> WorkflowRecordDTO:
"""Creates a workflow"""
return ApiDependencies.invoker.services.workflow_records.create(workflow)
@ -71,12 +73,12 @@ async def create_workflow(
"/",
operation_id="list_workflows",
responses={
200: {"model": PaginatedResults[WorkflowRecordDTO]},
200: {"model": PaginatedResults[WorkflowRecordListItemDTO]},
},
)
async def list_workflows(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of workflows per page"),
) -> PaginatedResults[WorkflowRecordDTO]:
"""Deletes a workflow"""
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets a page of workflows"""
return ApiDependencies.invoker.services.workflow_records.get_many(page=page, per_page=per_page)

View File

@ -4,6 +4,8 @@ from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import (
Workflow,
WorkflowRecordDTO,
WorkflowRecordListItemDTO,
WorkflowWithoutID,
)
@ -16,7 +18,7 @@ class WorkflowRecordsStorageBase(ABC):
pass
@abstractmethod
def create(self, workflow: Workflow) -> WorkflowRecordDTO:
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
"""Creates a workflow."""
pass
@ -31,6 +33,6 @@ class WorkflowRecordsStorageBase(ABC):
pass
@abstractmethod
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordDTO]:
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets many workflows."""
pass

View File

@ -48,7 +48,7 @@ WorkflowWithoutIDValidator = TypeAdapter(WorkflowWithoutID)
class Workflow(WorkflowWithoutID):
id: str = Field(default_factory=uuid_string, description="The id of the workflow.")
workflow_id: str = Field(default_factory=uuid_string, description="The id of the workflow.")
WorkflowValidator = TypeAdapter(Workflow)
@ -69,5 +69,16 @@ class WorkflowRecordDTO(BaseModel):
WorkflowRecordDTOValidator = TypeAdapter(WorkflowRecordDTO)
class WorkflowRecordListItemDTO(BaseModel):
workflow_id: str = Field(description="The id of the workflow.")
name: str = Field(description="The name of the workflow.")
description: str = Field(description="The description of the workflow.")
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the workflow.")
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the workflow.")
WorkflowRecordListItemDTOValidator = TypeAdapter(WorkflowRecordListItemDTO)
class WorkflowNotFoundError(Exception):
"""Raised when a workflow is not found"""

View File

@ -6,6 +6,10 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
Workflow,
WorkflowNotFoundError,
WorkflowRecordDTO,
WorkflowRecordListItemDTO,
WorkflowRecordListItemDTOValidator,
WorkflowValidator,
WorkflowWithoutID,
)
@ -41,15 +45,22 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
finally:
self._lock.release()
def create(self, workflow: Workflow) -> WorkflowRecordDTO:
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
try:
workflow_with_id = WorkflowValidator.validate_python(workflow.model_dump())
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO workflow_library(workflow)
VALUES (?);
INSERT OR IGNORE INTO workflow_library (
workflow_id,
workflow
)
VALUES (?, ?);
""",
(workflow.model_dump_json(),),
(
workflow_with_id.workflow_id,
workflow_with_id.model_dump_json(),
),
)
self._conn.commit()
except Exception:
@ -57,7 +68,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
raise
finally:
self._lock.release()
return self.get(workflow.id)
return self.get(workflow_with_id.workflow_id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
try:
@ -68,7 +79,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
SET workflow = ?
WHERE workflow_id = ?;
""",
(workflow.model_dump_json(), workflow.id),
(workflow.model_dump_json(), workflow.workflow_id),
)
self._conn.commit()
except Exception:
@ -76,7 +87,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
raise
finally:
self._lock.release()
return self.get(workflow.id)
return self.get(workflow.workflow_id)
def delete(self, workflow_id: str) -> None:
try:
@ -96,21 +107,26 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
self._lock.release()
return None
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordDTO]:
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordListItemDTO]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT workflow_id, workflow, created_at, updated_at
SELECT
workflow_id,
json_extract(workflow, '$.name') AS name,
json_extract(workflow, '$.description') AS description,
created_at,
updated_at
FROM workflow_library
ORDER BY created_at DESC
ORDER BY name ASC
LIMIT ? OFFSET ?;
""",
(per_page, page * per_page),
)
rows = self._cursor.fetchall()
workflows = [WorkflowRecordDTO.from_dict(dict(row)) for row in rows]
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
self._cursor.execute(
"""--sql
SELECT COUNT(*)
@ -138,8 +154,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS workflow_library (
workflow_id TEXT NOT NULL PRIMARY KEY, -- gets implicit index
workflow TEXT NOT NULL,
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
);