feat: workflow saving and loading

This commit is contained in:
psychedelicious
2023-08-24 21:42:32 +10:00
parent 7f6fdf5d39
commit 7d1942e9f0
51 changed files with 1175 additions and 320 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
import json
from typing import (
TYPE_CHECKING,
AbstractSet,
@ -20,7 +21,7 @@ from typing import (
get_type_hints,
)
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
from pydantic.fields import Undefined
from pydantic.typing import NoArgAnyCallable
@ -141,9 +142,11 @@ class UIType(str, Enum):
# endregion
# region Misc
FilePath = "FilePath"
Enum = "enum"
Scheduler = "Scheduler"
WorkflowField = "WorkflowField"
IsIntermediate = "IsIntermediate"
MetadataField = "MetadataField"
# endregion
@ -507,8 +510,24 @@ class BaseInvocation(ABC, BaseModel):
id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = InputField(
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
default=False, description="Whether or not this node is an intermediate node.", ui_type=UIType.IsIntermediate
)
workflow: Optional[str] = InputField(
default=None,
description="The workflow to save with the image",
ui_type=UIType.WorkflowField,
)
@validator("workflow", pre=True)
def validate_workflow_is_json(cls, v):
if v is None:
return None
try:
json.loads(v)
except json.decoder.JSONDecodeError:
raise ValueError("Workflow must be valid JSON")
return v
UIConfig: ClassVar[Type[UIConfigBase]]

View File

@ -151,11 +151,6 @@ class ImageProcessorInvocation(BaseInvocation):
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
# FIXME: what happened to image metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
image_dto = context.services.images.create(
@ -165,6 +160,7 @@ class ImageProcessorInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
"""Builds an ImageOutput and its ImageField"""

View File

@ -45,6 +45,7 @@ class CvInpaintInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -65,6 +65,7 @@ class BlankImageInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -102,6 +103,7 @@ class ImageCropInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -154,6 +156,7 @@ class ImagePasteInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -189,6 +192,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -223,6 +227,7 @@ class ImageMultiplyInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -259,6 +264,7 @@ class ImageChannelInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -295,6 +301,7 @@ class ImageConvertInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -333,6 +340,7 @@ class ImageBlurInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -393,6 +401,7 @@ class ImageResizeInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -438,6 +447,7 @@ class ImageScaleInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -475,6 +485,7 @@ class ImageLerpInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -512,6 +523,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -555,6 +567,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
@ -596,6 +609,7 @@ class ImageWatermarkInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
@ -644,6 +658,7 @@ class MaskEdgeInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -677,6 +692,7 @@ class MaskCombineInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -785,6 +801,7 @@ class ColorCorrectInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -827,6 +844,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
workflow=self.workflow,
)
return ImageOutput(
@ -877,6 +895,7 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
workflow=self.workflow,
)
return ImageOutput(
@ -925,6 +944,7 @@ class ImageSaturationAdjustmentInvocation(BaseInvocation):
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -145,6 +145,7 @@ class InfillColorInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -184,6 +185,7 @@ class InfillTileInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -218,6 +220,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -545,6 +545,7 @@ class LatentsToImageInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -376,6 +376,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -7,7 +7,7 @@ from pydantic import validator
from invokeai.app.invocations.primitives import StringCollectionOutput
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UIType, tags, title
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, tags, title
@title("Dynamic Prompt")
@ -41,7 +41,7 @@ class PromptsFromFileInvocation(BaseInvocation):
type: Literal["prompt_from_file"] = "prompt_from_file"
# Inputs
file_path: str = InputField(description="Path to prompt text file", ui_type=UIType.FilePath)
file_path: str = InputField(description="Path to prompt text file")
pre_prompt: Optional[str] = InputField(
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
)

View File

@ -110,6 +110,7 @@ class ESRGANInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -60,7 +60,7 @@ class ImageFileStorageBase(ABC):
image: PILImageType,
image_name: str,
metadata: Optional[dict] = None,
graph: Optional[dict] = None,
workflow: Optional[str] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -110,7 +110,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image: PILImageType,
image_name: str,
metadata: Optional[dict] = None,
graph: Optional[dict] = None,
workflow: Optional[str] = None,
thumbnail_size: int = 256,
) -> None:
try:
@ -121,8 +121,8 @@ class DiskImageFileStorage(ImageFileStorageBase):
if metadata is not None:
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
if graph is not None:
pnginfo.add_text("invokeai_graph", json.dumps(graph))
if workflow is not None:
pnginfo.add_text("invokeai_workflow", workflow)
image.save(image_path, "PNG", pnginfo=pnginfo)
thumbnail_name = get_thumbnail_name(image_name)

View File

@ -54,6 +54,7 @@ class ImageServiceABC(ABC):
board_id: Optional[str] = None,
is_intermediate: bool = False,
metadata: Optional[dict] = None,
workflow: Optional[str] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@ -177,6 +178,7 @@ class ImageService(ImageServiceABC):
board_id: Optional[str] = None,
is_intermediate: bool = False,
metadata: Optional[dict] = None,
workflow: Optional[str] = None,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
@ -186,16 +188,16 @@ class ImageService(ImageServiceABC):
image_name = self._services.names.create_image_name()
graph = None
if session_id is not None:
session_raw = self._services.graph_execution_manager.get_raw(session_id)
if session_raw is not None:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
# TODO: Do we want to store the graph in the image at all? I don't think so...
# graph = None
# if session_id is not None:
# session_raw = self._services.graph_execution_manager.get_raw(session_id)
# if session_raw is not None:
# try:
# graph = get_metadata_graph_from_raw_session(session_raw)
# except Exception as e:
# self._services.logger.warn(f"Failed to parse session graph: {e}")
# graph = None
(width, height) = image.size
@ -217,7 +219,7 @@ class ImageService(ImageServiceABC):
)
if board_id is not None:
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, graph=graph)
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
image_dto = self.get_dto(image_name)
return image_dto