feat(backend): update workflows handling

Update workflows handling for Workflow Library.

**Updated Workflow Storage**

"Embedded Workflows" are workflows associated with images, and are now only stored in the image files. "Library Workflows" are not associated with images, and are stored only in DB.

This works out nicely. We have always saved workflows to files, but recently began saving them to the DB in addition to in image files. When that happened, we stopped reading workflows from files, so all the workflows that only existed in images were inaccessible. With this change, access to those workflows is restored, and no workflows are lost.

**Updated Workflow Handling in Nodes**

Prior to this change, workflows were embedded in images by passing the whole workflow JSON to a special workflow field on a node. In the node's `invoke()` function, the node was able to access this workflow and save it with the image. This (inaccurately) models workflows as a property of an image and is rather awkward technically.

A workflow is now a property of a batch/session queue item. It is available in the InvocationContext and therefore available to all nodes during `invoke()`.

**Database Migrations**

Added a `SQLiteMigrator` class to handle database migrations. Migrations were needed to accomodate the DB-related changes in this PR. See the code for details.

The `images`, `workflows` and `session_queue` tables required migrations for this PR, and are using the new migrator. Other tables/services are still creating tables themselves. A followup PR will adapt them to use the migrator.

**Other/Support Changes**

- Add a `has_workflow` column to `images` table to indicate that the image has an embedded workflow.
- Add handling for retrieving the workflow from an image in python. The image file must be fetched, the workflow extracted, and then sent to client, avoiding needing the browser to parse the image file. With the `has_workflow` column, the UI knows if there is a workflow to be fetched, and only fetches when the user requests to load the workflow.
- Add route to get the workflow from an image
- Add CRUD service/routes for the library workflows
- `workflow_images` table and services removed (no longer needed now that embedded workflows are not in the DB)
This commit is contained in:
psychedelicious 2023-11-29 00:16:39 +11:00
parent 8cf2806489
commit a514c9e28b
55 changed files with 1209 additions and 626 deletions

View File

@ -2,7 +2,6 @@
from logging import Logger from logging import Logger
from invokeai.app.services.workflow_image_records.workflow_image_records_sqlite import SqliteWorkflowImageRecordsStorage
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__ from invokeai.version.invokeai_version import __version__
@ -30,7 +29,7 @@ from ..services.session_processor.session_processor_default import DefaultSessio
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.default_graphs import create_system_graphs from ..services.shared.default_graphs import create_system_graphs
from ..services.shared.graph import GraphExecutionState, LibraryGraph from ..services.shared.graph import GraphExecutionState, LibraryGraph
from ..services.shared.sqlite import SqliteDatabase from ..services.shared.sqlite.sqlite_database import SqliteDatabase
from ..services.urls.urls_default import LocalUrlService from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService from .events import FastAPIEventService
@ -94,7 +93,6 @@ class ApiDependencies:
session_processor = DefaultSessionProcessor() session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db) session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService() urls = LocalUrlService()
workflow_image_records = SqliteWorkflowImageRecordsStorage(db=db)
workflow_records = SqliteWorkflowRecordsStorage(db=db) workflow_records = SqliteWorkflowRecordsStorage(db=db)
services = InvocationServices( services = InvocationServices(
@ -121,15 +119,14 @@ class ApiDependencies:
session_processor=session_processor, session_processor=session_processor,
session_queue=session_queue, session_queue=session_queue,
urls=urls, urls=urls,
workflow_image_records=workflow_image_records,
workflow_records=workflow_records, workflow_records=workflow_records,
) )
create_system_graphs(services.graph_library) create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services) db.run_migrations()
db.clean() db.clean()
ApiDependencies.invoker = Invoker(services)
@staticmethod @staticmethod
def shutdown(): def shutdown():

View File

@ -8,10 +8,11 @@ from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator, WorkflowFieldValidator from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -73,7 +74,7 @@ async def upload_image(
workflow_raw = pil_image.info.get("invokeai_workflow", None) workflow_raw = pil_image.info.get("invokeai_workflow", None)
if workflow_raw is not None: if workflow_raw is not None:
try: try:
workflow = WorkflowFieldValidator.validate_json(workflow_raw) workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw)
except ValidationError: except ValidationError:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image") ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass pass
@ -184,6 +185,18 @@ async def get_image_metadata(
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@images_router.get(
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID]
)
async def get_image_workflow(
image_name: str = Path(description="The name of image whose workflow to get"),
) -> Optional[WorkflowWithoutID]:
try:
return ApiDependencies.invoker.services.images.get_workflow(image_name)
except Exception:
raise HTTPException(status_code=404)
@images_router.api_route( @images_router.api_route(
"/i/{image_name}/full", "/i/{image_name}/full",
methods=["GET", "HEAD"], methods=["GET", "HEAD"],

View File

@ -1,7 +1,12 @@
from fastapi import APIRouter, Path from fastapi import APIRouter, Body, HTTPException, Path, Query
from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.baseinvocation import WorkflowField from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import (
Workflow,
WorkflowNotFoundError,
WorkflowRecordDTO,
)
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"]) workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
@ -10,11 +15,68 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
"/i/{workflow_id}", "/i/{workflow_id}",
operation_id="get_workflow", operation_id="get_workflow",
responses={ responses={
200: {"model": WorkflowField}, 200: {"model": WorkflowRecordDTO},
}, },
) )
async def get_workflow( async def get_workflow(
workflow_id: str = Path(description="The workflow to get"), workflow_id: str = Path(description="The workflow to get"),
) -> WorkflowField: ) -> WorkflowRecordDTO:
"""Gets a workflow""" """Gets a workflow"""
return ApiDependencies.invoker.services.workflow_records.get(workflow_id) try:
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
@workflows_router.patch(
"/i/{workflow_id}",
operation_id="update_workflow",
responses={
200: {"model": Workflow},
},
)
async def update_workflow(
workflow: Workflow = Body(description="The updated workflow", embed=True),
) -> WorkflowRecordDTO:
"""Updates a workflow"""
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)
@workflows_router.delete(
"/i/{workflow_id}",
operation_id="delete_workflow",
)
async def delete_workflow(
workflow_id: str = Path(description="The workflow to delete"),
) -> None:
"""Deletes a workflow"""
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
@workflows_router.post(
"/",
operation_id="create_workflow",
responses={
200: {"model": WorkflowRecordDTO},
},
)
async def create_workflow(
workflow: Workflow = Body(description="The workflow to create", embed=True),
) -> WorkflowRecordDTO:
"""Creates a workflow"""
return ApiDependencies.invoker.services.workflow_records.create(workflow)
@workflows_router.get(
"/",
operation_id="list_workflows",
responses={
200: {"model": PaginatedResults[WorkflowRecordDTO]},
},
)
async def list_workflows(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of workflows per page"),
) -> PaginatedResults[WorkflowRecordDTO]:
"""Deletes a workflow"""
return ApiDependencies.invoker.services.workflow_records.get_many(page=page, per_page=per_page)

View File

@ -16,6 +16,7 @@ from pydantic.fields import FieldInfo, _Unset
from pydantic_core import PydanticUndefined from pydantic_core import PydanticUndefined
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
@ -452,6 +453,7 @@ class InvocationContext:
queue_id: str queue_id: str
queue_item_id: int queue_item_id: int
queue_batch_id: str queue_batch_id: str
workflow: Optional[WorkflowWithoutID]
def __init__( def __init__(
self, self,
@ -460,12 +462,14 @@ class InvocationContext:
queue_item_id: int, queue_item_id: int,
queue_batch_id: str, queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
workflow: Optional[WorkflowWithoutID],
): ):
self.services = services self.services = services
self.graph_execution_state_id = graph_execution_state_id self.graph_execution_state_id = graph_execution_state_id
self.queue_id = queue_id self.queue_id = queue_id
self.queue_item_id = queue_item_id self.queue_item_id = queue_item_id
self.queue_batch_id = queue_batch_id self.queue_batch_id = queue_batch_id
self.workflow = workflow
class BaseInvocationOutput(BaseModel): class BaseInvocationOutput(BaseModel):
@ -903,24 +907,6 @@ def invocation_output(
return wrapper return wrapper
class WorkflowField(RootModel):
"""
Pydantic model for workflows with custom root of type dict[str, Any].
Workflows are stored without a strict schema.
"""
root: dict[str, Any] = Field(description="The workflow")
WorkflowFieldValidator = TypeAdapter(WorkflowField)
class WithWorkflow(BaseModel):
workflow: Optional[WorkflowField] = Field(
default=None, description=FieldDescriptions.workflow, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
)
class MetadataField(RootModel): class MetadataField(RootModel):
""" """
Pydantic model for metadata with custom root of type dict[str, Any]. Pydantic model for metadata with custom root of type dict[str, Any].

View File

@ -39,7 +39,6 @@ from .baseinvocation import (
InvocationContext, InvocationContext,
OutputField, OutputField,
WithMetadata, WithMetadata,
WithWorkflow,
invocation, invocation,
invocation_output, invocation_output,
) )
@ -129,7 +128,7 @@ class ControlNetInvocation(BaseInvocation):
# This invocation exists for other invocations to subclass it - do not register with @invocation! # This invocation exists for other invocations to subclass it - do not register with @invocation!
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow): class ImageProcessorInvocation(BaseInvocation, WithMetadata):
"""Base class for invocations that preprocess images for ControlNet""" """Base class for invocations that preprocess images for ControlNet"""
image: ImageField = InputField(description="The image to process") image: ImageField = InputField(description="The image to process")
@ -153,7 +152,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
"""Builds an ImageOutput and its ImageField""" """Builds an ImageOutput and its ImageField"""
@ -173,7 +172,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
title="Canny Processor", title="Canny Processor",
tags=["controlnet", "canny"], tags=["controlnet", "canny"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class CannyImageProcessorInvocation(ImageProcessorInvocation): class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet""" """Canny edge detection for ControlNet"""
@ -196,7 +195,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
title="HED (softedge) Processor", title="HED (softedge) Processor",
tags=["controlnet", "hed", "softedge"], tags=["controlnet", "hed", "softedge"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class HedImageProcessorInvocation(ImageProcessorInvocation): class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image""" """Applies HED edge detection to image"""
@ -225,7 +224,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Processor", title="Lineart Processor",
tags=["controlnet", "lineart"], tags=["controlnet", "lineart"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class LineartImageProcessorInvocation(ImageProcessorInvocation): class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image""" """Applies line art processing to image"""
@ -247,7 +246,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Anime Processor", title="Lineart Anime Processor",
tags=["controlnet", "lineart", "anime"], tags=["controlnet", "lineart", "anime"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image""" """Applies line art anime processing to image"""
@ -270,7 +269,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
title="Openpose Processor", title="Openpose Processor",
tags=["controlnet", "openpose", "pose"], tags=["controlnet", "openpose", "pose"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class OpenposeImageProcessorInvocation(ImageProcessorInvocation): class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Openpose processing to image""" """Applies Openpose processing to image"""
@ -295,7 +294,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
title="Midas Depth Processor", title="Midas Depth Processor",
tags=["controlnet", "midas"], tags=["controlnet", "midas"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image""" """Applies Midas depth processing to image"""
@ -322,7 +321,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Normal BAE Processor", title="Normal BAE Processor",
tags=["controlnet"], tags=["controlnet"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image""" """Applies NormalBae processing to image"""
@ -339,7 +338,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
@invocation( @invocation(
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.1.0" "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0"
) )
class MlsdImageProcessorInvocation(ImageProcessorInvocation): class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image""" """Applies MLSD processing to image"""
@ -362,7 +361,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
@invocation( @invocation(
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.1.0" "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0"
) )
class PidiImageProcessorInvocation(ImageProcessorInvocation): class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image""" """Applies PIDI processing to image"""
@ -389,7 +388,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
title="Content Shuffle Processor", title="Content Shuffle Processor",
tags=["controlnet", "contentshuffle"], tags=["controlnet", "contentshuffle"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image""" """Applies content shuffle processing to image"""
@ -419,7 +418,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
title="Zoe (Depth) Processor", title="Zoe (Depth) Processor",
tags=["controlnet", "zoe", "depth"], tags=["controlnet", "zoe", "depth"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image""" """Applies Zoe depth processing to image"""
@ -435,7 +434,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Mediapipe Face Processor", title="Mediapipe Face Processor",
tags=["controlnet", "mediapipe", "face"], tags=["controlnet", "mediapipe", "face"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image""" """Applies mediapipe face processing to image"""
@ -458,7 +457,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
title="Leres (Depth) Processor", title="Leres (Depth) Processor",
tags=["controlnet", "leres", "depth"], tags=["controlnet", "leres", "depth"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class LeresImageProcessorInvocation(ImageProcessorInvocation): class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image""" """Applies leres processing to image"""
@ -487,7 +486,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
title="Tile Resample Processor", title="Tile Resample Processor",
tags=["controlnet", "tile"], tags=["controlnet", "tile"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class TileResamplerProcessorInvocation(ImageProcessorInvocation): class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor""" """Tile resampler processor"""
@ -527,7 +526,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
title="Segment Anything Processor", title="Segment Anything Processor",
tags=["controlnet", "segmentanything"], tags=["controlnet", "segmentanything"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image""" """Applies segment anything processing to image"""
@ -569,7 +568,7 @@ class SamDetectorReproducibleColors(SamDetector):
title="Color Map Processor", title="Color Map Processor",
tags=["controlnet"], tags=["controlnet"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.2.0",
) )
class ColorMapImageProcessorInvocation(ImageProcessorInvocation): class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a color map from the provided image""" """Generates a color map from the provided image"""

View File

@ -8,11 +8,11 @@ from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.1.0") @invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0")
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow): class CvInpaintInvocation(BaseInvocation, WithMetadata):
"""Simple inpaint using opencv.""" """Simple inpaint using opencv."""
image: ImageField = InputField(description="The image to inpaint") image: ImageField = InputField(description="The image to inpaint")
@ -41,7 +41,7 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(

View File

@ -17,7 +17,6 @@ from invokeai.app.invocations.baseinvocation import (
InvocationContext, InvocationContext,
OutputField, OutputField,
WithMetadata, WithMetadata,
WithWorkflow,
invocation, invocation,
invocation_output, invocation_output,
) )
@ -438,8 +437,8 @@ def get_faces_list(
return all_faces return all_faces
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.1.0") @invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0")
class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata): class FaceOffInvocation(BaseInvocation, WithMetadata):
"""Bound, extract, and mask a face from an image using MediaPipe detection""" """Bound, extract, and mask a face from an image using MediaPipe detection"""
image: ImageField = InputField(description="Image for face detection") image: ImageField = InputField(description="Image for face detection")
@ -508,7 +507,7 @@ class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow, workflow=context.workflow,
) )
mask_dto = context.services.images.create( mask_dto = context.services.images.create(
@ -532,8 +531,8 @@ class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
return output return output
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.1.0") @invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0")
class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata): class FaceMaskInvocation(BaseInvocation, WithMetadata):
"""Face mask creation using mediapipe face detection""" """Face mask creation using mediapipe face detection"""
image: ImageField = InputField(description="Image to face detect") image: ImageField = InputField(description="Image to face detect")
@ -627,7 +626,7 @@ class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow, workflow=context.workflow,
) )
mask_dto = context.services.images.create( mask_dto = context.services.images.create(
@ -650,9 +649,9 @@ class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation( @invocation(
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.1.0" "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0"
) )
class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata): class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" """Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
image: ImageField = InputField(description="Image to face detect") image: ImageField = InputField(description="Image to face detect")
@ -716,7 +715,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(

View File

@ -13,7 +13,7 @@ from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, invocation
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0") @invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
@ -36,8 +36,14 @@ class ShowImageInvocation(BaseInvocation):
) )
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.1.0") @invocation(
class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow): "blank_image",
title="Blank Image",
tags=["image"],
category="image",
version="1.2.0",
)
class BlankImageInvocation(BaseInvocation, WithMetadata):
"""Creates a blank image and forwards it to the pipeline""" """Creates a blank image and forwards it to the pipeline"""
width: int = InputField(default=512, description="The width of the image") width: int = InputField(default=512, description="The width of the image")
@ -56,7 +62,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -66,8 +72,14 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
) )
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.1.0") @invocation(
class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata): "img_crop",
title="Crop Image",
tags=["image", "crop"],
category="image",
version="1.2.0",
)
class ImageCropInvocation(BaseInvocation, WithMetadata):
"""Crops an image to a specified box. The box can be outside of the image.""" """Crops an image to a specified box. The box can be outside of the image."""
image: ImageField = InputField(description="The image to crop") image: ImageField = InputField(description="The image to crop")
@ -90,7 +102,7 @@ class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -100,8 +112,14 @@ class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
) )
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.1.0") @invocation(
class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata): "img_paste",
title="Paste Image",
tags=["image", "paste"],
category="image",
version="1.2.0",
)
class ImagePasteInvocation(BaseInvocation, WithMetadata):
"""Pastes an image into another image.""" """Pastes an image into another image."""
base_image: ImageField = InputField(description="The base image") base_image: ImageField = InputField(description="The base image")
@ -144,7 +162,7 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -154,8 +172,14 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
) )
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.1.0") @invocation(
class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata): "tomask",
title="Mask from Alpha",
tags=["image", "mask"],
category="image",
version="1.2.0",
)
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata):
"""Extracts the alpha channel of an image as a mask.""" """Extracts the alpha channel of an image as a mask."""
image: ImageField = InputField(description="The image to create the mask from") image: ImageField = InputField(description="The image to create the mask from")
@ -176,7 +200,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -186,8 +210,14 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
) )
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.1.0") @invocation(
class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata): "img_mul",
title="Multiply Images",
tags=["image", "multiply"],
category="image",
version="1.2.0",
)
class ImageMultiplyInvocation(BaseInvocation, WithMetadata):
"""Multiplies two images together using `PIL.ImageChops.multiply()`.""" """Multiplies two images together using `PIL.ImageChops.multiply()`."""
image1: ImageField = InputField(description="The first image to multiply") image1: ImageField = InputField(description="The first image to multiply")
@ -207,7 +237,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -220,8 +250,14 @@ class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
IMAGE_CHANNELS = Literal["A", "R", "G", "B"] IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.1.0") @invocation(
class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata): "img_chan",
title="Extract Image Channel",
tags=["image", "channel"],
category="image",
version="1.2.0",
)
class ImageChannelInvocation(BaseInvocation, WithMetadata):
"""Gets a channel from an image.""" """Gets a channel from an image."""
image: ImageField = InputField(description="The image to get the channel from") image: ImageField = InputField(description="The image to get the channel from")
@ -240,7 +276,7 @@ class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -253,8 +289,14 @@ class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.1.0") @invocation(
class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata): "img_conv",
title="Convert Image Mode",
tags=["image", "convert"],
category="image",
version="1.2.0",
)
class ImageConvertInvocation(BaseInvocation, WithMetadata):
"""Converts an image to a different mode.""" """Converts an image to a different mode."""
image: ImageField = InputField(description="The image to convert") image: ImageField = InputField(description="The image to convert")
@ -273,7 +315,7 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -283,8 +325,14 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
) )
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.1.0") @invocation(
class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata): "img_blur",
title="Blur Image",
tags=["image", "blur"],
category="image",
version="1.2.0",
)
class ImageBlurInvocation(BaseInvocation, WithMetadata):
"""Blurs an image""" """Blurs an image"""
image: ImageField = InputField(description="The image to blur") image: ImageField = InputField(description="The image to blur")
@ -308,7 +356,7 @@ class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -338,8 +386,14 @@ PIL_RESAMPLING_MAP = {
} }
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.1.0") @invocation(
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow): "img_resize",
title="Resize Image",
tags=["image", "resize"],
category="image",
version="1.2.0",
)
class ImageResizeInvocation(BaseInvocation, WithMetadata):
"""Resizes an image to specific dimensions""" """Resizes an image to specific dimensions"""
image: ImageField = InputField(description="The image to resize") image: ImageField = InputField(description="The image to resize")
@ -365,7 +419,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -375,8 +429,14 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
) )
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.1.0") @invocation(
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow): "img_scale",
title="Scale Image",
tags=["image", "scale"],
category="image",
version="1.2.0",
)
class ImageScaleInvocation(BaseInvocation, WithMetadata):
"""Scales an image by a factor""" """Scales an image by a factor"""
image: ImageField = InputField(description="The image to scale") image: ImageField = InputField(description="The image to scale")
@ -407,7 +467,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -417,8 +477,14 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
) )
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.1.0") @invocation(
class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata): "img_lerp",
title="Lerp Image",
tags=["image", "lerp"],
category="image",
version="1.2.0",
)
class ImageLerpInvocation(BaseInvocation, WithMetadata):
"""Linear interpolation of all pixels of an image""" """Linear interpolation of all pixels of an image"""
image: ImageField = InputField(description="The image to lerp") image: ImageField = InputField(description="The image to lerp")
@ -441,7 +507,7 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -451,8 +517,14 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
) )
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.1.0") @invocation(
class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata): "img_ilerp",
title="Inverse Lerp Image",
tags=["image", "ilerp"],
category="image",
version="1.2.0",
)
class ImageInverseLerpInvocation(BaseInvocation, WithMetadata):
"""Inverse linear interpolation of all pixels of an image""" """Inverse linear interpolation of all pixels of an image"""
image: ImageField = InputField(description="The image to lerp") image: ImageField = InputField(description="The image to lerp")
@ -475,7 +547,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -485,8 +557,14 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
) )
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.1.0") @invocation(
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow): "img_nsfw",
title="Blur NSFW Image",
tags=["image", "nsfw"],
category="image",
version="1.2.0",
)
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata):
"""Add blur to NSFW-flagged images""" """Add blur to NSFW-flagged images"""
image: ImageField = InputField(description="The image to check") image: ImageField = InputField(description="The image to check")
@ -511,7 +589,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -532,9 +610,9 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
title="Add Invisible Watermark", title="Add Invisible Watermark",
tags=["image", "watermark"], tags=["image", "watermark"],
category="image", category="image",
version="1.1.0", version="1.2.0",
) )
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow): class ImageWatermarkInvocation(BaseInvocation, WithMetadata):
"""Add an invisible watermark to an image""" """Add an invisible watermark to an image"""
image: ImageField = InputField(description="The image to check") image: ImageField = InputField(description="The image to check")
@ -551,7 +629,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -561,8 +639,14 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
) )
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.1.0") @invocation(
class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata): "mask_edge",
title="Mask Edge",
tags=["image", "mask", "inpaint"],
category="image",
version="1.2.0",
)
class MaskEdgeInvocation(BaseInvocation, WithMetadata):
"""Applies an edge mask to an image""" """Applies an edge mask to an image"""
image: ImageField = InputField(description="The image to apply the mask to") image: ImageField = InputField(description="The image to apply the mask to")
@ -597,7 +681,7 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -612,9 +696,9 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
title="Combine Masks", title="Combine Masks",
tags=["image", "mask", "multiply"], tags=["image", "mask", "multiply"],
category="image", category="image",
version="1.1.0", version="1.2.0",
) )
class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata): class MaskCombineInvocation(BaseInvocation, WithMetadata):
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
mask1: ImageField = InputField(description="The first mask to combine") mask1: ImageField = InputField(description="The first mask to combine")
@ -634,7 +718,7 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -644,8 +728,14 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
) )
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.1.0") @invocation(
class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata): "color_correct",
title="Color Correct",
tags=["image", "color"],
category="image",
version="1.2.0",
)
class ColorCorrectInvocation(BaseInvocation, WithMetadata):
""" """
Shifts the colors of a target image to match the reference image, optionally Shifts the colors of a target image to match the reference image, optionally
using a mask to only color-correct certain regions of the target image. using a mask to only color-correct certain regions of the target image.
@ -745,7 +835,7 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -755,8 +845,14 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
) )
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.1.0") @invocation(
class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata): "img_hue_adjust",
title="Adjust Image Hue",
tags=["image", "hue"],
category="image",
version="1.2.0",
)
class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata):
"""Adjusts the Hue of an image.""" """Adjusts the Hue of an image."""
image: ImageField = InputField(description="The image to adjust") image: ImageField = InputField(description="The image to adjust")
@ -785,7 +881,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -858,9 +954,9 @@ CHANNEL_FORMATS = {
"value", "value",
], ],
category="image", category="image",
version="1.1.0", version="1.2.0",
) )
class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata): class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata):
"""Add or subtract a value from a specific color channel of an image.""" """Add or subtract a value from a specific color channel of an image."""
image: ImageField = InputField(description="The image to adjust") image: ImageField = InputField(description="The image to adjust")
@ -895,7 +991,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -929,9 +1025,9 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"value", "value",
], ],
category="image", category="image",
version="1.1.0", version="1.2.0",
) )
class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata): class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata):
"""Scale a specific color channel of an image.""" """Scale a specific color channel of an image."""
image: ImageField = InputField(description="The image to adjust") image: ImageField = InputField(description="The image to adjust")
@ -970,7 +1066,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
workflow=self.workflow, workflow=context.workflow,
metadata=self.metadata, metadata=self.metadata,
) )
@ -988,10 +1084,10 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
title="Save Image", title="Save Image",
tags=["primitives", "image"], tags=["primitives", "image"],
category="primitives", category="primitives",
version="1.1.0", version="1.2.0",
use_cache=False, use_cache=False,
) )
class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata): class SaveImageInvocation(BaseInvocation, WithMetadata):
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image.""" """Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
image: ImageField = InputField(description=FieldDescriptions.image) image: ImageField = InputField(description=FieldDescriptions.image)
@ -1009,7 +1105,7 @@ class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -1027,7 +1123,7 @@ class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
version="1.0.1", version="1.0.1",
use_cache=False, use_cache=False,
) )
class LinearUIOutputInvocation(BaseInvocation, WithWorkflow, WithMetadata): class LinearUIOutputInvocation(BaseInvocation, WithMetadata):
"""Handles Linear UI Image Outputting tasks.""" """Handles Linear UI Image Outputting tasks."""
image: ImageField = InputField(description=FieldDescriptions.image) image: ImageField = InputField(description=FieldDescriptions.image)

View File

@ -13,7 +13,7 @@ from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.lama import LaMA
from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.patchmatch import PatchMatch
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
@ -118,8 +118,8 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return si return si
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0") @invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata): class InfillColorInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image with a solid color""" """Infills transparent areas of an image with a solid color"""
image: ImageField = InputField(description="The image to infill") image: ImageField = InputField(description="The image to infill")
@ -144,7 +144,7 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -154,8 +154,8 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
) )
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.1") @invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata): class InfillTileInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image with tiles of the image""" """Infills transparent areas of an image with tiles of the image"""
image: ImageField = InputField(description="The image to infill") image: ImageField = InputField(description="The image to infill")
@ -181,7 +181,7 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -192,9 +192,9 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation( @invocation(
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0" "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0"
) )
class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata): class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image using the PatchMatch algorithm""" """Infills transparent areas of an image using the PatchMatch algorithm"""
image: ImageField = InputField(description="The image to infill") image: ImageField = InputField(description="The image to infill")
@ -235,7 +235,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -245,8 +245,8 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
) )
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0") @invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata): class LaMaInfillInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image using the LaMa model""" """Infills transparent areas of an image using the LaMa model"""
image: ImageField = InputField(description="The image to infill") image: ImageField = InputField(description="The image to infill")
@ -264,7 +264,7 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(
@ -274,8 +274,8 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
) )
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0") @invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata): class CV2InfillInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image using OpenCV Inpainting""" """Infills transparent areas of an image using OpenCV Inpainting"""
image: ImageField = InputField(description="The image to infill") image: ImageField = InputField(description="The image to infill")
@ -293,7 +293,7 @@ class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(

View File

@ -64,7 +64,6 @@ from .baseinvocation import (
OutputField, OutputField,
UIType, UIType,
WithMetadata, WithMetadata,
WithWorkflow,
invocation, invocation,
invocation_output, invocation_output,
) )
@ -792,9 +791,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
title="Latents to Image", title="Latents to Image",
tags=["latents", "image", "vae", "l2i"], tags=["latents", "image", "vae", "l2i"],
category="latents", category="latents",
version="1.1.0", version="1.2.0",
) )
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow): class LatentsToImageInvocation(BaseInvocation, WithMetadata):
"""Generates an image from latents.""" """Generates an image from latents."""
latents: LatentsField = InputField( latents: LatentsField = InputField(
@ -876,7 +875,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(

View File

@ -31,7 +31,6 @@ from .baseinvocation import (
UIComponent, UIComponent,
UIType, UIType,
WithMetadata, WithMetadata,
WithWorkflow,
invocation, invocation,
invocation_output, invocation_output,
) )
@ -326,9 +325,9 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
title="ONNX Latents to Image", title="ONNX Latents to Image",
tags=["latents", "image", "vae", "onnx"], tags=["latents", "image", "vae", "onnx"],
category="image", category="image",
version="1.1.0", version="1.2.0",
) )
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow): class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
"""Generates an image from latents.""" """Generates an image from latents."""
latents: LatentsField = InputField( latents: LatentsField = InputField(
@ -378,7 +377,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(

View File

@ -14,7 +14,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import choose_torch_device
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
# TODO: Populate this from disk? # TODO: Populate this from disk?
# TODO: Use model manager to load? # TODO: Use model manager to load?
@ -29,8 +29,8 @@ if choose_torch_device() == torch.device("mps"):
from torch import mps from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.2.0") @invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0")
class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata): class ESRGANInvocation(BaseInvocation, WithMetadata):
"""Upscales an image using RealESRGAN.""" """Upscales an image using RealESRGAN."""
image: ImageField = InputField(description="The input image") image: ImageField = InputField(description="The input image")
@ -118,7 +118,7 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata, metadata=self.metadata,
workflow=self.workflow, workflow=context.workflow,
) )
return ImageOutput( return ImageOutput(

View File

@ -4,7 +4,7 @@ from typing import Optional, cast
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .board_image_records_base import BoardImageRecordStorageBase from .board_image_records_base import BoardImageRecordStorageBase

View File

@ -3,7 +3,7 @@ import threading
from typing import Union, cast from typing import Union, cast
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from .board_records_base import BoardRecordStorageBase from .board_records_base import BoardRecordStorageBase

View File

@ -4,7 +4,8 @@ from typing import Optional
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField from invokeai.app.invocations.baseinvocation import MetadataField
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageFileStorageBase(ABC): class ImageFileStorageBase(ABC):
@ -33,7 +34,7 @@ class ImageFileStorageBase(ABC):
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[MetadataField] = None, metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowField] = None, workflow: Optional[WorkflowWithoutID] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp.""" """Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -43,3 +44,8 @@ class ImageFileStorageBase(ABC):
def delete(self, image_name: str) -> None: def delete(self, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists).""" """Deletes an image and its thumbnail (if one exists)."""
pass pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
"""Gets the workflow of an image."""
pass

View File

@ -7,8 +7,9 @@ from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField from invokeai.app.invocations.baseinvocation import MetadataField
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
from .image_files_base import ImageFileStorageBase from .image_files_base import ImageFileStorageBase
@ -56,7 +57,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[MetadataField] = None, metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowField] = None, workflow: Optional[WorkflowWithoutID] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
@ -64,12 +65,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
image_path = self.get_path(image_name) image_path = self.get_path(image_name)
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
info_dict = {}
if metadata is not None: if metadata is not None:
pnginfo.add_text("invokeai_metadata", metadata.model_dump_json()) metadata_json = metadata.model_dump_json()
info_dict["invokeai_metadata"] = metadata_json
pnginfo.add_text("invokeai_metadata", metadata_json)
if workflow is not None: if workflow is not None:
pnginfo.add_text("invokeai_workflow", workflow.model_dump_json()) workflow_json = workflow.model_dump_json()
info_dict["invokeai_workflow"] = workflow_json
pnginfo.add_text("invokeai_workflow", workflow_json)
# When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict
image.save( image.save(
image_path, image_path,
"PNG", "PNG",
@ -121,6 +129,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
path = path if isinstance(path, Path) else Path(path) path = path if isinstance(path, Path) else Path(path)
return path.exists() return path.exists()
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None:
image = self.get(image_name)
workflow = image.info.get("invokeai_workflow", None)
if workflow is not None:
return WorkflowWithoutID.model_validate_json(workflow)
return None
def __validate_storage_folders(self) -> None: def __validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't""" """Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder] folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]

View File

@ -75,6 +75,7 @@ class ImageRecordStorageBase(ABC):
image_category: ImageCategory, image_category: ImageCategory,
width: int, width: int,
height: int, height: int,
has_workflow: bool,
is_intermediate: Optional[bool] = False, is_intermediate: Optional[bool] = False,
starred: Optional[bool] = False, starred: Optional[bool] = False,
session_id: Optional[str] = None, session_id: Optional[str] = None,

View File

@ -100,6 +100,7 @@ IMAGE_DTO_COLS = ", ".join(
"height", "height",
"session_id", "session_id",
"node_id", "node_id",
"has_workflow",
"is_intermediate", "is_intermediate",
"created_at", "created_at",
"updated_at", "updated_at",
@ -145,6 +146,7 @@ class ImageRecord(BaseModelExcludeNull):
"""The node ID that generated this image, if it is a generated image.""" """The node ID that generated this image, if it is a generated image."""
starred: bool = Field(description="Whether this image is starred.") starred: bool = Field(description="Whether this image is starred.")
"""Whether this image is starred.""" """Whether this image is starred."""
has_workflow: bool = Field(description="Whether this image has a workflow.")
class ImageRecordChanges(BaseModelExcludeNull, extra="allow"): class ImageRecordChanges(BaseModelExcludeNull, extra="allow"):
@ -188,6 +190,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False) is_intermediate = image_dict.get("is_intermediate", False)
starred = image_dict.get("starred", False) starred = image_dict.get("starred", False)
has_workflow = image_dict.get("has_workflow", False)
return ImageRecord( return ImageRecord(
image_name=image_name, image_name=image_name,
@ -202,4 +205,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
deleted_at=deleted_at, deleted_at=deleted_at,
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
starred=starred, starred=starred,
has_workflow=has_workflow,
) )

View File

@ -1,11 +1,13 @@
import sqlite3 import sqlite3
import threading
from datetime import datetime from datetime import datetime
from typing import Optional, Union, cast from typing import Optional, Union, cast
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.migrations import v0, v1, v2
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration, MigrationSet
from .image_records_base import ImageRecordStorageBase from .image_records_base import ImageRecordStorageBase
from .image_records_common import ( from .image_records_common import (
@ -20,102 +22,27 @@ from .image_records_common import (
deserialize_image_record, deserialize_image_record,
) )
images_migrations = MigrationSet(
table_name="images",
migrations=[
Migration(version=0, migrate=v0),
Migration(version=1, migrate=v1),
Migration(version=2, migrate=v2),
],
)
class SqliteImageRecordStorage(ImageRecordStorageBase): class SqliteImageRecordStorage(ImageRecordStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None: def __init__(self, db: SqliteDatabase) -> None:
super().__init__() super().__init__()
self._db = db
self._lock = db.lock self._lock = db.lock
self._conn = db.conn self._conn = db.conn
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._db.register_migration_set(images_migrations)
try: def start(self, invoker: Invoker) -> None:
self._lock.acquire() self._invoker = invoker
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `images` table."""
# Create the `images` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in self._cursor.fetchall()]
if "starred" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
)
def get(self, image_name: str) -> ImageRecord: def get(self, image_name: str) -> ImageRecord:
try: try:
@ -408,6 +335,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_category: ImageCategory, image_category: ImageCategory,
width: int, width: int,
height: int, height: int,
has_workflow: bool,
is_intermediate: Optional[bool] = False, is_intermediate: Optional[bool] = False,
starred: Optional[bool] = False, starred: Optional[bool] = False,
session_id: Optional[str] = None, session_id: Optional[str] = None,
@ -429,9 +357,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id, session_id,
metadata, metadata,
is_intermediate, is_intermediate,
starred starred,
has_workflow
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?); VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""", """,
( (
image_name, image_name,
@ -444,6 +373,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
metadata_json, metadata_json,
is_intermediate, is_intermediate,
starred, starred,
has_workflow,
), ),
) )
self._conn.commit() self._conn.commit()

View File

@ -0,0 +1,5 @@
from .v0 import v0
from .v1 import v1
from .v2 import v2
__all__ = [v0, v1, v2] # type: ignore

View File

@ -0,0 +1,64 @@
import sqlite3
def v0(cursor: sqlite3.Cursor) -> None:
"""
Migration for `images` table v0
https://github.com/invoke-ai/InvokeAI/pull/3443
Adds the `images` table, indicies and triggers for the image_records service.
"""
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
"""
)
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER
UPDATE ON images FOR EACH ROW BEGIN
UPDATE images
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
)

View File

@ -0,0 +1,25 @@
import sqlite3
def v1(cursor: sqlite3.Cursor) -> None:
"""
Migration for `images` table v1
https://github.com/invoke-ai/InvokeAI/pull/4246
Adds the `starred` column to the `images` table.
"""
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "starred" not in columns:
cursor.execute(
"""--sql
ALTER TABLE images
ADD COLUMN starred BOOLEAN DEFAULT FALSE;
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
"""
)

View File

@ -0,0 +1,24 @@
import sqlite3
def v2(cursor: sqlite3.Cursor) -> None:
"""
Migration for `images` table v2
https://github.com/invoke-ai/InvokeAI/pull/5148
Adds the `has_workflow` column to the `images` table.
Workflows associated with images are now only stored in the image file itself. This column
indicates whether the image has a workflow embedded in it, so we don't need to read the image
file to find out.
"""
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "has_workflow" not in columns:
cursor.execute(
"""--sql
ALTER TABLE images
ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;
"""
)

View File

@ -3,7 +3,7 @@ from typing import Callable, Optional
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField from invokeai.app.invocations.baseinvocation import MetadataField
from invokeai.app.services.image_records.image_records_common import ( from invokeai.app.services.image_records.image_records_common import (
ImageCategory, ImageCategory,
ImageRecord, ImageRecord,
@ -12,6 +12,7 @@ from invokeai.app.services.image_records.image_records_common import (
) )
from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageServiceABC(ABC): class ImageServiceABC(ABC):
@ -51,7 +52,7 @@ class ImageServiceABC(ABC):
board_id: Optional[str] = None, board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False, is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None, metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowField] = None, workflow: Optional[WorkflowWithoutID] = None,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass
@ -85,6 +86,11 @@ class ImageServiceABC(ABC):
"""Gets an image's metadata.""" """Gets an image's metadata."""
pass pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
"""Gets an image's workflow."""
pass
@abstractmethod @abstractmethod
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path.""" """Gets an image's path."""

View File

@ -24,11 +24,6 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
default=None, description="The id of the board the image belongs to, if one exists." default=None, description="The id of the board the image belongs to, if one exists."
) )
"""The id of the board the image belongs to, if one exists.""" """The id of the board the image belongs to, if one exists."""
workflow_id: Optional[str] = Field(
default=None,
description="The workflow that generated this image.",
)
"""The workflow that generated this image."""
def image_record_to_dto( def image_record_to_dto(
@ -36,7 +31,6 @@ def image_record_to_dto(
image_url: str, image_url: str,
thumbnail_url: str, thumbnail_url: str,
board_id: Optional[str], board_id: Optional[str],
workflow_id: Optional[str],
) -> ImageDTO: ) -> ImageDTO:
"""Converts an image record to an image DTO.""" """Converts an image record to an image DTO."""
return ImageDTO( return ImageDTO(
@ -44,5 +38,4 @@ def image_record_to_dto(
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
board_id=board_id, board_id=board_id,
workflow_id=workflow_id,
) )

View File

@ -2,9 +2,10 @@ from typing import Optional
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField from invokeai.app.invocations.baseinvocation import MetadataField
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from ..image_files.image_files_common import ( from ..image_files.image_files_common import (
ImageFileDeleteException, ImageFileDeleteException,
@ -42,7 +43,7 @@ class ImageService(ImageServiceABC):
board_id: Optional[str] = None, board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False, is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None, metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowField] = None, workflow: Optional[WorkflowWithoutID] = None,
) -> ImageDTO: ) -> ImageDTO:
if image_origin not in ResourceOrigin: if image_origin not in ResourceOrigin:
raise InvalidOriginException raise InvalidOriginException
@ -55,12 +56,6 @@ class ImageService(ImageServiceABC):
(width, height) = image.size (width, height) = image.size
try: try:
if workflow is not None:
created_workflow = self.__invoker.services.workflow_records.create(workflow)
workflow_id = created_workflow.model_dump()["id"]
else:
workflow_id = None
# TODO: Consider using a transaction here to ensure consistency between storage and database # TODO: Consider using a transaction here to ensure consistency between storage and database
self.__invoker.services.image_records.save( self.__invoker.services.image_records.save(
# Non-nullable fields # Non-nullable fields
@ -69,6 +64,7 @@ class ImageService(ImageServiceABC):
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,
has_workflow=workflow is not None,
# Meta fields # Meta fields
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
# Nullable fields # Nullable fields
@ -78,8 +74,6 @@ class ImageService(ImageServiceABC):
) )
if board_id is not None: if board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name) self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
if workflow_id is not None:
self.__invoker.services.workflow_image_records.create(workflow_id=workflow_id, image_name=image_name)
self.__invoker.services.image_files.save( self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow image_name=image_name, image=image, metadata=metadata, workflow=workflow
) )
@ -143,7 +137,6 @@ class ImageService(ImageServiceABC):
image_url=self.__invoker.services.urls.get_image_url(image_name), image_url=self.__invoker.services.urls.get_image_url(image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(image_name, True), thumbnail_url=self.__invoker.services.urls.get_image_url(image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(image_name), board_id=self.__invoker.services.board_image_records.get_board_for_image(image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name),
) )
return image_dto return image_dto
@ -164,18 +157,15 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image DTO") self.__invoker.services.logger.error("Problem getting image DTO")
raise e raise e
def get_workflow(self, image_name: str) -> Optional[WorkflowField]: def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
try: try:
workflow_id = self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name) return self.__invoker.services.image_files.get_workflow(image_name)
if workflow_id is None: except ImageFileNotFoundException:
return None self.__invoker.services.logger.error("Image file not found")
return self.__invoker.services.workflow_records.get(workflow_id) raise
except ImageRecordNotFoundException: except Exception:
self.__invoker.services.logger.error("Image record not found") self.__invoker.services.logger.error("Problem getting image workflow")
raise raise
except Exception as e:
self.__invoker.services.logger.error("Problem getting image DTO")
raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try: try:
@ -223,7 +213,6 @@ class ImageService(ImageServiceABC):
image_url=self.__invoker.services.urls.get_image_url(r.image_name), image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True), thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name), board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
) )
for r in results.items for r in results.items
] ]

View File

@ -108,6 +108,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
queue_item_id=queue_item.session_queue_item_id, queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id, queue_id=queue_item.session_queue_id,
queue_batch_id=queue_item.session_queue_batch_id, queue_batch_id=queue_item.session_queue_batch_id,
workflow=queue_item.workflow,
) )
) )
@ -178,6 +179,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
session_queue_item_id=queue_item.session_queue_item_id, session_queue_item_id=queue_item.session_queue_item_id,
session_queue_id=queue_item.session_queue_id, session_queue_id=queue_item.session_queue_id,
graph_execution_state=graph_execution_state, graph_execution_state=graph_execution_state,
workflow=queue_item.workflow,
invoke_all=True, invoke_all=True,
) )
except Exception as e: except Exception as e:

View File

@ -1,9 +1,12 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import time import time
from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class InvocationQueueItem(BaseModel): class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state") graph_execution_state_id: str = Field(description="The ID of the graph execution state")
@ -15,5 +18,6 @@ class InvocationQueueItem(BaseModel):
session_queue_batch_id: str = Field( session_queue_batch_id: str = Field(
description="The ID of the session batch from which this invocation queue item came" description="The ID of the session batch from which this invocation queue item came"
) )
workflow: Optional[WorkflowWithoutID] = Field(description="The workflow associated with this queue item")
invoke_all: bool = Field(default=False) invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time) timestamp: float = Field(default_factory=time.time)

View File

@ -28,7 +28,6 @@ if TYPE_CHECKING:
from .session_queue.session_queue_base import SessionQueueBase from .session_queue.session_queue_base import SessionQueueBase
from .shared.graph import GraphExecutionState, LibraryGraph from .shared.graph import GraphExecutionState, LibraryGraph
from .urls.urls_base import UrlServiceBase from .urls.urls_base import UrlServiceBase
from .workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
@ -59,7 +58,6 @@ class InvocationServices:
invocation_cache: "InvocationCacheBase" invocation_cache: "InvocationCacheBase"
names: "NameServiceBase" names: "NameServiceBase"
urls: "UrlServiceBase" urls: "UrlServiceBase"
workflow_image_records: "WorkflowImageRecordsStorageBase"
workflow_records: "WorkflowRecordsStorageBase" workflow_records: "WorkflowRecordsStorageBase"
def __init__( def __init__(
@ -87,7 +85,6 @@ class InvocationServices:
invocation_cache: "InvocationCacheBase", invocation_cache: "InvocationCacheBase",
names: "NameServiceBase", names: "NameServiceBase",
urls: "UrlServiceBase", urls: "UrlServiceBase",
workflow_image_records: "WorkflowImageRecordsStorageBase",
workflow_records: "WorkflowRecordsStorageBase", workflow_records: "WorkflowRecordsStorageBase",
): ):
self.board_images = board_images self.board_images = board_images
@ -113,5 +110,4 @@ class InvocationServices:
self.invocation_cache = invocation_cache self.invocation_cache = invocation_cache
self.names = names self.names = names
self.urls = urls self.urls = urls
self.workflow_image_records = workflow_image_records
self.workflow_records = workflow_records self.workflow_records = workflow_records

View File

@ -2,6 +2,8 @@
from typing import Optional from typing import Optional
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from .invocation_queue.invocation_queue_common import InvocationQueueItem from .invocation_queue.invocation_queue_common import InvocationQueueItem
from .invocation_services import InvocationServices from .invocation_services import InvocationServices
from .shared.graph import Graph, GraphExecutionState from .shared.graph import Graph, GraphExecutionState
@ -22,6 +24,7 @@ class Invoker:
session_queue_item_id: int, session_queue_item_id: int,
session_queue_batch_id: str, session_queue_batch_id: str,
graph_execution_state: GraphExecutionState, graph_execution_state: GraphExecutionState,
workflow: Optional[WorkflowWithoutID] = None,
invoke_all: bool = False, invoke_all: bool = False,
) -> Optional[str]: ) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed. """Determines the next node to invoke and enqueues it, preparing if needed.
@ -43,6 +46,7 @@ class Invoker:
session_queue_batch_id=session_queue_batch_id, session_queue_batch_id=session_queue_batch_id,
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id, invocation_id=invocation.id,
workflow=workflow,
invoke_all=invoke_all, invoke_all=invoke_all,
) )
) )

View File

@ -5,7 +5,7 @@ from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .item_storage_base import ItemStorageABC from .item_storage_base import ItemStorageABC

View File

@ -52,7 +52,7 @@ from invokeai.backend.model_manager.config import (
ModelType, ModelType,
) )
from ..shared.sqlite import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import ( from .model_records_base import (
CONFIG_FILE_VERSION, CONFIG_FILE_VERSION,
DuplicateModelException, DuplicateModelException,

View File

@ -114,6 +114,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
session_queue_id=queue_item.queue_id, session_queue_id=queue_item.queue_id,
session_queue_item_id=queue_item.item_id, session_queue_item_id=queue_item.item_id,
graph_execution_state=queue_item.session, graph_execution_state=queue_item.session,
workflow=queue_item.workflow,
invoke_all=True, invoke_all=True,
) )
queue_item = None queue_item = None

View File

@ -0,0 +1,4 @@
from .v0 import v0
from .v1 import v1
__all__ = [v0, v1] # type: ignore

View File

@ -0,0 +1,106 @@
import sqlite3
def v0(cursor: sqlite3.Cursor) -> None:
"""
Migration for `session_queue` table v0
https://github.com/invoke-ai/InvokeAI/pull/4502
Creates the `session_queue` table, indicies and triggers for the session_queue service.
"""
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
started_at DATETIME, -- updated via trigger
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
"""
)
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'in_progress'
BEGIN
UPDATE session_queue
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
"""
)

View File

@ -0,0 +1,22 @@
import sqlite3
def v1(cursor: sqlite3.Cursor) -> None:
"""
Migration for `session_queue` table v1
https://github.com/invoke-ai/InvokeAI/pull/5148
Adds the `workflow` column to the `session_queue` table.
Workflows have been (correctly) made a property of a queue item, rather than individual nodes.
This requires they be included in the session queue.
"""
cursor.execute("PRAGMA table_info(session_queue)")
columns = [column[1] for column in cursor.fetchall()]
if "workflow" not in columns:
cursor.execute(
"""--sql
ALTER TABLE session_queue ADD COLUMN workflow TEXT;
"""
)

View File

@ -8,6 +8,10 @@ from pydantic_core import to_jsonable_python
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowWithoutID,
WorkflowWithoutIDValidator,
)
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
# region Errors # region Errors
@ -66,6 +70,9 @@ class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch") batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.") data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with") graph: Graph = Field(description="The graph to initialize the session with")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow to initialize the session with"
)
runs: int = Field( runs: int = Field(
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices" default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
) )
@ -164,6 +171,12 @@ def get_session(queue_item_dict: dict) -> GraphExecutionState:
return session return session
def get_workflow(queue_item_dict: dict) -> WorkflowWithoutID:
workflow_raw = queue_item_dict.get("workflow", "{}")
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw, strict=False)
return workflow
class SessionQueueItemWithoutGraph(BaseModel): class SessionQueueItemWithoutGraph(BaseModel):
"""Session queue item without the full graph. Used for serialization.""" """Session queue item without the full graph. Used for serialization."""
@ -213,12 +226,16 @@ class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
class SessionQueueItem(SessionQueueItemWithoutGraph): class SessionQueueItem(SessionQueueItemWithoutGraph):
session: GraphExecutionState = Field(description="The fully-populated session to be executed") session: GraphExecutionState = Field(description="The fully-populated session to be executed")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow associated with this queue item"
)
@classmethod @classmethod
def queue_item_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem": def queue_item_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
# must parse these manually # must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict) queue_item_dict["field_values"] = get_field_values(queue_item_dict)
queue_item_dict["session"] = get_session(queue_item_dict) queue_item_dict["session"] = get_session(queue_item_dict)
queue_item_dict["workflow"] = get_workflow(queue_item_dict)
return SessionQueueItem(**queue_item_dict) return SessionQueueItem(**queue_item_dict)
model_config = ConfigDict( model_config = ConfigDict(
@ -334,7 +351,7 @@ def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) ->
def create_session_nfv_tuples( def create_session_nfv_tuples(
batch: Batch, maximum: int batch: Batch, maximum: int
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue]], None, None]: ) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
""" """
Create all graph permutations from the given batch data and graph. Yields tuples Create all graph permutations from the given batch data and graph. Yields tuples
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
@ -365,7 +382,7 @@ def create_session_nfv_tuples(
return return
flat_node_field_values = list(chain.from_iterable(d)) flat_node_field_values = list(chain.from_iterable(d))
graph = populate_graph(batch.graph, flat_node_field_values) graph = populate_graph(batch.graph, flat_node_field_values)
yield (GraphExecutionState(graph=graph), flat_node_field_values) yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
count += 1 count += 1
@ -391,12 +408,14 @@ def calc_session_count(batch: Batch) -> int:
class SessionQueueValueToInsert(NamedTuple): class SessionQueueValueToInsert(NamedTuple):
"""A tuple of values to insert into the session_queue table""" """A tuple of values to insert into the session_queue table"""
# Careful with the ordering of this - it must match the insert statement
queue_id: str # queue_id queue_id: str # queue_id
session: str # session json session: str # session json
session_id: str # session_id session_id: str # session_id
batch_id: str # batch_id batch_id: str # batch_id
field_values: Optional[str] # field_values json field_values: Optional[str] # field_values json
priority: int # priority priority: int # priority
workflow: Optional[str] # workflow json
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert] ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
@ -404,7 +423,7 @@ ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert: def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
values_to_insert: ValuesToInsert = [] values_to_insert: ValuesToInsert = []
for session, field_values in create_session_nfv_tuples(batch, max_new_queue_items): for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
# sessions must have unique id # sessions must have unique id
session.id = uuid_string() session.id = uuid_string()
values_to_insert.append( values_to_insert.append(
@ -416,6 +435,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
# must use pydantic_encoder bc field_values is a list of models # must use pydantic_encoder bc field_values is a list of models
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json) json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
priority, # priority priority, # priority
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
) )
) )
return values_to_insert return values_to_insert

View File

@ -1,5 +1,4 @@
import sqlite3 import sqlite3
import threading
from typing import Optional, Union, cast from typing import Optional, Union, cast
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
@ -7,6 +6,7 @@ from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.migrations import v0, v1
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
DEFAULT_QUEUE_ID, DEFAULT_QUEUE_ID,
@ -28,14 +28,26 @@ from invokeai.app.services.session_queue.session_queue_common import (
prepare_values_to_insert, prepare_values_to_insert,
) )
from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration, MigrationSet
session_queue_migrations = MigrationSet(
table_name="session_queue",
migrations=[
Migration(version=0, migrate=v0),
Migration(version=1, migrate=v1),
],
)
class SqliteSessionQueue(SessionQueueBase): class SqliteSessionQueue(SessionQueueBase):
__invoker: Invoker def __init__(self, db: SqliteDatabase) -> None:
__conn: sqlite3.Connection super().__init__()
__cursor: sqlite3.Cursor self.__db = db
__lock: threading.RLock self.__lock = db.lock
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
self.__db.register_migration_set(session_queue_migrations)
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self.__invoker = invoker self.__invoker = invoker
@ -44,13 +56,6 @@ class SqliteSessionQueue(SessionQueueBase):
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event) local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self.__lock = db.lock
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
self._create_tables()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool: def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in return event[1]["event"] in match_in
@ -97,114 +102,6 @@ class SqliteSessionQueue(SessionQueueBase):
except SessionQueueItemNotFoundError: except SessionQueueItemNotFoundError:
return return
def _create_tables(self) -> None:
"""Creates the session queue tables, indicies, and triggers"""
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
started_at DATETIME, -- updated via trigger
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'in_progress'
BEGIN
UPDATE session_queue
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
"""
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
def _set_in_progress_to_canceled(self) -> None: def _set_in_progress_to_canceled(self) -> None:
""" """
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue. Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
@ -280,8 +177,8 @@ class SqliteSessionQueue(SessionQueueBase):
self.__cursor.executemany( self.__cursor.executemany(
"""--sql """--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority) INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
""", """,
values_to_insert, values_to_insert,
) )

View File

@ -0,0 +1,61 @@
from typing import Annotated, Any, Callable
from pydantic import GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema
from semver import Version
class _VersionPydanticAnnotation:
"""
Pydantic annotation for semver.Version.
Requires a field_serializer to serialize to a string.
Usage:
class MyModel(BaseModel):
version: SemVer = Field(..., description="The version of the model.")
@field_serializer("version")
def serialize_version(self, version: SemVer, _info):
return str(version)
MyModel(version=semver.Version.parse("1.2.3"))
MyModel.model_validate({"version":"1.2.3"})
"""
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: Callable[[Any], core_schema.CoreSchema],
) -> core_schema.CoreSchema:
def validate_from_str(value: str) -> Version:
return Version.parse(value)
from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(validate_from_str),
]
)
return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(Version),
from_str_schema,
]
),
serialization=core_schema.to_string_ser_schema(),
)
@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.str_schema())
SemVer = Annotated[Version, _VersionPydanticAnnotation]

View File

@ -0,0 +1 @@
sqlite_memory = ":memory:"

View File

@ -3,27 +3,22 @@ import threading
from logging import Logger from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
sqlite_memory = ":memory:" from invokeai.app.services.shared.sqlite.sqlite_migrator import MigrationSet, SQLiteMigrator
class SqliteDatabase: class SqliteDatabase:
conn: sqlite3.Connection
lock: threading.RLock
_logger: Logger
_config: InvokeAIAppConfig
def __init__(self, config: InvokeAIAppConfig, logger: Logger): def __init__(self, config: InvokeAIAppConfig, logger: Logger):
self._logger = logger self._logger = logger
self._config = config self._config = config
if self._config.use_memory_db: if self._config.use_memory_db:
location = sqlite_memory location = sqlite_memory
logger.info("Using in-memory database") self._logger.info("Using in-memory database")
else: else:
db_path = self._config.db_path db_path = self._config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True) db_path.parent.mkdir(parents=True, exist_ok=True)
location = str(db_path) location = db_path
self._logger.info(f"Using database at {location}") self._logger.info(f"Using database at {location}")
self.conn = sqlite3.connect(location, check_same_thread=False) self.conn = sqlite3.connect(location, check_same_thread=False)
@ -34,6 +29,7 @@ class SqliteDatabase:
self.conn.set_trace_callback(self._logger.debug) self.conn.set_trace_callback(self._logger.debug)
self.conn.execute("PRAGMA foreign_keys = ON;") self.conn.execute("PRAGMA foreign_keys = ON;")
self._migrator = SQLiteMigrator(db_path=location, lock=self.lock, logger=self._logger)
def clean(self) -> None: def clean(self) -> None:
try: try:
@ -41,8 +37,14 @@ class SqliteDatabase:
self.conn.execute("VACUUM;") self.conn.execute("VACUUM;")
self.conn.commit() self.conn.commit()
self._logger.info("Cleaned database") self._logger.info("Cleaned database")
except Exception as e: except sqlite3.Error as e:
self._logger.error(f"Error cleaning database: {e}") self._logger.error(f"Error cleaning database: {e}")
raise e raise
finally: finally:
self.lock.release() self.lock.release()
def register_migration_set(self, migration_set: MigrationSet) -> None:
self._migrator.register_migration_set(migration_set)
def run_migrations(self) -> None:
self._migrator.run_migrations()

View File

@ -0,0 +1,192 @@
import datetime
import shutil
import sqlite3
import threading
from logging import Logger
from pathlib import Path
from typing import Callable, Optional, TypeAlias
from .sqlite_common import sqlite_memory
MigrateCallback: TypeAlias = Callable[[sqlite3.Cursor], None]
class MigrationError(Exception):
pass
class Migration:
def __init__(
self,
version: int,
migrate: MigrateCallback,
) -> None:
self.version = version
self.migrate = migrate
class MigrationSet:
def __init__(self, table_name: str, migrations: list[Migration]) -> None:
self.table_name = table_name
self.migrations = migrations
class SQLiteMigrator:
"""
Handles SQLite database migrations.
Migrations are registered with the `register_migration_set` method. They are applied on
application startup with the `run_migrations` method.
A `MigrationSet` is a set of `Migration`s for a single table. Each `Migration` has a `version`
and `migrate` callback. The callback is provided with a `sqlite3.Cursor` and should perform the
any migration logic. Committing, rolling back transactions and errors are handled by the migrator.
Migrations are applied in order of version number. If the database does not have a version table
for a given table, it is assumed to be at version 0. The migrator creates and manages the version
tables.
If the database is a file, it will be backed up before migrations are applied and restored if
there are any errors.
"""
def __init__(self, db_path: Path | str, lock: threading.RLock, logger: Logger):
self._logger = logger
self._conn = sqlite3.connect(db_path, check_same_thread=False)
self._cursor = self._conn.cursor()
self._lock = lock
self._db_path = db_path
self._migration_sets: set[MigrationSet] = set()
def _get_version_table_name(self, table_name: str) -> str:
"""Returns the name of the version table for a given table."""
return f"{table_name}_version"
def _create_version_table(self, table_name: str) -> None:
"""
Creates a version table for a given table, if it does not exist.
Throws MigrationError if there is a problem.
"""
version_table_name = self._get_version_table_name(table_name)
with self._lock:
try:
self._cursor.execute(
f"""--sql
CREATE TABLE IF NOT EXISTS {version_table_name} (
version INTEGER PRIMARY KEY,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
self._conn.commit()
except sqlite3.Error as e:
msg = f'Problem creation "{version_table_name}" table: {e}'
self._logger.error(msg)
self._conn.rollback()
raise MigrationError(msg) from e
def _get_current_version(self, table_name: str) -> Optional[int]:
"""Gets the current version of a table, or None if it doesn't exist."""
version_table_name = self._get_version_table_name(table_name)
try:
self._cursor.execute(f"SELECT MAX(version) FROM {version_table_name};")
return self._cursor.fetchone()[0]
except sqlite3.OperationalError as e:
if "no such table" in str(e):
return None
raise
def _set_version(self, table_name: str, version: int) -> None:
"""Adds a version entry to the table's version table."""
version_table_name = self._get_version_table_name(table_name)
self._cursor.execute(f"INSERT INTO {version_table_name} (version) VALUES (?);", (version,))
def _run_migration(self, table_name: str, migration: Migration) -> None:
"""Runs a single migration."""
with self._lock:
try:
migration.migrate(self._cursor)
self._set_version(table_name=table_name, version=migration.version)
self._conn.commit()
except sqlite3.Error:
self._conn.rollback()
raise
def _run_migration_set(self, migration_set: MigrationSet) -> None:
"""Runs a set of migrations for a single table."""
with self._lock:
table_name = migration_set.table_name
migrations = migration_set.migrations
self._create_version_table(table_name=table_name)
for migration in migrations:
current_version = self._get_current_version(table_name)
if current_version is None or current_version < migration.version:
try:
self._logger.info(f'runing "{table_name}" migration {migration.version}')
self._run_migration(table_name=table_name, migration=migration)
except sqlite3.Error as e:
raise MigrationError(f'Problem runing "{table_name}" migration {migration.version}: {e}') from e
def _backup_db(self, db_path: Path) -> Path:
"""Backs up the databse, returning the path to the backup file."""
with self._lock:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = db_path.parent / f"{db_path.stem}_{timestamp}.db"
self._logger.info(f"Backing up database to {backup_path}")
backup_conn = sqlite3.connect(backup_path)
with backup_conn:
self._conn.backup(backup_conn)
backup_conn.close()
return backup_path
def _restore_db(self, backup_path: Path) -> None:
"""Restores the database from a backup file, unless the database is a memory database."""
if self._db_path == sqlite_memory:
return
with self._lock:
self._logger.info(f"Restoring database from {backup_path}")
self._conn.close()
if not Path(backup_path).is_file():
raise FileNotFoundError(f"Backup file {backup_path} does not exist")
shutil.copy2(backup_path, self._db_path)
def _get_is_migration_needed(self, migration_set: MigrationSet) -> bool:
table_name = migration_set.table_name
migrations = migration_set.migrations
current_version = self._get_current_version(table_name)
if current_version is None or current_version < migrations[-1].version:
return True
return False
def run_migrations(self) -> None:
"""
Applies all registered migration sets.
If the database is a file, it will be backed up before migrations are applied and restored
if there are any errors.
"""
if not any(self._get_is_migration_needed(migration_set) for migration_set in self._migration_sets):
return
backup_path: Optional[Path] = None
with self._lock:
# Only make a backup if using a file database (not memory)
if isinstance(self._db_path, Path):
backup_path = self._backup_db(self._db_path)
for migration_set in self._migration_sets:
if self._get_is_migration_needed(migration_set):
try:
self._run_migration_set(migration_set)
except Exception as e:
msg = f'Problem runing "{migration_set.table_name}" migrations: {e}'
self._logger.error(msg)
if backup_path is not None:
self._logger.error(f" Restoring from {backup_path}")
self._restore_db(backup_path)
raise MigrationError(msg) from e
# TODO: delete backup file?
# if backup_path is not None:
# Path(backup_path).unlink()
def register_migration_set(self, migration_set: MigrationSet) -> None:
"""Registers a migration set to be migrated on application startup."""
self._migration_sets.add(migration_set)

View File

@ -1,23 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
class WorkflowImageRecordsStorageBase(ABC):
"""Abstract base class for the one-to-many workflow-image relationship record storage."""
@abstractmethod
def create(
self,
workflow_id: str,
image_name: str,
) -> None:
"""Creates a workflow-image record."""
pass
@abstractmethod
def get_workflow_for_image(
self,
image_name: str,
) -> Optional[str]:
"""Gets an image's workflow id, if it has one."""
pass

View File

@ -1,122 +0,0 @@
import sqlite3
import threading
from typing import Optional, cast
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
class SqliteWorkflowImageRecordsStorage(WorkflowImageRecordsStorageBase):
"""SQLite implementation of WorkflowImageRecordsStorageBase."""
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
# Create the `workflow_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS workflow_images (
workflow_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between workflows and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add index for workflow id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);
"""
)
# Add index for workflow id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
AFTER UPDATE
ON workflow_images FOR EACH ROW
BEGIN
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
END;
"""
)
def create(
self,
workflow_id: str,
image_name: str,
) -> None:
"""Creates a workflow-image record."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO workflow_images (workflow_id, image_name)
VALUES (?, ?);
""",
(workflow_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_workflow_for_image(
self,
image_name: str,
) -> Optional[str]:
"""Gets an image's workflow id, if it has one."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT workflow_id
FROM workflow_images
WHERE image_name = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@ -0,0 +1,4 @@
from invokeai.app.services.workflow_records.migrations.v0 import v0
from invokeai.app.services.workflow_records.migrations.v1 import v1
__all__ = [v0, v1] # type: ignore

View File

@ -0,0 +1,35 @@
import sqlite3
def v0(cursor: sqlite3.Cursor) -> None:
"""
Migration for `workflows` table v0
https://github.com/invoke-ai/InvokeAI/pull/4686
Creates the `workflows` table for the workflow_records service & a trigger for updated_at.
Note: `workflow_id` gets an implicit index. We don't need to make one for this column.
"""
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS workflows (
workflow TEXT NOT NULL,
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
);
"""
)
cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
AFTER UPDATE
ON workflows FOR EACH ROW
BEGIN
UPDATE workflows
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
)

View File

@ -0,0 +1,33 @@
import sqlite3
def v1(cursor: sqlite3.Cursor) -> None:
"""
Migration for `workflows` table v1
https://github.com/invoke-ai/InvokeAI/pull/5148
Drops the `workflow_images` table and empties the `workflows` table.
Prior to v3.5.0, all workflows were associated with images. They were stored in the image files
themselves, and in v3.4.0 we started storing them in the DB. This turned out to be a bad idea -
you end up with *many* image workflows, most of which are duplicates.
The purpose of workflows DB storage was to provide a workflow library. Library workflows are
different from image workflows. They are only saved when the user requests they be saved.
Moving forward, the storage for image workflows and library workflows will be separate. Image
workflows are store only in the image files. Library workflows are stored only in the DB.
To give ourselves a clean slate, we need to delete all existing workflows in the DB (all of which)
are image workflows. We also need to delete the workflow_images table, which is no longer needed.
"""
cursor.execute(
"""--sql
DROP TABLE IF EXISTS workflow_images;
"""
)
cursor.execute(
"""--sql
DELETE FROM workflows;
"""
)

View File

@ -1,17 +1,36 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from invokeai.app.invocations.baseinvocation import WorkflowField from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import (
Workflow,
WorkflowRecordDTO,
)
class WorkflowRecordsStorageBase(ABC): class WorkflowRecordsStorageBase(ABC):
"""Base class for workflow storage services.""" """Base class for workflow storage services."""
@abstractmethod @abstractmethod
def get(self, workflow_id: str) -> WorkflowField: def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Get workflow by id.""" """Get workflow by id."""
pass pass
@abstractmethod @abstractmethod
def create(self, workflow: WorkflowField) -> WorkflowField: def create(self, workflow: Workflow) -> WorkflowRecordDTO:
"""Creates a workflow.""" """Creates a workflow."""
pass pass
@abstractmethod
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
"""Updates a workflow."""
pass
@abstractmethod
def delete(self, workflow_id: str) -> None:
"""Deletes a workflow."""
pass
@abstractmethod
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordDTO]:
"""Gets many workflows."""
pass

View File

@ -1,2 +1,73 @@
import datetime
from typing import Any, Union
import semver
from pydantic import BaseModel, Field, JsonValue, TypeAdapter, field_validator
from invokeai.app.util.misc import uuid_string
__workflow_meta_version__ = semver.Version.parse("1.0.0")
class ExposedField(BaseModel):
nodeId: str
fieldName: str
class WorkflowMeta(BaseModel):
version: str = Field(description="The version of the workflow schema.")
@field_validator("version")
def validate_version(cls, version: str):
try:
semver.Version.parse(version)
return version
except Exception:
raise ValueError(f"Invalid workflow meta version: {version}")
def to_semver(self) -> semver.Version:
return semver.Version.parse(self.version)
class WorkflowWithoutID(BaseModel):
name: str = Field(description="The name of the workflow.")
author: str = Field(description="The author of the workflow.")
description: str = Field(description="The description of the workflow.")
version: str = Field(description="The version of the workflow.")
contact: str = Field(description="The contact of the workflow.")
tags: str = Field(description="The tags of the workflow.")
notes: str = Field(description="The notes of the workflow.")
exposedFields: list[ExposedField] = Field(description="The exposed fields of the workflow.")
meta: WorkflowMeta = Field(description="The meta of the workflow.")
# TODO: nodes and edges are very loosely typed
nodes: list[dict[str, JsonValue]] = Field(description="The nodes of the workflow.")
edges: list[dict[str, JsonValue]] = Field(description="The edges of the workflow.")
WorkflowWithoutIDValidator = TypeAdapter(WorkflowWithoutID)
class Workflow(WorkflowWithoutID):
id: str = Field(default_factory=uuid_string, description="The id of the workflow.")
WorkflowValidator = TypeAdapter(Workflow)
class WorkflowRecordDTO(BaseModel):
workflow_id: str = Field(description="The id of the workflow.")
workflow: Workflow = Field(description="The workflow.")
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the workflow.")
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the workflow.")
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRecordDTO":
data["workflow"] = WorkflowValidator.validate_json(data.get("workflow", ""))
return WorkflowRecordDTOValidator.validate_python(data)
WorkflowRecordDTOValidator = TypeAdapter(WorkflowRecordDTO)
class WorkflowNotFoundError(Exception): class WorkflowNotFoundError(Exception):
"""Raised when a workflow is not found""" """Raised when a workflow is not found"""

View File

@ -1,36 +1,42 @@
import sqlite3
import threading
from invokeai.app.invocations.baseinvocation import WorkflowField, WorkflowFieldValidator
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_migrator import Migration, MigrationSet
from invokeai.app.services.workflow_records.migrations import v0, v1
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowNotFoundError from invokeai.app.services.workflow_records.workflow_records_common import (
from invokeai.app.util.misc import uuid_string Workflow,
WorkflowNotFoundError,
WorkflowRecordDTO,
)
workflows_migrations = MigrationSet(
table_name="workflows",
migrations=[
Migration(version=0, migrate=v0),
Migration(version=1, migrate=v1),
],
)
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
_invoker: Invoker
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None: def __init__(self, db: SqliteDatabase) -> None:
super().__init__() super().__init__()
self._db = db
self._lock = db.lock self._lock = db.lock
self._conn = db.conn self._conn = db.conn
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._create_tables() self._db.register_migration_set(workflows_migrations)
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self._invoker = invoker self._invoker = invoker
def get(self, workflow_id: str) -> WorkflowField: def get(self, workflow_id: str) -> WorkflowRecordDTO:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT workflow SELECT workflow_id, workflow, created_at, updated_at
FROM workflows FROM workflows
WHERE workflow_id = ?; WHERE workflow_id = ?;
""", """,
@ -39,18 +45,15 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
row = self._cursor.fetchone() row = self._cursor.fetchone()
if row is None: if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found") raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowFieldValidator.validate_json(row[0]) return WorkflowRecordDTO.from_dict(dict(row))
except Exception: except Exception:
self._conn.rollback() self._conn.rollback()
raise raise
finally: finally:
self._lock.release() self._lock.release()
def create(self, workflow: WorkflowField) -> WorkflowField: def create(self, workflow: Workflow) -> WorkflowRecordDTO:
try: try:
# workflows do not have ids until they are saved
workflow_id = uuid_string()
workflow.root["id"] = workflow_id
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -65,38 +68,77 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
raise raise
finally: finally:
self._lock.release() self._lock.release()
return self.get(workflow_id) return self.get(workflow.id)
def _create_tables(self) -> None: def update(self, workflow: Workflow) -> WorkflowRecordDTO:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
CREATE TABLE IF NOT EXISTS workflows ( UPDATE workflows
workflow TEXT NOT NULL, SET workflow = ?
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index WHERE workflow_id = ?;
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), """,
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger (workflow.model_dump_json(), workflow.id),
);
"""
) )
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
AFTER UPDATE
ON workflows FOR EACH ROW
BEGIN
UPDATE workflows
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
)
self._conn.commit() self._conn.commit()
except Exception: except Exception:
self._conn.rollback() self._conn.rollback()
raise raise
finally: finally:
self._lock.release() self._lock.release()
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE from workflows
WHERE workflow_id = ?;
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def get_many(self, page: int, per_page: int) -> PaginatedResults[WorkflowRecordDTO]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT workflow_id, workflow, created_at, updated_at
FROM workflows
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
""",
(per_page, page * per_page),
)
rows = self._cursor.fetchall()
workflows = [WorkflowRecordDTO.from_dict(dict(row)) for row in rows]
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM workflows;
"""
)
total = self._cursor.fetchone()[0]
pages = int(total / per_page) + 1
return PaginatedResults(
items=workflows,
page=page,
per_page=per_page,
pages=pages,
total=total,
)
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()

View File

@ -11,7 +11,7 @@ from invokeai.app.services.model_records import (
DuplicateModelException, DuplicateModelException,
ModelRecordServiceSQL, ModelRecordServiceSQL,
) )
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,

View File

@ -13,7 +13,7 @@ from invokeai.app.services.model_records import (
ModelRecordServiceSQL, ModelRecordServiceSQL,
UnknownModelException, UnknownModelException,
) )
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
BaseModelType, BaseModelType,
MainCheckpointConfig, MainCheckpointConfig,

View File

@ -24,7 +24,7 @@ from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@pytest.fixture @pytest.fixture

View File

@ -3,7 +3,7 @@ from pydantic import BaseModel, Field
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger