From fa169b55179cb8eb305050b0a5f05a1d6f4b7a6f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 4 Jul 2023 00:08:51 +1000 Subject: [PATCH] feat(nodes): add ImageCollection node in prep for batch processing --- invokeai/app/invocations/baseinvocation.py | 1 + invokeai/app/invocations/collections.py | 40 ++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 4ce3e839b6..1bf9353368 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -97,6 +97,7 @@ class UIConfig(TypedDict, total=False): "latents", "model", "control", + "image_collection", ], ] tags: List[str] diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 891f217317..33bde42d69 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -4,13 +4,16 @@ from typing import Literal import numpy as np from pydantic import Field, validator +from invokeai.app.models.image import ImageField from invokeai.app.util.misc import SEED_MAX, get_random_seed from .baseinvocation import ( BaseInvocation, + InvocationConfig, InvocationContext, BaseInvocationOutput, + UIConfig, ) @@ -22,6 +25,7 @@ class IntCollectionOutput(BaseInvocationOutput): # Outputs collection: list[int] = Field(default=[], description="The int collection") + class FloatCollectionOutput(BaseInvocationOutput): """A collection of floats""" @@ -31,6 +35,18 @@ class FloatCollectionOutput(BaseInvocationOutput): collection: list[float] = Field(default=[], description="The float collection") +class ImageCollectionOutput(BaseInvocationOutput): + """A collection of images""" + + type: Literal["image_collection"] = "image_collection" + + # Outputs + collection: list[ImageField] = Field(default=[], description="The output images") + + class Config: + schema_extra = {"required": ["type", "collection"]} + + class RangeInvocation(BaseInvocation): """Creates a range of numbers from start to stop with step""" @@ -92,3 +108,27 @@ class RandomRangeInvocation(BaseInvocation): return IntCollectionOutput( collection=list(rng.integers(low=self.low, high=self.high, size=self.size)) ) + + +class ImageCollectionInvocation(BaseInvocation): + """Load a collection of images and provide it as output.""" + + # fmt: off + type: Literal["image_collection"] = "image_collection" + + # Inputs + images: list[ImageField] = Field( + default=[], description="The image collection to load" + ) + # fmt: on + def invoke(self, context: InvocationContext) -> ImageCollectionOutput: + return ImageCollectionOutput(collection=self.images) + + class Config(InvocationConfig): + schema_extra = { + "ui": { + "type_hints": { + "images": "image_collection", + } + }, + }