mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(backend): add WorkflowRecordListItemDTO
This is the id, name, description, created at and updated at workflow columns/attrs. Used to display lists of workflowsl
This commit is contained in:
parent
c1bfc1f47b
commit
fcc056fe6a
@ -6,6 +6,8 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
|
|||||||
Workflow,
|
Workflow,
|
||||||
WorkflowNotFoundError,
|
WorkflowNotFoundError,
|
||||||
WorkflowRecordDTO,
|
WorkflowRecordDTO,
|
||||||
|
WorkflowRecordListItemDTO,
|
||||||
|
WorkflowWithoutID,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||||
@ -61,7 +63,7 @@ async def delete_workflow(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def create_workflow(
|
async def create_workflow(
|
||||||
workflow: Workflow = Body(description="The workflow to create", embed=True),
|
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
|
||||||
) -> WorkflowRecordDTO:
|
) -> WorkflowRecordDTO:
|
||||||
"""Creates a workflow"""
|
"""Creates a workflow"""
|
||||||
return ApiDependencies.invoker.services.workflow_records.create(workflow)
|
return ApiDependencies.invoker.services.workflow_records.create(workflow)
|
||||||
@ -71,12 +73,12 @@ async def create_workflow(
|
|||||||
"/",
|
"/",
|
||||||
operation_id="list_workflows",
|
operation_id="list_workflows",
|
||||||
responses={
|
responses={
|
||||||
200: {"model": PaginatedResults[WorkflowRecordDTO]},
|
200: {"model": PaginatedResults[WorkflowRecordListItemDTO]},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def list_workflows(
|
async def list_workflows(
|
||||||
page: int = Query(default=0, description="The page to get"),
|
page: int = Query(default=0, description="The page to get"),
|
||||||
per_page: int = Query(default=10, description="The number of workflows per page"),
|
per_page: int = Query(default=10, description="The number of workflows per page"),
|
||||||
) -> PaginatedResults[WorkflowRecordDTO]:
|
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||||
"""Deletes a workflow"""
|
"""Gets a page of workflows"""
|
||||||
return ApiDependencies.invoker.services.workflow_records.get_many(page=page, per_page=per_page)
|
return ApiDependencies.invoker.services.workflow_records.get_many(page=page, per_page=per_page)
|
||||||
|
@ -4,6 +4,8 @@ from invokeai.app.services.shared.pagination import PaginatedResults
|
|||||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowRecordDTO,
|
WorkflowRecordDTO,
|
||||||
|
WorkflowRecordListItemDTO,
|
||||||
|
WorkflowWithoutID,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -16,7 +18,7 @@ class WorkflowRecordsStorageBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create(self, workflow: Workflow) -> WorkflowRecordDTO:
|
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
|
||||||
"""Creates a workflow."""
|
"""Creates a workflow."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -31,6 +33,6 @@ class WorkflowRecordsStorageBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordDTO]:
|
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||||
"""Gets many workflows."""
|
"""Gets many workflows."""
|
||||||
pass
|
pass
|
||||||
|
@ -48,7 +48,7 @@ WorkflowWithoutIDValidator = TypeAdapter(WorkflowWithoutID)
|
|||||||
|
|
||||||
|
|
||||||
class Workflow(WorkflowWithoutID):
|
class Workflow(WorkflowWithoutID):
|
||||||
id: str = Field(default_factory=uuid_string, description="The id of the workflow.")
|
workflow_id: str = Field(default_factory=uuid_string, description="The id of the workflow.")
|
||||||
|
|
||||||
|
|
||||||
WorkflowValidator = TypeAdapter(Workflow)
|
WorkflowValidator = TypeAdapter(Workflow)
|
||||||
@ -69,5 +69,16 @@ class WorkflowRecordDTO(BaseModel):
|
|||||||
WorkflowRecordDTOValidator = TypeAdapter(WorkflowRecordDTO)
|
WorkflowRecordDTOValidator = TypeAdapter(WorkflowRecordDTO)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRecordListItemDTO(BaseModel):
|
||||||
|
workflow_id: str = Field(description="The id of the workflow.")
|
||||||
|
name: str = Field(description="The name of the workflow.")
|
||||||
|
description: str = Field(description="The description of the workflow.")
|
||||||
|
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the workflow.")
|
||||||
|
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the workflow.")
|
||||||
|
|
||||||
|
|
||||||
|
WorkflowRecordListItemDTOValidator = TypeAdapter(WorkflowRecordListItemDTO)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNotFoundError(Exception):
|
class WorkflowNotFoundError(Exception):
|
||||||
"""Raised when a workflow is not found"""
|
"""Raised when a workflow is not found"""
|
||||||
|
@ -6,6 +6,10 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
|
|||||||
Workflow,
|
Workflow,
|
||||||
WorkflowNotFoundError,
|
WorkflowNotFoundError,
|
||||||
WorkflowRecordDTO,
|
WorkflowRecordDTO,
|
||||||
|
WorkflowRecordListItemDTO,
|
||||||
|
WorkflowRecordListItemDTOValidator,
|
||||||
|
WorkflowValidator,
|
||||||
|
WorkflowWithoutID,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -41,15 +45,22 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def create(self, workflow: Workflow) -> WorkflowRecordDTO:
|
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
|
||||||
try:
|
try:
|
||||||
|
workflow_with_id = WorkflowValidator.validate_python(workflow.model_dump())
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT OR IGNORE INTO workflow_library(workflow)
|
INSERT OR IGNORE INTO workflow_library (
|
||||||
VALUES (?);
|
workflow_id,
|
||||||
|
workflow
|
||||||
|
)
|
||||||
|
VALUES (?, ?);
|
||||||
""",
|
""",
|
||||||
(workflow.model_dump_json(),),
|
(
|
||||||
|
workflow_with_id.workflow_id,
|
||||||
|
workflow_with_id.model_dump_json(),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -57,7 +68,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
return self.get(workflow.id)
|
return self.get(workflow_with_id.workflow_id)
|
||||||
|
|
||||||
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
|
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||||
try:
|
try:
|
||||||
@ -68,7 +79,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
SET workflow = ?
|
SET workflow = ?
|
||||||
WHERE workflow_id = ?;
|
WHERE workflow_id = ?;
|
||||||
""",
|
""",
|
||||||
(workflow.model_dump_json(), workflow.id),
|
(workflow.model_dump_json(), workflow.workflow_id),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -76,7 +87,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
return self.get(workflow.id)
|
return self.get(workflow.workflow_id)
|
||||||
|
|
||||||
def delete(self, workflow_id: str) -> None:
|
def delete(self, workflow_id: str) -> None:
|
||||||
try:
|
try:
|
||||||
@ -96,21 +107,26 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordDTO]:
|
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT workflow_id, workflow, created_at, updated_at
|
SELECT
|
||||||
|
workflow_id,
|
||||||
|
json_extract(workflow, '$.name') AS name,
|
||||||
|
json_extract(workflow, '$.description') AS description,
|
||||||
|
created_at,
|
||||||
|
updated_at
|
||||||
FROM workflow_library
|
FROM workflow_library
|
||||||
ORDER BY created_at DESC
|
ORDER BY name ASC
|
||||||
LIMIT ? OFFSET ?;
|
LIMIT ? OFFSET ?;
|
||||||
""",
|
""",
|
||||||
(per_page, page * per_page),
|
(per_page, page * per_page),
|
||||||
)
|
)
|
||||||
rows = self._cursor.fetchall()
|
rows = self._cursor.fetchall()
|
||||||
workflows = [WorkflowRecordDTO.from_dict(dict(row)) for row in rows]
|
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
@ -138,8 +154,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
CREATE TABLE IF NOT EXISTS workflow_library (
|
CREATE TABLE IF NOT EXISTS workflow_library (
|
||||||
|
workflow_id TEXT NOT NULL PRIMARY KEY, -- gets implicit index
|
||||||
workflow TEXT NOT NULL,
|
workflow TEXT NOT NULL,
|
||||||
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
||||||
);
|
);
|
||||||
|
Loading…
Reference in New Issue
Block a user