fix(nodes): fix uploading image metadata retention

was causing failure to save images
This commit is contained in:
psychedelicious 2023-08-30 14:52:50 +10:00
parent 94d0c18cbd
commit ae05d34584
6 changed files with 87 additions and 22 deletions

View File

@ -372,8 +372,9 @@ class UIConfigBase(BaseModel):
decorators, though you may add this class to a node definition to specify the title and tags. decorators, though you may add this class to a node definition to specify the title and tags.
""" """
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI") tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
title: Optional[str] = Field(default=None, description="The display name of the node") title: Optional[str] = Field(default=None, description="The node's display name")
category: Optional[str] = Field(default=None, description="The node's category")
class InvocationContext: class InvocationContext:
@ -469,6 +470,8 @@ class BaseInvocation(ABC, BaseModel):
schema["title"] = uiconfig.title schema["title"] = uiconfig.title
if uiconfig and hasattr(uiconfig, "tags"): if uiconfig and hasattr(uiconfig, "tags"):
schema["tags"] = uiconfig.tags schema["tags"] = uiconfig.tags
if uiconfig and hasattr(uiconfig, "category"):
schema["category"] = uiconfig.category
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list() schema["required"] = list()
schema["required"].extend(["type", "id"]) schema["required"].extend(["type", "id"])
@ -558,3 +561,39 @@ def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
return cls return cls
return wrapper return wrapper
def category(category: str) -> Callable[[Type[T]], Type[T]]:
"""Adds a category to the invocation. This is used to group invocations in the UI."""
def wrapper(cls: Type[T]) -> Type[T]:
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig.category = category
return cls
return wrapper
def node(
title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None
) -> Callable[[Type[T]], Type[T]]:
"""
Adds metadata to the invocation as a decorator.
:param Optional[str] title: Adds a title to the node. Use if the auto-generated title isn't quite right. Defaults to None.
:param Optional[list[str]] tags: Adds tags to the node. Nodes may be searched for by their tags. Defaults to None.
:param Optional[str] category: Adds a category to the node. Used to group the nodes in the UI. Defaults to None.
"""
def wrapper(cls: Type[T]) -> Type[T]:
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
return cls
return wrapper

View File

@ -8,11 +8,10 @@ from pydantic import validator
from invokeai.app.invocations.primitives import IntegerCollectionOutput from invokeai.app.invocations.primitives import IntegerCollectionOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title from .baseinvocation import BaseInvocation, InputField, InvocationContext, node
@title("Integer Range") @node(title="Integer Range", tags=["collection", "integer", "range"], category="collections")
@tags("collection", "integer", "range")
class RangeInvocation(BaseInvocation): class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step""" """Creates a range of numbers from start to stop with step"""
@ -33,8 +32,7 @@ class RangeInvocation(BaseInvocation):
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step))) return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
@title("Integer Range of Size") @node(title="Integer Range of Size", tags=["collection", "integer", "size", "range"], category="collections")
@tags("range", "integer", "size", "collection")
class RangeOfSizeInvocation(BaseInvocation): class RangeOfSizeInvocation(BaseInvocation):
"""Creates a range from start to start + size with step""" """Creates a range from start to start + size with step"""
@ -49,8 +47,7 @@ class RangeOfSizeInvocation(BaseInvocation):
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step))) return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
@title("Random Range") @node(title="Random Range", tags=["range", "integer", "random", "collection"], category="collections")
@tags("range", "integer", "random", "collection")
class RandomRangeInvocation(BaseInvocation): class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers""" """Creates a collection of random numbers"""

View File

@ -26,6 +26,7 @@ from .baseinvocation import (
InvocationContext, InvocationContext,
OutputField, OutputField,
UIComponent, UIComponent,
category,
tags, tags,
title, title,
) )
@ -44,8 +45,9 @@ class ConditioningFieldData:
# PerpNeg = "perp_neg" # PerpNeg = "perp_neg"
@title("Compel Prompt") @title("Prompt")
@tags("prompt", "compel") @tags("prompt", "compel")
@category("conditioning")
class CompelInvocation(BaseInvocation): class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -265,8 +267,9 @@ class SDXLPromptInvocationBase:
return c, c_pooled, ec return c, c_pooled, ec
@title("SDXL Compel Prompt") @title("SDXL Prompt")
@tags("sdxl", "compel", "prompt") @tags("sdxl", "compel", "prompt")
@category("conditioning")
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -324,8 +327,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
) )
@title("SDXL Refiner Compel Prompt") @title("SDXL Refiner Prompt")
@tags("sdxl", "compel", "prompt") @tags("sdxl", "compel", "prompt")
@category("conditioning")
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -381,6 +385,7 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
@title("CLIP Skip") @title("CLIP Skip")
@tags("clipskip", "clip", "skip") @tags("clipskip", "clip", "skip")
@category("conditioning")
class ClipSkipInvocation(BaseInvocation): class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model.""" """Skip layers in clip text_encoder model."""

View File

@ -40,6 +40,8 @@ from .baseinvocation import (
InvocationContext, InvocationContext,
OutputField, OutputField,
UIType, UIType,
category,
node,
tags, tags,
title, title,
) )
@ -96,8 +98,7 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control) control: ControlField = OutputField(description=FieldDescriptions.control)
@title("ControlNet") @node(title="ControlNet", tags=["controlnet"], category="controlnet")
@tags("controlnet")
class ControlNetInvocation(BaseInvocation): class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
@ -177,6 +178,7 @@ class ImageProcessorInvocation(BaseInvocation):
@title("Canny Processor") @title("Canny Processor")
@tags("controlnet", "canny") @tags("controlnet", "canny")
@category("controlnet")
class CannyImageProcessorInvocation(ImageProcessorInvocation): class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet""" """Canny edge detection for ControlNet"""
@ -198,6 +200,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
@title("HED (softedge) Processor") @title("HED (softedge) Processor")
@tags("controlnet", "hed", "softedge") @tags("controlnet", "hed", "softedge")
@category("controlnet")
class HedImageProcessorInvocation(ImageProcessorInvocation): class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image""" """Applies HED edge detection to image"""
@ -225,6 +228,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
@title("Lineart Processor") @title("Lineart Processor")
@tags("controlnet", "lineart") @tags("controlnet", "lineart")
@category("controlnet")
class LineartImageProcessorInvocation(ImageProcessorInvocation): class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image""" """Applies line art processing to image"""
@ -245,6 +249,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
@title("Lineart Anime Processor") @title("Lineart Anime Processor")
@tags("controlnet", "lineart", "anime") @tags("controlnet", "lineart", "anime")
@category("controlnet")
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image""" """Applies line art anime processing to image"""
@ -266,6 +271,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
@title("Openpose Processor") @title("Openpose Processor")
@tags("controlnet", "openpose", "pose") @tags("controlnet", "openpose", "pose")
@category("controlnet")
class OpenposeImageProcessorInvocation(ImageProcessorInvocation): class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Openpose processing to image""" """Applies Openpose processing to image"""
@ -289,6 +295,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
@title("Midas (Depth) Processor") @title("Midas (Depth) Processor")
@tags("controlnet", "midas", "depth") @tags("controlnet", "midas", "depth")
@category("controlnet")
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image""" """Applies Midas depth processing to image"""
@ -314,6 +321,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
@title("Normal BAE Processor") @title("Normal BAE Processor")
@tags("controlnet", "normal", "bae") @tags("controlnet", "normal", "bae")
@category("controlnet")
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image""" """Applies NormalBae processing to image"""
@ -333,6 +341,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
@title("MLSD Processor") @title("MLSD Processor")
@tags("controlnet", "mlsd") @tags("controlnet", "mlsd")
@category("controlnet")
class MlsdImageProcessorInvocation(ImageProcessorInvocation): class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image""" """Applies MLSD processing to image"""
@ -358,6 +367,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
@title("PIDI Processor") @title("PIDI Processor")
@tags("controlnet", "pidi") @tags("controlnet", "pidi")
@category("controlnet")
class PidiImageProcessorInvocation(ImageProcessorInvocation): class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image""" """Applies PIDI processing to image"""
@ -383,6 +393,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
@title("Content Shuffle Processor") @title("Content Shuffle Processor")
@tags("controlnet", "contentshuffle") @tags("controlnet", "contentshuffle")
@category("controlnet")
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image""" """Applies content shuffle processing to image"""
@ -411,6 +422,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13 # should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
@title("Zoe (Depth) Processor") @title("Zoe (Depth) Processor")
@tags("controlnet", "zoe", "depth") @tags("controlnet", "zoe", "depth")
@category("controlnet")
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image""" """Applies Zoe depth processing to image"""
@ -424,6 +436,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
@title("Mediapipe Face Processor") @title("Mediapipe Face Processor")
@tags("controlnet", "mediapipe", "face") @tags("controlnet", "mediapipe", "face")
@category("controlnet")
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image""" """Applies mediapipe face processing to image"""
@ -445,6 +458,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
@title("Leres (Depth) Processor") @title("Leres (Depth) Processor")
@tags("controlnet", "leres", "depth") @tags("controlnet", "leres", "depth")
@category("controlnet")
class LeresImageProcessorInvocation(ImageProcessorInvocation): class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image""" """Applies leres processing to image"""
@ -472,6 +486,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
@title("Tile Resample Processor") @title("Tile Resample Processor")
@tags("controlnet", "tile") @tags("controlnet", "tile")
@category("controlnet")
class TileResamplerProcessorInvocation(ImageProcessorInvocation): class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor""" """Tile resampler processor"""
@ -510,6 +525,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
@title("Segment Anything Processor") @title("Segment Anything Processor")
@tags("controlnet", "segmentanything") @tags("controlnet", "segmentanything")
@category("controlnet")
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image""" """Applies segment anything processing to image"""

View File

@ -8,11 +8,12 @@ from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title from .baseinvocation import BaseInvocation, InputField, InvocationContext, category, tags, title
@title("OpenCV Inpaint") @title("OpenCV Inpaint")
@tags("opencv", "inpaint") @tags("opencv", "inpaint")
@category("inpaint")
class CvInpaintInvocation(BaseInvocation): class CvInpaintInvocation(BaseInvocation):
"""Simple inpaint using opencv.""" """Simple inpaint using opencv."""

View File

@ -119,13 +119,20 @@ class DiskImageFileStorage(ImageFileStorageBase):
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
if metadata is not None or workflow is not None:
if metadata is not None: if metadata is not None:
pnginfo.add_text("invokeai_metadata", json.dumps(metadata)) pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
if workflow is not None: if workflow is not None:
pnginfo.add_text("invokeai_workflow", workflow) pnginfo.add_text("invokeai_workflow", workflow)
else:
# For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back # For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back
for item_name, item in image.info.items(): # TODO: retain non-invokeai metadata on save...
pnginfo.add_text(item_name, item) original_metadata = image.info.get("invokeai_metadata", None)
if original_metadata is not None:
pnginfo.add_text("invokeai_metadata", original_metadata)
original_workflow = image.info.get("invokeai_workflow", None)
if original_workflow is not None:
pnginfo.add_text("invokeai_workflow", original_workflow)
image.save(image_path, "PNG", pnginfo=pnginfo) image.save(image_path, "PNG", pnginfo=pnginfo)