mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(nodes): fix uploading image metadata retention
was causing failure to save images
This commit is contained in:
parent
94d0c18cbd
commit
ae05d34584
@ -372,8 +372,9 @@ class UIConfigBase(BaseModel):
|
||||
decorators, though you may add this class to a node definition to specify the title and tags.
|
||||
"""
|
||||
|
||||
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
|
||||
title: Optional[str] = Field(default=None, description="The display name of the node")
|
||||
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||
category: Optional[str] = Field(default=None, description="The node's category")
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
@ -469,6 +470,8 @@ class BaseInvocation(ABC, BaseModel):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig and hasattr(uiconfig, "category"):
|
||||
schema["category"] = uiconfig.category
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type", "id"])
|
||||
@ -558,3 +561,39 @@ def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def category(category: str) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Adds a category to the invocation. This is used to group invocations in the UI."""
|
||||
|
||||
def wrapper(cls: Type[T]) -> Type[T]:
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||
cls.UIConfig.category = category
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def node(
|
||||
title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None
|
||||
) -> Callable[[Type[T]], Type[T]]:
|
||||
"""
|
||||
Adds metadata to the invocation as a decorator.
|
||||
|
||||
:param Optional[str] title: Adds a title to the node. Use if the auto-generated title isn't quite right. Defaults to None.
|
||||
:param Optional[list[str]] tags: Adds tags to the node. Nodes may be searched for by their tags. Defaults to None.
|
||||
:param Optional[str] category: Adds a category to the node. Used to group the nodes in the UI. Defaults to None.
|
||||
"""
|
||||
|
||||
def wrapper(cls: Type[T]) -> Type[T]:
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||
cls.UIConfig.title = title
|
||||
cls.UIConfig.tags = tags
|
||||
cls.UIConfig.category = category
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
@ -8,11 +8,10 @@ from pydantic import validator
|
||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, node
|
||||
|
||||
|
||||
@title("Integer Range")
|
||||
@tags("collection", "integer", "range")
|
||||
@node(title="Integer Range", tags=["collection", "integer", "range"], category="collections")
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
||||
@ -33,8 +32,7 @@ class RangeInvocation(BaseInvocation):
|
||||
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||
|
||||
|
||||
@title("Integer Range of Size")
|
||||
@tags("range", "integer", "size", "collection")
|
||||
@node(title="Integer Range of Size", tags=["collection", "integer", "size", "range"], category="collections")
|
||||
class RangeOfSizeInvocation(BaseInvocation):
|
||||
"""Creates a range from start to start + size with step"""
|
||||
|
||||
@ -49,8 +47,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
||||
|
||||
|
||||
@title("Random Range")
|
||||
@tags("range", "integer", "random", "collection")
|
||||
@node(title="Random Range", tags=["range", "integer", "random", "collection"], category="collections")
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
||||
|
@ -26,6 +26,7 @@ from .baseinvocation import (
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
category,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
@ -44,8 +45,9 @@ class ConditioningFieldData:
|
||||
# PerpNeg = "perp_neg"
|
||||
|
||||
|
||||
@title("Compel Prompt")
|
||||
@title("Prompt")
|
||||
@tags("prompt", "compel")
|
||||
@category("conditioning")
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
@ -265,8 +267,9 @@ class SDXLPromptInvocationBase:
|
||||
return c, c_pooled, ec
|
||||
|
||||
|
||||
@title("SDXL Compel Prompt")
|
||||
@title("SDXL Prompt")
|
||||
@tags("sdxl", "compel", "prompt")
|
||||
@category("conditioning")
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
@ -324,8 +327,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
)
|
||||
|
||||
|
||||
@title("SDXL Refiner Compel Prompt")
|
||||
@title("SDXL Refiner Prompt")
|
||||
@tags("sdxl", "compel", "prompt")
|
||||
@category("conditioning")
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
@ -381,6 +385,7 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
@title("CLIP Skip")
|
||||
@tags("clipskip", "clip", "skip")
|
||||
@category("conditioning")
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
|
@ -40,6 +40,8 @@ from .baseinvocation import (
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
category,
|
||||
node,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
@ -96,8 +98,7 @@ class ControlOutput(BaseInvocationOutput):
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@title("ControlNet")
|
||||
@tags("controlnet")
|
||||
@node(title="ControlNet", tags=["controlnet"], category="controlnet")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
@ -177,6 +178,7 @@ class ImageProcessorInvocation(BaseInvocation):
|
||||
|
||||
@title("Canny Processor")
|
||||
@tags("controlnet", "canny")
|
||||
@category("controlnet")
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
|
||||
@ -198,6 +200,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("HED (softedge) Processor")
|
||||
@tags("controlnet", "hed", "softedge")
|
||||
@category("controlnet")
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
|
||||
@ -225,6 +228,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("Lineart Processor")
|
||||
@tags("controlnet", "lineart")
|
||||
@category("controlnet")
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
|
||||
@ -245,6 +249,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("Lineart Anime Processor")
|
||||
@tags("controlnet", "lineart", "anime")
|
||||
@category("controlnet")
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
|
||||
@ -266,6 +271,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("Openpose Processor")
|
||||
@tags("controlnet", "openpose", "pose")
|
||||
@category("controlnet")
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Openpose processing to image"""
|
||||
|
||||
@ -289,6 +295,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("Midas (Depth) Processor")
|
||||
@tags("controlnet", "midas", "depth")
|
||||
@category("controlnet")
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
|
||||
@ -314,6 +321,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("Normal BAE Processor")
|
||||
@tags("controlnet", "normal", "bae")
|
||||
@category("controlnet")
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
|
||||
@ -333,6 +341,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("MLSD Processor")
|
||||
@tags("controlnet", "mlsd")
|
||||
@category("controlnet")
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
|
||||
@ -358,6 +367,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("PIDI Processor")
|
||||
@tags("controlnet", "pidi")
|
||||
@category("controlnet")
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
|
||||
@ -383,6 +393,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("Content Shuffle Processor")
|
||||
@tags("controlnet", "contentshuffle")
|
||||
@category("controlnet")
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
|
||||
@ -411,6 +422,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
@title("Zoe (Depth) Processor")
|
||||
@tags("controlnet", "zoe", "depth")
|
||||
@category("controlnet")
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
@ -424,6 +436,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("Mediapipe Face Processor")
|
||||
@tags("controlnet", "mediapipe", "face")
|
||||
@category("controlnet")
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
|
||||
@ -445,6 +458,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("Leres (Depth) Processor")
|
||||
@tags("controlnet", "leres", "depth")
|
||||
@category("controlnet")
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
|
||||
@ -472,6 +486,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("Tile Resample Processor")
|
||||
@tags("controlnet", "tile")
|
||||
@category("controlnet")
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
|
||||
@ -510,6 +525,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
@title("Segment Anything Processor")
|
||||
@tags("controlnet", "segmentanything")
|
||||
@category("controlnet")
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
|
||||
|
@ -8,11 +8,12 @@ from PIL import Image, ImageOps
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, category, tags, title
|
||||
|
||||
|
||||
@title("OpenCV Inpaint")
|
||||
@tags("opencv", "inpaint")
|
||||
@category("inpaint")
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
|
@ -119,13 +119,20 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
|
||||
if workflow is not None:
|
||||
pnginfo.add_text("invokeai_workflow", workflow)
|
||||
# For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back
|
||||
for item_name, item in image.info.items():
|
||||
pnginfo.add_text(item_name, item)
|
||||
if metadata is not None or workflow is not None:
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
|
||||
if workflow is not None:
|
||||
pnginfo.add_text("invokeai_workflow", workflow)
|
||||
else:
|
||||
# For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back
|
||||
# TODO: retain non-invokeai metadata on save...
|
||||
original_metadata = image.info.get("invokeai_metadata", None)
|
||||
if original_metadata is not None:
|
||||
pnginfo.add_text("invokeai_metadata", original_metadata)
|
||||
original_workflow = image.info.get("invokeai_workflow", None)
|
||||
if original_workflow is not None:
|
||||
pnginfo.add_text("invokeai_workflow", original_workflow)
|
||||
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user