fix(nodes): fix uploading image metadata retention

was causing failure to save images
This commit is contained in:
psychedelicious 2023-08-30 14:52:50 +10:00
parent 94d0c18cbd
commit ae05d34584
6 changed files with 87 additions and 22 deletions

View File

@ -372,8 +372,9 @@ class UIConfigBase(BaseModel):
decorators, though you may add this class to a node definition to specify the title and tags.
"""
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
title: Optional[str] = Field(default=None, description="The display name of the node")
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
title: Optional[str] = Field(default=None, description="The node's display name")
category: Optional[str] = Field(default=None, description="The node's category")
class InvocationContext:
@ -469,6 +470,8 @@ class BaseInvocation(ABC, BaseModel):
schema["title"] = uiconfig.title
if uiconfig and hasattr(uiconfig, "tags"):
schema["tags"] = uiconfig.tags
if uiconfig and hasattr(uiconfig, "category"):
schema["category"] = uiconfig.category
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list()
schema["required"].extend(["type", "id"])
@ -558,3 +561,39 @@ def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
return cls
return wrapper
def category(category: str) -> Callable[[Type[T]], Type[T]]:
"""Adds a category to the invocation. This is used to group invocations in the UI."""
def wrapper(cls: Type[T]) -> Type[T]:
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig.category = category
return cls
return wrapper
def node(
title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None
) -> Callable[[Type[T]], Type[T]]:
"""
Adds metadata to the invocation as a decorator.
:param Optional[str] title: Adds a title to the node. Use if the auto-generated title isn't quite right. Defaults to None.
:param Optional[list[str]] tags: Adds tags to the node. Nodes may be searched for by their tags. Defaults to None.
:param Optional[str] category: Adds a category to the node. Used to group the nodes in the UI. Defaults to None.
"""
def wrapper(cls: Type[T]) -> Type[T]:
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
return cls
return wrapper

View File

@ -8,11 +8,10 @@ from pydantic import validator
from invokeai.app.invocations.primitives import IntegerCollectionOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .baseinvocation import BaseInvocation, InputField, InvocationContext, node
@title("Integer Range")
@tags("collection", "integer", "range")
@node(title="Integer Range", tags=["collection", "integer", "range"], category="collections")
class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step"""
@ -33,8 +32,7 @@ class RangeInvocation(BaseInvocation):
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
@title("Integer Range of Size")
@tags("range", "integer", "size", "collection")
@node(title="Integer Range of Size", tags=["collection", "integer", "size", "range"], category="collections")
class RangeOfSizeInvocation(BaseInvocation):
"""Creates a range from start to start + size with step"""
@ -49,8 +47,7 @@ class RangeOfSizeInvocation(BaseInvocation):
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
@title("Random Range")
@tags("range", "integer", "random", "collection")
@node(title="Random Range", tags=["range", "integer", "random", "collection"], category="collections")
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""

View File

@ -26,6 +26,7 @@ from .baseinvocation import (
InvocationContext,
OutputField,
UIComponent,
category,
tags,
title,
)
@ -44,8 +45,9 @@ class ConditioningFieldData:
# PerpNeg = "perp_neg"
@title("Compel Prompt")
@title("Prompt")
@tags("prompt", "compel")
@category("conditioning")
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
@ -265,8 +267,9 @@ class SDXLPromptInvocationBase:
return c, c_pooled, ec
@title("SDXL Compel Prompt")
@title("SDXL Prompt")
@tags("sdxl", "compel", "prompt")
@category("conditioning")
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@ -324,8 +327,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
)
@title("SDXL Refiner Compel Prompt")
@title("SDXL Refiner Prompt")
@tags("sdxl", "compel", "prompt")
@category("conditioning")
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@ -381,6 +385,7 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
@title("CLIP Skip")
@tags("clipskip", "clip", "skip")
@category("conditioning")
class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""

View File

@ -40,6 +40,8 @@ from .baseinvocation import (
InvocationContext,
OutputField,
UIType,
category,
node,
tags,
title,
)
@ -96,8 +98,7 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control)
@title("ControlNet")
@tags("controlnet")
@node(title="ControlNet", tags=["controlnet"], category="controlnet")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
@ -177,6 +178,7 @@ class ImageProcessorInvocation(BaseInvocation):
@title("Canny Processor")
@tags("controlnet", "canny")
@category("controlnet")
class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet"""
@ -198,6 +200,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
@title("HED (softedge) Processor")
@tags("controlnet", "hed", "softedge")
@category("controlnet")
class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image"""
@ -225,6 +228,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
@title("Lineart Processor")
@tags("controlnet", "lineart")
@category("controlnet")
class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image"""
@ -245,6 +249,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
@title("Lineart Anime Processor")
@tags("controlnet", "lineart", "anime")
@category("controlnet")
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image"""
@ -266,6 +271,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
@title("Openpose Processor")
@tags("controlnet", "openpose", "pose")
@category("controlnet")
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Openpose processing to image"""
@ -289,6 +295,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
@title("Midas (Depth) Processor")
@tags("controlnet", "midas", "depth")
@category("controlnet")
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image"""
@ -314,6 +321,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
@title("Normal BAE Processor")
@tags("controlnet", "normal", "bae")
@category("controlnet")
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image"""
@ -333,6 +341,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
@title("MLSD Processor")
@tags("controlnet", "mlsd")
@category("controlnet")
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image"""
@ -358,6 +367,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
@title("PIDI Processor")
@tags("controlnet", "pidi")
@category("controlnet")
class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image"""
@ -383,6 +393,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
@title("Content Shuffle Processor")
@tags("controlnet", "contentshuffle")
@category("controlnet")
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image"""
@ -411,6 +422,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
@title("Zoe (Depth) Processor")
@tags("controlnet", "zoe", "depth")
@category("controlnet")
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
@ -424,6 +436,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
@title("Mediapipe Face Processor")
@tags("controlnet", "mediapipe", "face")
@category("controlnet")
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image"""
@ -445,6 +458,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
@title("Leres (Depth) Processor")
@tags("controlnet", "leres", "depth")
@category("controlnet")
class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image"""
@ -472,6 +486,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
@title("Tile Resample Processor")
@tags("controlnet", "tile")
@category("controlnet")
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor"""
@ -510,6 +525,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
@title("Segment Anything Processor")
@tags("controlnet", "segmentanything")
@category("controlnet")
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""

View File

@ -8,11 +8,12 @@ from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .baseinvocation import BaseInvocation, InputField, InvocationContext, category, tags, title
@title("OpenCV Inpaint")
@tags("opencv", "inpaint")
@category("inpaint")
class CvInpaintInvocation(BaseInvocation):
"""Simple inpaint using opencv."""

View File

@ -119,13 +119,20 @@ class DiskImageFileStorage(ImageFileStorageBase):
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None:
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
if workflow is not None:
pnginfo.add_text("invokeai_workflow", workflow)
# For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back
for item_name, item in image.info.items():
pnginfo.add_text(item_name, item)
if metadata is not None or workflow is not None:
if metadata is not None:
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
if workflow is not None:
pnginfo.add_text("invokeai_workflow", workflow)
else:
# For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back
# TODO: retain non-invokeai metadata on save...
original_metadata = image.info.get("invokeai_metadata", None)
if original_metadata is not None:
pnginfo.add_text("invokeai_metadata", original_metadata)
original_workflow = image.info.get("invokeai_workflow", None)
if original_workflow is not None:
pnginfo.add_text("invokeai_workflow", original_workflow)
image.save(image_path, "PNG", pnginfo=pnginfo)