feat(api): add workflow_images junction table

similar to boards, images and workflows may be associated via junction table
This commit is contained in:
psychedelicious 2023-10-18 18:16:36 +11:00
parent 6d776bad7e
commit 0cda7943fa
11 changed files with 189 additions and 32 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from logging import Logger
from invokeai.app.services.workflow_image_records.workflow_image_records_sqlite import SqliteWorkflowImageRecordsStorage
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -91,6 +92,7 @@ class ApiDependencies:
session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_image_records = SqliteWorkflowImageRecordsStorage(db=db)
workflow_records = SqliteWorkflowRecordsStorage(db=db)
services = InvocationServices(
@ -116,6 +118,7 @@ class ApiDependencies:
session_processor=session_processor,
session_queue=session_queue,
urls=urls,
workflow_image_records=workflow_image_records,
workflow_records=workflow_records,
)

View File

@ -1,4 +1,5 @@
import io
import traceback
from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
@ -60,6 +61,7 @@ async def upload_image(
pil_image = pil_image.crop(bbox)
except Exception:
# Error opening the image
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
# attempt to parse metadata from image
@ -97,6 +99,7 @@ async def upload_image(
return image_dto
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="Failed to create image")

View File

@ -80,7 +80,6 @@ class ImageRecordStorageBase(ABC):
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[MetadataField] = None,
workflow_id: Optional[str] = None,
) -> datetime:
"""Saves an image record."""
pass

View File

@ -100,7 +100,6 @@ IMAGE_DTO_COLS = ", ".join(
"width",
"height",
"session_id",
"workflow_id",
"node_id",
"is_intermediate",
"created_at",
@ -141,11 +140,6 @@ class ImageRecord(BaseModelExcludeNull):
description="The session ID that generated this image, if it is a generated image.",
)
"""The session ID that generated this image, if it is a generated image."""
workflow_id: Optional[str] = Field(
default=None,
description="The workflow that generated this image.",
)
"""The workflow that generated this image."""
node_id: Optional[str] = Field(
default=None,
description="The node ID that generated this image, if it is a generated image.",
@ -190,7 +184,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
width = image_dict.get("width", 0)
height = image_dict.get("height", 0)
session_id = image_dict.get("session_id", None)
workflow_id = image_dict.get("workflow_id", None)
node_id = image_dict.get("node_id", None)
created_at = image_dict.get("created_at", get_iso_timestamp())
updated_at = image_dict.get("updated_at", get_iso_timestamp())
@ -205,7 +198,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
width=width,
height=height,
session_id=session_id,
workflow_id=workflow_id,
node_id=node_id,
created_at=created_at,
updated_at=updated_at,

View File

@ -76,16 +76,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
if "workflow_id" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images
ADD COLUMN workflow_id TEXT;
-- TODO: This requires a migration:
-- FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE SET NULL;
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
@ -423,7 +413,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[MetadataField] = None,
workflow_id: Optional[str] = None,
) -> datetime:
try:
metadata_json = metadata.model_dump_json() if metadata is not None else None
@ -439,11 +428,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id,
session_id,
metadata,
workflow_id,
is_intermediate,
starred
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
@ -454,7 +442,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id,
session_id,
metadata_json,
workflow_id,
is_intermediate,
starred,
),

View File

@ -24,6 +24,11 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
default=None, description="The id of the board the image belongs to, if one exists."
)
"""The id of the board the image belongs to, if one exists."""
workflow_id: Optional[str] = Field(
default=None,
description="The workflow that generated this image.",
)
"""The workflow that generated this image."""
def image_record_to_dto(
@ -31,6 +36,7 @@ def image_record_to_dto(
image_url: str,
thumbnail_url: str,
board_id: Optional[str],
workflow_id: Optional[str],
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
@ -38,4 +44,5 @@ def image_record_to_dto(
image_url=image_url,
thumbnail_url=thumbnail_url,
board_id=board_id,
workflow_id=workflow_id,
)

View File

@ -74,11 +74,12 @@ class ImageService(ImageServiceABC):
# Nullable fields
node_id=node_id,
metadata=metadata,
workflow_id=workflow_id,
session_id=session_id,
)
if board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
if workflow_id is not None:
self.__invoker.services.workflow_image_records.create(workflow_id=workflow_id, image_name=image_name)
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow
)
@ -138,10 +139,11 @@ class ImageService(ImageServiceABC):
image_record = self.__invoker.services.image_records.get(image_name)
image_dto = image_record_to_dto(
image_record,
self.__invoker.services.urls.get_image_url(image_name),
self.__invoker.services.urls.get_image_url(image_name, True),
self.__invoker.services.board_image_records.get_board_for_image(image_name),
image_record=image_record,
image_url=self.__invoker.services.urls.get_image_url(image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name),
)
return image_dto
@ -162,6 +164,19 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image DTO")
raise e
def get_workflow(self, image_name: str) -> Optional[WorkflowField]:
try:
workflow_id = self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name)
if workflow_id is None:
return None
return self.__invoker.services.workflow_records.get(workflow_id)
except ImageRecordNotFoundException:
self.__invoker.services.logger.error("Image record not found")
raise
except Exception as e:
self.__invoker.services.logger.error("Problem getting image DTO")
raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))
@ -205,10 +220,11 @@ class ImageService(ImageServiceABC):
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self.__invoker.services.urls.get_image_url(r.image_name),
self.__invoker.services.urls.get_image_url(r.image_name, True),
self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
image_record=r,
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
),
results.items,
)

View File

@ -28,6 +28,7 @@ if TYPE_CHECKING:
from .shared.graph import GraphExecutionState, LibraryGraph
from .urls.urls_base import UrlServiceBase
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from .workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
class InvocationServices:
@ -56,6 +57,7 @@ class InvocationServices:
invocation_cache: "InvocationCacheBase"
names: "NameServiceBase"
urls: "UrlServiceBase"
workflow_image_records: "WorkflowImageRecordsStorageBase"
workflow_records: "WorkflowRecordsStorageBase"
def __init__(
@ -82,6 +84,7 @@ class InvocationServices:
invocation_cache: "InvocationCacheBase",
names: "NameServiceBase",
urls: "UrlServiceBase",
workflow_image_records: "WorkflowImageRecordsStorageBase",
workflow_records: "WorkflowRecordsStorageBase",
):
self.board_images = board_images
@ -106,4 +109,5 @@ class InvocationServices:
self.invocation_cache = invocation_cache
self.names = names
self.urls = urls
self.workflow_image_records = workflow_image_records
self.workflow_records = workflow_records

View File

@ -0,0 +1,23 @@
from abc import ABC, abstractmethod
from typing import Optional
class WorkflowImageRecordsStorageBase(ABC):
"""Abstract base class for the one-to-many workflow-image relationship record storage."""
@abstractmethod
def create(
self,
workflow_id: str,
image_name: str,
) -> None:
"""Creates a workflow-image record."""
pass
@abstractmethod
def get_workflow_for_image(
self,
image_name: str,
) -> Optional[str]:
"""Gets an image's workflow id, if it has one."""
pass

View File

@ -0,0 +1,123 @@
import sqlite3
import threading
from typing import Optional, cast
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
class SqliteWorkflowImageRecordsStorage(WorkflowImageRecordsStorageBase):
"""SQLite implementation of WorkflowImageRecordsStorageBase."""
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
# Create the `workflow_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS workflow_images (
workflow_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between workflows and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add index for workflow id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);
"""
)
# Add index for workflow id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
AFTER UPDATE
ON workflow_images FOR EACH ROW
BEGIN
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
END;
"""
)
def create(
self,
workflow_id: str,
image_name: str,
) -> None:
"""Creates a workflow-image record."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO workflow_images (workflow_id, image_name)
VALUES (?, ?)
ON CONFLICT (image_name) DO UPDATE SET workflow_id = ?;
""",
(workflow_id, image_name, workflow_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_workflow_for_image(
self,
image_name: str,
) -> Optional[str]:
"""Gets an image's workflow id, if it has one."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT workflow_id
FROM workflow_images
WHERE image_name = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()