feat(api): add workflow_images junction table

similar to boards, images and workflows may be associated via junction table
This commit is contained in:
psychedelicious 2023-10-18 18:16:36 +11:00
parent 6d776bad7e
commit 0cda7943fa
11 changed files with 189 additions and 32 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from logging import Logger from logging import Logger
from invokeai.app.services.workflow_image_records.workflow_image_records_sqlite import SqliteWorkflowImageRecordsStorage
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__ from invokeai.version.invokeai_version import __version__
@ -91,6 +92,7 @@ class ApiDependencies:
session_processor = DefaultSessionProcessor() session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db) session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService() urls = LocalUrlService()
workflow_image_records = SqliteWorkflowImageRecordsStorage(db=db)
workflow_records = SqliteWorkflowRecordsStorage(db=db) workflow_records = SqliteWorkflowRecordsStorage(db=db)
services = InvocationServices( services = InvocationServices(
@ -116,6 +118,7 @@ class ApiDependencies:
session_processor=session_processor, session_processor=session_processor,
session_queue=session_queue, session_queue=session_queue,
urls=urls, urls=urls,
workflow_image_records=workflow_image_records,
workflow_records=workflow_records, workflow_records=workflow_records,
) )

View File

@ -1,4 +1,5 @@
import io import io
import traceback
from typing import Optional from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
@ -60,6 +61,7 @@ async def upload_image(
pil_image = pil_image.crop(bbox) pil_image = pil_image.crop(bbox)
except Exception: except Exception:
# Error opening the image # Error opening the image
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image") raise HTTPException(status_code=415, detail="Failed to read image")
# attempt to parse metadata from image # attempt to parse metadata from image
@ -97,6 +99,7 @@ async def upload_image(
return image_dto return image_dto
except Exception: except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="Failed to create image") raise HTTPException(status_code=500, detail="Failed to create image")

View File

@ -80,7 +80,6 @@ class ImageRecordStorageBase(ABC):
session_id: Optional[str] = None, session_id: Optional[str] = None,
node_id: Optional[str] = None, node_id: Optional[str] = None,
metadata: Optional[MetadataField] = None, metadata: Optional[MetadataField] = None,
workflow_id: Optional[str] = None,
) -> datetime: ) -> datetime:
"""Saves an image record.""" """Saves an image record."""
pass pass

View File

@ -100,7 +100,6 @@ IMAGE_DTO_COLS = ", ".join(
"width", "width",
"height", "height",
"session_id", "session_id",
"workflow_id",
"node_id", "node_id",
"is_intermediate", "is_intermediate",
"created_at", "created_at",
@ -141,11 +140,6 @@ class ImageRecord(BaseModelExcludeNull):
description="The session ID that generated this image, if it is a generated image.", description="The session ID that generated this image, if it is a generated image.",
) )
"""The session ID that generated this image, if it is a generated image.""" """The session ID that generated this image, if it is a generated image."""
workflow_id: Optional[str] = Field(
default=None,
description="The workflow that generated this image.",
)
"""The workflow that generated this image."""
node_id: Optional[str] = Field( node_id: Optional[str] = Field(
default=None, default=None,
description="The node ID that generated this image, if it is a generated image.", description="The node ID that generated this image, if it is a generated image.",
@ -190,7 +184,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
width = image_dict.get("width", 0) width = image_dict.get("width", 0)
height = image_dict.get("height", 0) height = image_dict.get("height", 0)
session_id = image_dict.get("session_id", None) session_id = image_dict.get("session_id", None)
workflow_id = image_dict.get("workflow_id", None)
node_id = image_dict.get("node_id", None) node_id = image_dict.get("node_id", None)
created_at = image_dict.get("created_at", get_iso_timestamp()) created_at = image_dict.get("created_at", get_iso_timestamp())
updated_at = image_dict.get("updated_at", get_iso_timestamp()) updated_at = image_dict.get("updated_at", get_iso_timestamp())
@ -205,7 +198,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
width=width, width=width,
height=height, height=height,
session_id=session_id, session_id=session_id,
workflow_id=workflow_id,
node_id=node_id, node_id=node_id,
created_at=created_at, created_at=created_at,
updated_at=updated_at, updated_at=updated_at,

View File

@ -76,16 +76,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
) )
if "workflow_id" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images
ADD COLUMN workflow_id TEXT;
-- TODO: This requires a migration:
-- FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE SET NULL;
"""
)
# Create the `images` table indices. # Create the `images` table indices.
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -423,7 +413,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id: Optional[str] = None, session_id: Optional[str] = None,
node_id: Optional[str] = None, node_id: Optional[str] = None,
metadata: Optional[MetadataField] = None, metadata: Optional[MetadataField] = None,
workflow_id: Optional[str] = None,
) -> datetime: ) -> datetime:
try: try:
metadata_json = metadata.model_dump_json() if metadata is not None else None metadata_json = metadata.model_dump_json() if metadata is not None else None
@ -439,11 +428,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id, node_id,
session_id, session_id,
metadata, metadata,
workflow_id,
is_intermediate, is_intermediate,
starred starred
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""", """,
( (
image_name, image_name,
@ -454,7 +442,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id, node_id,
session_id, session_id,
metadata_json, metadata_json,
workflow_id,
is_intermediate, is_intermediate,
starred, starred,
), ),

View File

@ -24,6 +24,11 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
default=None, description="The id of the board the image belongs to, if one exists." default=None, description="The id of the board the image belongs to, if one exists."
) )
"""The id of the board the image belongs to, if one exists.""" """The id of the board the image belongs to, if one exists."""
workflow_id: Optional[str] = Field(
default=None,
description="The workflow that generated this image.",
)
"""The workflow that generated this image."""
def image_record_to_dto( def image_record_to_dto(
@ -31,6 +36,7 @@ def image_record_to_dto(
image_url: str, image_url: str,
thumbnail_url: str, thumbnail_url: str,
board_id: Optional[str], board_id: Optional[str],
workflow_id: Optional[str],
) -> ImageDTO: ) -> ImageDTO:
"""Converts an image record to an image DTO.""" """Converts an image record to an image DTO."""
return ImageDTO( return ImageDTO(
@ -38,4 +44,5 @@ def image_record_to_dto(
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
board_id=board_id, board_id=board_id,
workflow_id=workflow_id,
) )

View File

@ -74,11 +74,12 @@ class ImageService(ImageServiceABC):
# Nullable fields # Nullable fields
node_id=node_id, node_id=node_id,
metadata=metadata, metadata=metadata,
workflow_id=workflow_id,
session_id=session_id, session_id=session_id,
) )
if board_id is not None: if board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name) self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
if workflow_id is not None:
self.__invoker.services.workflow_image_records.create(workflow_id=workflow_id, image_name=image_name)
self.__invoker.services.image_files.save( self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow image_name=image_name, image=image, metadata=metadata, workflow=workflow
) )
@ -138,10 +139,11 @@ class ImageService(ImageServiceABC):
image_record = self.__invoker.services.image_records.get(image_name) image_record = self.__invoker.services.image_records.get(image_name)
image_dto = image_record_to_dto( image_dto = image_record_to_dto(
image_record, image_record=image_record,
self.__invoker.services.urls.get_image_url(image_name), image_url=self.__invoker.services.urls.get_image_url(image_name),
self.__invoker.services.urls.get_image_url(image_name, True), thumbnail_url=self.__invoker.services.urls.get_image_url(image_name, True),
self.__invoker.services.board_image_records.get_board_for_image(image_name), board_id=self.__invoker.services.board_image_records.get_board_for_image(image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name),
) )
return image_dto return image_dto
@ -162,6 +164,19 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image DTO") self.__invoker.services.logger.error("Problem getting image DTO")
raise e raise e
def get_workflow(self, image_name: str) -> Optional[WorkflowField]:
try:
workflow_id = self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name)
if workflow_id is None:
return None
return self.__invoker.services.workflow_records.get(workflow_id)
except ImageRecordNotFoundException:
self.__invoker.services.logger.error("Image record not found")
raise
except Exception as e:
self.__invoker.services.logger.error("Problem getting image DTO")
raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try: try:
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail)) return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))
@ -205,10 +220,11 @@ class ImageService(ImageServiceABC):
image_dtos = list( image_dtos = list(
map( map(
lambda r: image_record_to_dto( lambda r: image_record_to_dto(
r, image_record=r,
self.__invoker.services.urls.get_image_url(r.image_name), image_url=self.__invoker.services.urls.get_image_url(r.image_name),
self.__invoker.services.urls.get_image_url(r.image_name, True), thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
self.__invoker.services.board_image_records.get_board_for_image(r.image_name), board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
), ),
results.items, results.items,
) )

View File

@ -28,6 +28,7 @@ if TYPE_CHECKING:
from .shared.graph import GraphExecutionState, LibraryGraph from .shared.graph import GraphExecutionState, LibraryGraph
from .urls.urls_base import UrlServiceBase from .urls.urls_base import UrlServiceBase
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from .workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
class InvocationServices: class InvocationServices:
@ -56,6 +57,7 @@ class InvocationServices:
invocation_cache: "InvocationCacheBase" invocation_cache: "InvocationCacheBase"
names: "NameServiceBase" names: "NameServiceBase"
urls: "UrlServiceBase" urls: "UrlServiceBase"
workflow_image_records: "WorkflowImageRecordsStorageBase"
workflow_records: "WorkflowRecordsStorageBase" workflow_records: "WorkflowRecordsStorageBase"
def __init__( def __init__(
@ -82,6 +84,7 @@ class InvocationServices:
invocation_cache: "InvocationCacheBase", invocation_cache: "InvocationCacheBase",
names: "NameServiceBase", names: "NameServiceBase",
urls: "UrlServiceBase", urls: "UrlServiceBase",
workflow_image_records: "WorkflowImageRecordsStorageBase",
workflow_records: "WorkflowRecordsStorageBase", workflow_records: "WorkflowRecordsStorageBase",
): ):
self.board_images = board_images self.board_images = board_images
@ -106,4 +109,5 @@ class InvocationServices:
self.invocation_cache = invocation_cache self.invocation_cache = invocation_cache
self.names = names self.names = names
self.urls = urls self.urls = urls
self.workflow_image_records = workflow_image_records
self.workflow_records = workflow_records self.workflow_records = workflow_records

View File

@ -0,0 +1,23 @@
from abc import ABC, abstractmethod
from typing import Optional
class WorkflowImageRecordsStorageBase(ABC):
"""Abstract base class for the one-to-many workflow-image relationship record storage."""
@abstractmethod
def create(
self,
workflow_id: str,
image_name: str,
) -> None:
"""Creates a workflow-image record."""
pass
@abstractmethod
def get_workflow_for_image(
self,
image_name: str,
) -> Optional[str]:
"""Gets an image's workflow id, if it has one."""
pass

View File

@ -0,0 +1,123 @@
import sqlite3
import threading
from typing import Optional, cast
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
class SqliteWorkflowImageRecordsStorage(WorkflowImageRecordsStorageBase):
"""SQLite implementation of WorkflowImageRecordsStorageBase."""
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
# Create the `workflow_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS workflow_images (
workflow_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between workflows and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add index for workflow id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);
"""
)
# Add index for workflow id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
AFTER UPDATE
ON workflow_images FOR EACH ROW
BEGIN
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
END;
"""
)
def create(
self,
workflow_id: str,
image_name: str,
) -> None:
"""Creates a workflow-image record."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO workflow_images (workflow_id, image_name)
VALUES (?, ?)
ON CONFLICT (image_name) DO UPDATE SET workflow_id = ?;
""",
(workflow_id, image_name, workflow_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_workflow_for_image(
self,
image_name: str,
) -> Optional[str]:
"""Gets an image's workflow id, if it has one."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT workflow_id
FROM workflow_images
WHERE image_name = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()