InvokeAI/invokeai/app/invocations/collections.py
psychedelicious c48fd9c083 feat(nodes): refactor parameter/primitive nodes
Refine concept of "parameter" nodes to "primitives":
- integer
- float
- string
- boolean
- image
- latents
- conditioning
- color

Each primitive has:
- A field definition, if it is not already python primitive value. The field is how this primitive value is passed between nodes. Collections are lists of the field in node definitions. ex: `ImageField` & `list[ImageField]`
- A single output class. ex: `ImageOutput`
- A collection output class. ex: `ImageCollectionOutput`
- A node, which functions to load or pass on the primitive value. ex: `ImageInvocation` (in this case, `ImageInvocation` replaces `LoadImage`)

Plus a number of related changes:
- Reorganize these into `primitives.py`
- Update all nodes and logic to use primitives
- Consolidate "prompt" outputs into "string" & "mask" into "image" (there's no reason for these to be different, the function identically)
- Update default graphs & tests
- Regen frontend types & minor frontend tidy related to changes
2023-08-16 09:54:38 +10:00

89 lines
3.3 KiB
Python

# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import Literal
import numpy as np
from pydantic import validator
from invokeai.app.invocations.primitives import ImageCollectionOutput, ImageField, IntegerCollectionOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UITypeHint, tags, title
@title("Integer Range")
@tags("collection", "integer", "range")
class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step"""
type: Literal["range"] = "range"
# Inputs
start: int = InputField(default=0, description="The start of the range")
stop: int = InputField(default=10, description="The stop of the range")
step: int = InputField(default=1, description="The step of the range")
@validator("stop")
def stop_gt_start(cls, v, values):
if "start" in values and v <= values["start"]:
raise ValueError("stop must be greater than start")
return v
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
@title("Integer Range of Size")
@tags("range", "integer", "size", "collection")
class RangeOfSizeInvocation(BaseInvocation):
"""Creates a range from start to start + size with step"""
type: Literal["range_of_size"] = "range_of_size"
# Inputs
start: int = InputField(default=0, description="The start of the range")
size: int = InputField(default=1, description="The number of values")
step: int = InputField(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
@title("Random Range")
@tags("range", "integer", "random", "collection")
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""
type: Literal["random_range"] = "random_range"
# Inputs
low: int = InputField(default=0, description="The inclusive low value")
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
size: int = InputField(default=1, description="The number of values to generate")
seed: int = InputField(
ge=0,
le=SEED_MAX,
description="The seed for the RNG (omit for random)",
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
rng = np.random.default_rng(self.seed)
return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
@title("Image Collection")
@tags("image", "collection")
class ImageCollectionInvocation(BaseInvocation):
"""Load a collection of images and provide it as output."""
type: Literal["image_collection"] = "image_collection"
# Inputs
images: list[ImageField] = InputField(
default=[], description="The image collection to load", ui_type_hint=UITypeHint.ImageCollection
)
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.images)