mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(backend): add WorkflowRecordListItemDTO
This is the id, name, description, created at and updated at workflow columns/attrs. Used to display lists of workflowsl
This commit is contained in:
parent
c1bfc1f47b
commit
fcc056fe6a
@ -6,6 +6,8 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
Workflow,
|
||||
WorkflowNotFoundError,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowWithoutID,
|
||||
)
|
||||
|
||||
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
@ -61,7 +63,7 @@ async def delete_workflow(
|
||||
},
|
||||
)
|
||||
async def create_workflow(
|
||||
workflow: Workflow = Body(description="The workflow to create", embed=True),
|
||||
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Creates a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.create(workflow)
|
||||
@ -71,12 +73,12 @@ async def create_workflow(
|
||||
"/",
|
||||
operation_id="list_workflows",
|
||||
responses={
|
||||
200: {"model": PaginatedResults[WorkflowRecordDTO]},
|
||||
200: {"model": PaginatedResults[WorkflowRecordListItemDTO]},
|
||||
},
|
||||
)
|
||||
async def list_workflows(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of workflows per page"),
|
||||
) -> PaginatedResults[WorkflowRecordDTO]:
|
||||
"""Deletes a workflow"""
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
return ApiDependencies.invoker.services.workflow_records.get_many(page=page, per_page=per_page)
|
||||
|
@ -4,6 +4,8 @@ from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
Workflow,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowWithoutID,
|
||||
)
|
||||
|
||||
|
||||
@ -16,7 +18,7 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
|
||||
"""Creates a workflow."""
|
||||
pass
|
||||
|
||||
@ -31,6 +33,6 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordDTO]:
|
||||
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets many workflows."""
|
||||
pass
|
||||
|
@ -48,7 +48,7 @@ WorkflowWithoutIDValidator = TypeAdapter(WorkflowWithoutID)
|
||||
|
||||
|
||||
class Workflow(WorkflowWithoutID):
|
||||
id: str = Field(default_factory=uuid_string, description="The id of the workflow.")
|
||||
workflow_id: str = Field(default_factory=uuid_string, description="The id of the workflow.")
|
||||
|
||||
|
||||
WorkflowValidator = TypeAdapter(Workflow)
|
||||
@ -69,5 +69,16 @@ class WorkflowRecordDTO(BaseModel):
|
||||
WorkflowRecordDTOValidator = TypeAdapter(WorkflowRecordDTO)
|
||||
|
||||
|
||||
class WorkflowRecordListItemDTO(BaseModel):
|
||||
workflow_id: str = Field(description="The id of the workflow.")
|
||||
name: str = Field(description="The name of the workflow.")
|
||||
description: str = Field(description="The description of the workflow.")
|
||||
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the workflow.")
|
||||
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the workflow.")
|
||||
|
||||
|
||||
WorkflowRecordListItemDTOValidator = TypeAdapter(WorkflowRecordListItemDTO)
|
||||
|
||||
|
||||
class WorkflowNotFoundError(Exception):
|
||||
"""Raised when a workflow is not found"""
|
||||
|
@ -6,6 +6,10 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
Workflow,
|
||||
WorkflowNotFoundError,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowRecordListItemDTOValidator,
|
||||
WorkflowValidator,
|
||||
WorkflowWithoutID,
|
||||
)
|
||||
|
||||
|
||||
@ -41,15 +45,22 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
|
||||
try:
|
||||
workflow_with_id = WorkflowValidator.validate_python(workflow.model_dump())
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO workflow_library(workflow)
|
||||
VALUES (?);
|
||||
INSERT OR IGNORE INTO workflow_library (
|
||||
workflow_id,
|
||||
workflow
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(workflow.model_dump_json(),),
|
||||
(
|
||||
workflow_with_id.workflow_id,
|
||||
workflow_with_id.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
@ -57,7 +68,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(workflow.id)
|
||||
return self.get(workflow_with_id.workflow_id)
|
||||
|
||||
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||
try:
|
||||
@ -68,7 +79,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
SET workflow = ?
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow.model_dump_json(), workflow.id),
|
||||
(workflow.model_dump_json(), workflow.workflow_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
@ -76,7 +87,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(workflow.id)
|
||||
return self.get(workflow.workflow_id)
|
||||
|
||||
def delete(self, workflow_id: str) -> None:
|
||||
try:
|
||||
@ -96,26 +107,31 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
self._lock.release()
|
||||
return None
|
||||
|
||||
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordDTO]:
|
||||
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, created_at, updated_at
|
||||
SELECT
|
||||
workflow_id,
|
||||
json_extract(workflow, '$.name') AS name,
|
||||
json_extract(workflow, '$.description') AS description,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM workflow_library
|
||||
ORDER BY created_at DESC
|
||||
ORDER BY name ASC
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
""",
|
||||
(per_page, page * per_page),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
workflows = [WorkflowRecordDTO.from_dict(dict(row)) for row in rows]
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library;
|
||||
"""
|
||||
"""
|
||||
)
|
||||
total = self._cursor.fetchone()[0]
|
||||
pages = int(total / per_page) + 1
|
||||
@ -138,8 +154,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflow_library (
|
||||
workflow_id TEXT NOT NULL PRIMARY KEY, -- gets implicit index
|
||||
workflow TEXT NOT NULL,
|
||||
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user