feat(nodes): refactor parameter/primitive nodes

Refine concept of "parameter" nodes to "primitives":
- integer
- float
- string
- boolean
- image
- latents
- conditioning
- color

Each primitive has:
- A field definition, if it is not already python primitive value. The field is how this primitive value is passed between nodes. Collections are lists of the field in node definitions. ex: `ImageField` & `list[ImageField]`
- A single output class. ex: `ImageOutput`
- A collection output class. ex: `ImageCollectionOutput`
- A node, which functions to load or pass on the primitive value. ex: `ImageInvocation` (in this case, `ImageInvocation` replaces `LoadImage`)

Plus a number of related changes:
- Reorganize these into `primitives.py`
- Update all nodes and logic to use primitives
- Consolidate "prompt" outputs into "string" & "mask" into "image" (there's no reason for these to be different, the function identically)
- Update default graphs & tests
- Regen frontend types & minor frontend tidy related to changes
This commit is contained in:
psychedelicious 2023-08-14 19:41:29 +10:00
parent f49fc7fb55
commit c48fd9c083
24 changed files with 887 additions and 666 deletions

View File

@ -98,36 +98,49 @@ class UITypeHint(str, Enum):
on adding a new field type, which involves client-side changes.
"""
# region Primitives
Integer = "integer"
Float = "float"
Boolean = "boolean"
String = "string"
Enum = "enum"
Array = "array"
ImageField = "ImageField"
LatentsField = "LatentsField"
ConditioningField = "ConditioningField"
ControlField = "ControlField"
MainModelField = "MainModelField"
SDXLMainModelField = "SDXLMainModelField"
SDXLRefinerModelField = "SDXLRefinerModelField"
ONNXModelField = "ONNXModelField"
VaeModelField = "VaeModelField"
LoRAModelField = "LoRAModelField"
ControlNetModelField = "ControlNetModelField"
UNetField = "UNetField"
VaeField = "VaeField"
ClipField = "ClipField"
ColorField = "ColorField"
Image = "ImageField"
Latents = "LatentsField"
Conditioning = "ConditioningField"
Control = "ControlField"
Color = "ColorField"
ImageCollection = "ImageCollection"
ConditioningCollection = "ConditioningCollection"
ColorCollection = "ColorCollection"
LatentsCollection = "LatentsCollection"
IntegerCollection = "IntegerCollection"
FloatCollection = "FloatCollection"
StringCollection = "StringCollection"
BooleanCollection = "BooleanCollection"
# endregion
# region Models
MainModel = "MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VaeModel = "VaeModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
UNet = "UNetField"
Vae = "VaeField"
CLIP = "ClipField"
# endregion
# region Iterate/Collect
Collection = "Collection"
CollectionItem = "CollectionItem"
Seed = "Seed"
# endregion
# region Misc
FilePath = "FilePath"
Enum = "enum"
# endregion
class UIComponent(str, Enum):

View File

@ -5,63 +5,10 @@ from typing import Literal
import numpy as np
from pydantic import validator
from invokeai.app.models.image import ImageField
from invokeai.app.invocations.primitives import ImageCollectionOutput, ImageField, IntegerCollectionOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
UITypeHint,
tags,
title,
)
class IntCollectionOutput(BaseInvocationOutput):
"""A collection of integers"""
type: Literal["int_collection_output"] = "int_collection_output"
# Outputs
collection: list[int] = OutputField(
default=[], description="The int collection", ui_type_hint=UITypeHint.IntegerCollection
)
class FloatCollectionOutput(BaseInvocationOutput):
"""A collection of floats"""
type: Literal["float_collection_output"] = "float_collection_output"
# Outputs
collection: list[float] = OutputField(
default=[], description="The float collection", ui_type_hint=UITypeHint.FloatCollection
)
class StringCollectionOutput(BaseInvocationOutput):
"""A collection of strings"""
type: Literal["string_collection_output"] = "string_collection_output"
# Outputs
collection: list[str] = OutputField(
default=[], description="The output strings", ui_type_hint=UITypeHint.StringCollection
)
class ImageCollectionOutput(BaseInvocationOutput):
"""A collection of images"""
type: Literal["image_collection_output"] = "image_collection_output"
# Outputs
collection: list[ImageField] = OutputField(
default=[], description="The output images", ui_type_hint=UITypeHint.ImageCollection
)
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UITypeHint, tags, title
@title("Integer Range")
@ -82,8 +29,8 @@ class RangeInvocation(BaseInvocation):
raise ValueError("stop must be greater than start")
return v
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
@title("Integer Range of Size")
@ -98,8 +45,8 @@ class RangeOfSizeInvocation(BaseInvocation):
size: int = InputField(default=1, description="The number of values")
step: int = InputField(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
@title("Random Range")
@ -120,9 +67,9 @@ class RandomRangeInvocation(BaseInvocation):
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
rng = np.random.default_rng(self.seed)
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
@title("Image Collection")

View File

@ -5,7 +5,7 @@ from typing import List, Literal, Union
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from pydantic import BaseModel, Field
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
BasicConditioningInfo,
@ -32,13 +32,6 @@ from .baseinvocation import (
from .model import ClipField
class ConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning data")
class Config:
schema_extra = {"required": ["conditioning_name"]}
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo]
@ -51,16 +44,6 @@ class ConditioningFieldData:
# PerpNeg = "perp_neg"
class CompelOutput(BaseInvocationOutput):
"""Compel parser output"""
# fmt: off
type: Literal["compel_output"] = "compel_output"
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
# fmt: on
@title("Compel Prompt")
@tags("prompt", "compel")
class CompelInvocation(BaseInvocation):
@ -80,7 +63,7 @@ class CompelInvocation(BaseInvocation):
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
context=context,
@ -163,7 +146,7 @@ class CompelInvocation(BaseInvocation):
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
@ -303,7 +286,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
)
@ -336,7 +319,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
@ -361,7 +344,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
@ -384,7 +367,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),

View File

@ -26,8 +26,11 @@ from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -40,7 +43,6 @@ from .baseinvocation import (
tags,
title,
)
from ..models.image import ImageOutput
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]

View File

@ -5,10 +5,10 @@ from typing import Literal
import cv2 as cv
import numpy
from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .image import ImageOutput
@title("OpenCV Inpaint")

View File

@ -8,34 +8,14 @@ import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, ResourceOrigin
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
@title("Load Image")
@tags("image")
class LoadImageInvocation(BaseInvocation):
"""Load an image and provide it as output."""
# Metadata
type: Literal["load_image"] = "load_image"
# Inputs
image: ImageField = InputField(description="The image to load")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
return ImageOutput(
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@title("Show Image")
@tags("image")
class ShowImageInvocation(BaseInvocation):
@ -162,7 +142,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
image: ImageField = InputField(description="The image to create the mask from")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def invoke(self, context: InvocationContext) -> MaskOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
image_mask = image.split()[-1]
@ -178,8 +158,8 @@ class MaskFromAlphaInvocation(BaseInvocation):
is_intermediate=self.is_intermediate,
)
return MaskOutput(
mask=ImageField(image_name=image_dto.image_name),
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -608,7 +588,7 @@ class MaskEdgeInvocation(BaseInvocation):
description="Second threshold for the hysteresis procedure in Canny edge detection"
)
def invoke(self, context: InvocationContext) -> MaskOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.services.images.get_pil_image(self.image.image_name)
npimg = numpy.asarray(mask, dtype=numpy.uint8)
@ -633,8 +613,8 @@ class MaskEdgeInvocation(BaseInvocation):
is_intermediate=self.is_intermediate,
)
return MaskOutput(
mask=ImageField(image_name=image_dto.image_name),
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -5,13 +5,13 @@ from typing import Literal, Optional, get_args
import numpy as np
import math
from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ImageField, ImageOutput, ColorField
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UITypeHint, title, tags
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
def infill_methods() -> list[str]:

View File

@ -19,6 +19,13 @@ from pydantic import BaseModel, Field, validator
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import (
ImageField,
ImageOutput,
LatentsField,
LatentsOutput,
build_latents_output,
)
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
@ -35,7 +42,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -50,43 +57,11 @@ from .baseinvocation import (
)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
DEFAULT_PRECISION = choose_precision(choose_torch_device())
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
latents_name: str = Field(description="The name of the latents")
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
class Config:
schema_extra = {"required": ["latents_name"]}
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
type: Literal["latents_output"] = "latents_output"
# Inputs
latents: LatentsField = OutputField(
description=FieldDescriptions.latents,
)
width: int = OutputField(description=FieldDescriptions.width)
height: int = OutputField(description=FieldDescriptions.height)
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]

View File

@ -4,30 +4,9 @@ from typing import Literal
import numpy as np
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
InvocationContext,
OutputField,
tags,
title,
)
from invokeai.app.invocations.primitives import IntegerOutput
class IntOutput(BaseInvocationOutput):
"""An integer output"""
type: Literal["int_output"] = "int_output"
a: int = OutputField(default=None, description="The output integer")
class FloatOutput(BaseInvocationOutput):
"""A float output"""
type: Literal["float_output"] = "float_output"
a: float = OutputField(default=None, description="The output float")
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
@title("Add Integers")
@ -41,8 +20,8 @@ class AddInvocation(BaseInvocation):
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a + self.b)
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=self.a + self.b)
@title("Subtract Integers")
@ -56,8 +35,8 @@ class SubtractInvocation(BaseInvocation):
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a - self.b)
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=self.a - self.b)
@title("Multiply Integers")
@ -71,8 +50,8 @@ class MultiplyInvocation(BaseInvocation):
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a * self.b)
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=self.a * self.b)
@title("Divide Integers")
@ -86,8 +65,8 @@ class DivideInvocation(BaseInvocation):
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=int(self.a / self.b))
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=int(self.a / self.b))
@title("Random Integer")
@ -101,5 +80,5 @@ class RandomIntInvocation(BaseInvocation):
low: int = InputField(default=0, description="The inclusive low value")
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=np.random.randint(self.low, self.high))
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=np.random.randint(self.low, self.high))

View File

@ -7,6 +7,7 @@ from invokeai.app.invocations.baseinvocation import (
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
tags,
title,
)
@ -94,7 +95,7 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
metadata: CoreMetadata = Field(description="The core metadata for the image")
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
@title("Metadata Accumulator")

View File

@ -365,7 +365,7 @@ class VaeLoaderInvocation(BaseInvocation):
# Inputs
vae_model: VAEModelField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type_hint=UITypeHint.VaeModelField, title="VAE"
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type_hint=UITypeHint.VaeModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:

View File

@ -14,13 +14,14 @@ from pydantic import BaseModel, Field, validator
from tqdm import tqdm
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend import BaseModelType, ModelType, SubModelType
from ...backend.model_management import ONNXModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util import choose_torch_device
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -34,9 +35,7 @@ from .baseinvocation import (
tags,
title,
)
from .compel import CompelOutput, ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
from .model import ClipField, ModelInfo, UNetField, VaeField
@ -66,7 +65,7 @@ class ONNXPromptInvocation(BaseInvocation):
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
def invoke(self, context: InvocationContext) -> CompelOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
)
@ -135,7 +134,7 @@ class ONNXPromptInvocation(BaseInvocation):
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (prompt_embeds, None))
return CompelOutput(
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
@ -181,7 +180,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
default=None,
description=FieldDescriptions.control,
ui_type_hint=UITypeHint.ControlField,
ui_type_hint=UITypeHint.Control,
)
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
@ -417,7 +416,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
# Inputs
model: OnnxModelField = InputField(
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type_hint=UITypeHint.ONNXModelField
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type_hint=UITypeHint.ONNXModel
)
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:

View File

@ -42,9 +42,10 @@ from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
from pydantic import BaseModel, Field
from invokeai.app.invocations.primitives import FloatCollectionOutput
from ...backend.util.logging import InvokeAILogger
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .collections import FloatCollectionOutput
@title("Float Range")

View File

@ -1,81 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
tags,
title,
)
from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs
@title("Integer Parameter")
@tags("integer")
class ParamIntInvocation(BaseInvocation):
"""An integer parameter"""
type: Literal["param_int"] = "param_int"
# Inputs
a: int = InputField(default=0, description="The integer value")
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)
@title("Float Parameter")
@tags("float")
class ParamFloatInvocation(BaseInvocation):
"""A float parameter"""
type: Literal["param_float"] = "param_float"
# Inputs
param: float = InputField(default=0.0, description="The float value")
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(a=self.param)
class StringOutput(BaseInvocationOutput):
"""A string output"""
type: Literal["string_output"] = "string_output"
text: str = OutputField(description="The output string")
@title("String Parameter")
@tags("string")
class ParamStringInvocation(BaseInvocation):
"""A string parameter"""
type: Literal["param_string"] = "param_string"
# Inputs
text: str = InputField(default="", description="The string value")
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text)
@title("Prompt Parameter")
@tags("prompt")
class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter"""
type: Literal["param_prompt"] = "param_prompt"
# Inputs
prompt: str = InputField(default="", description="The prompt value")
def invoke(self, context: InvocationContext) -> PromptOutput:
return PromptOutput(prompt=self.prompt)

View File

@ -0,0 +1,381 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Tuple
from pydantic import BaseModel, Field
import torch
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIComponent,
UITypeHint,
tags,
title,
)
"""
Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
- primitive nodes
- primitive outputs
- primitive collection outputs
"""
# region Boolean
class BooleanOutput(BaseInvocationOutput):
"""Base class for nodes that output a single boolean"""
type: Literal["boolean_output"] = "boolean_output"
a: bool = OutputField(description="The output boolean")
class BooleanCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of booleans"""
type: Literal["boolean_collection_output"] = "boolean_collection_output"
# Outputs
collection: list[bool] = OutputField(
default_factory=list, description="The output boolean collection", ui_type_hint=UITypeHint.BooleanCollection
)
@title("Boolean Primitive")
@tags("boolean")
class BooleanInvocation(BaseInvocation):
"""A boolean primitive value"""
type: Literal["boolean"] = "boolean"
# Inputs
a: bool = InputField(default=False, description="The boolean value")
def invoke(self, context: InvocationContext) -> BooleanOutput:
return BooleanOutput(a=self.a)
# endregion
# region Integer
class IntegerOutput(BaseInvocationOutput):
"""Base class for nodes that output a single integer"""
type: Literal["integer_output"] = "integer_output"
a: int = OutputField(description="The output integer")
class IntegerCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of integers"""
type: Literal["integer_collection_output"] = "integer_collection_output"
# Outputs
collection: list[int] = OutputField(
default_factory=list, description="The int collection", ui_type_hint=UITypeHint.IntegerCollection
)
@title("Integer Primitive")
@tags("integer")
class IntegerInvocation(BaseInvocation):
"""An integer primitive value"""
type: Literal["integer"] = "integer"
# Inputs
a: int = InputField(default=0, description="The integer value")
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(a=self.a)
# endregion
# region Float
class FloatOutput(BaseInvocationOutput):
"""Base class for nodes that output a single float"""
type: Literal["float_output"] = "float_output"
a: float = OutputField(description="The output float")
class FloatCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of floats"""
type: Literal["float_collection_output"] = "float_collection_output"
# Outputs
collection: list[float] = OutputField(
default_factory=list, description="The float collection", ui_type_hint=UITypeHint.FloatCollection
)
@title("Float Primitive")
@tags("float")
class FloatInvocation(BaseInvocation):
"""A float primitive value"""
type: Literal["float"] = "float"
# Inputs
param: float = InputField(default=0.0, description="The float value")
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(a=self.param)
# endregion
# region String
class StringOutput(BaseInvocationOutput):
"""Base class for nodes that output a single string"""
type: Literal["string_output"] = "string_output"
text: str = OutputField(description="The output string")
class StringCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of strings"""
type: Literal["string_collection_output"] = "string_collection_output"
# Outputs
collection: list[str] = OutputField(
default_factory=list, description="The output strings", ui_type_hint=UITypeHint.StringCollection
)
@title("String Primitive")
@tags("string")
class StringInvocation(BaseInvocation):
"""A string primitive value"""
type: Literal["string"] = "string"
# Inputs
text: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text)
# endregion
# region Image
class ImageField(BaseModel):
"""An image primitive field"""
image_name: str = Field(description="The name of the image")
class ImageOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""
type: Literal["image_output"] = "image_output"
image: ImageField = OutputField(description="The output image")
width: int = OutputField(description="The width of the image in pixels")
height: int = OutputField(description="The height of the image in pixels")
class ImageCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of images"""
type: Literal["image_collection_output"] = "image_collection_output"
# Outputs
collection: list[ImageField] = OutputField(
default_factory=list, description="The output images", ui_type_hint=UITypeHint.ImageCollection
)
@title("Image Primitive")
@tags("image")
class ImageInvocation(BaseInvocation):
"""An image primitive value"""
# Metadata
type: Literal["image"] = "image"
# Inputs
image: ImageField = InputField(description="The image to load")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
return ImageOutput(
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
# endregion
# region Latents
class LatentsField(BaseModel):
"""A latents tensor primitive field"""
latents_name: str = Field(description="The name of the latents")
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
class LatentsOutput(BaseInvocationOutput):
"""Base class for nodes that output a single latents tensor"""
type: Literal["latents_output"] = "latents_output"
latents: LatentsField = OutputField(
description=FieldDescriptions.latents,
)
width: int = OutputField(description=FieldDescriptions.width)
height: int = OutputField(description=FieldDescriptions.height)
class LatentsCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of latents tensors"""
type: Literal["latents_collection_output"] = "latents_collection_output"
latents: list[LatentsField] = OutputField(
default_factory=list,
description=FieldDescriptions.latents,
ui_type_hint=UITypeHint.LatentsCollection,
)
@title("Latents Primitive")
@tags("latents")
class LatentsInvocation(BaseInvocation):
"""A latents tensor primitive value"""
type: Literal["latents"] = "latents"
# Inputs
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
return build_latents_output(self.latents.latents_name, latents)
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
# endregion
# region Color
class ColorField(BaseModel):
"""A color primitive field"""
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ColorOutput(BaseInvocationOutput):
"""Base class for nodes that output a single color"""
type: Literal["color_output"] = "color_output"
color: ColorField = OutputField(description="The output color")
class ColorCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of colors"""
type: Literal["color_collection_output"] = "color_collection_output"
# Outputs
collection: list[ColorField] = OutputField(
default_factory=list, description="The output colors", ui_type_hint=UITypeHint.ColorCollection
)
@title("Color Primitive")
@tags("color")
class ColorInvocation(BaseInvocation):
"""A color primitive value"""
type: Literal["color"] = "color"
# Inputs
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
def invoke(self, context: InvocationContext) -> ColorOutput:
return ColorOutput(color=self.color)
# endregion
# region Conditioning
class ConditioningField(BaseModel):
"""A conditioning tensor primitive field"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""
type: Literal["conditioning_output"] = "conditioning_output"
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
class ConditioningCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of conditioning tensors"""
type: Literal["conditioning_collection_output"] = "conditioning_collection_output"
# Outputs
collection: list[ConditioningField] = OutputField(
default_factory=list,
description="The output conditioning tensors",
ui_type_hint=UITypeHint.ConditioningCollection,
)
@title("Conditioning Primitive")
@tags("conditioning")
class ConditioningInvocation(BaseInvocation):
"""A conditioning tensor primitive value"""
type: Literal["conditioning"] = "conditioning"
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
def invoke(self, context: InvocationContext) -> ConditioningOutput:
return ConditioningOutput(conditioning=self.conditioning)
# endregion

View File

@ -1,40 +1,13 @@
from os.path import exists
from typing import Literal, Optional
from typing import Literal, Optional, Union
import numpy as np
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
from pydantic import validator
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
UIComponent,
UITypeHint,
title,
tags,
)
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
from invokeai.app.invocations.primitives import StringCollectionOutput
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
type: Literal["prompt"] = "prompt"
prompt: str = OutputField(description="The output prompt")
class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a collection of prompts"""
type: Literal["prompt_collection_output"] = "prompt_collection_output"
prompt_collection: list[str] = OutputField(
description="The output prompt collection", ui_type_hint=UITypeHint.StringCollection
)
count: int = OutputField(description="The size of the prompt collection")
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UITypeHint, tags, title
@title("Dynamic Prompt")
@ -49,7 +22,7 @@ class DynamicPromptInvocation(BaseInvocation):
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
if self.combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
@ -57,7 +30,7 @@ class DynamicPromptInvocation(BaseInvocation):
generator = RandomPromptGenerator()
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
return StringCollectionOutput(collection=prompts)
@title("Prompts from File")
@ -70,10 +43,10 @@ class PromptsFromFileInvocation(BaseInvocation):
# Inputs
file_path: str = InputField(description="Path to prompt text file", ui_type_hint=UITypeHint.FilePath)
pre_prompt: Optional[str] = InputField(
description="String to prepend to each prompt", ui_component=UIComponent.Textarea
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
)
post_prompt: Optional[str] = InputField(
description="String to append to each prompt", ui_component=UIComponent.Textarea
default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea
)
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
@ -84,7 +57,14 @@ class PromptsFromFileInvocation(BaseInvocation):
raise ValueError(FileNotFoundError)
return v
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
def promptsFromFile(
self,
file_path: str,
pre_prompt: Union[str, None],
post_prompt: Union[str, None],
start_line: int,
max_prompts: int,
):
prompts = []
start_line -= 1
end_line = start_line + max_prompts
@ -98,8 +78,8 @@ class PromptsFromFileInvocation(BaseInvocation):
break
return prompts
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
prompts = self.promptsFromFile(
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
return StringCollectionOutput(collection=prompts)

View File

@ -46,7 +46,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
# Inputs
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type_hint=UITypeHint.SDXLMainModelField
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type_hint=UITypeHint.SDXLMainModel
)
# TODO: precision?
@ -133,7 +133,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct,
ui_type_hint=UITypeHint.SDXLRefinerModelField,
ui_type_hint=UITypeHint.SDXLRefinerModel,
)
# TODO: precision?

View File

@ -7,11 +7,11 @@ import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from realesrgan import RealESRGANer
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
from .image import ImageOutput
# TODO: Populate this from disk?
# TODO: Use model manager to load?

View File

@ -1,30 +1,8 @@
from enum import Enum
from typing import Optional, Tuple, Literal
from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum
from ..invocations.baseinvocation import (
BaseInvocationOutput,
)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_name: str = Field(description="The name of the image")
class Config:
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
@ -35,39 +13,6 @@ class ProgressImage(BaseModel):
dataURL: str = Field(description="The image data as a b64 data URL")
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
"""The origin of a resource (eg image).

View File

@ -2,7 +2,7 @@ from ..invocations.latent import LatentsToImageInvocation, DenoiseLatentsInvocat
from ..invocations.image import ImageNSFWBlurInvocation
from ..invocations.noise import NoiseInvocation
from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation
from ..invocations.primitives import IntegerInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
from .item_storage import ItemStorageABC
@ -17,9 +17,9 @@ def create_text_to_image() -> LibraryGraph:
description="Converts text to an image",
graph=Graph(
nodes={
"width": ParamIntInvocation(id="width", a=512),
"height": ParamIntInvocation(id="height", a=512),
"seed": ParamIntInvocation(id="seed", a=-1),
"width": IntegerInvocation(id="width", a=512),
"height": IntegerInvocation(id="height", a=512),
"seed": IntegerInvocation(id="seed", a=-1),
"3": NoiseInvocation(id="3"),
"4": CompelInvocation(id="4"),
"5": CompelInvocation(id="5"),

View File

@ -3,8 +3,6 @@ import {
GraphExecutionState,
GraphInvocationOutput,
ImageOutput,
MaskOutput,
PromptOutput,
IterateInvocationOutput,
CollectInvocationOutput,
ImageField,
@ -48,14 +46,6 @@ export const isLatentsOutput = (
output: GraphExecutionState['results'][string]
): output is LatentsOutput => output?.type === 'latents_output';
export const isMaskOutput = (
output: GraphExecutionState['results'][string]
): output is MaskOutput => output?.type === 'mask';
export const isPromptOutput = (
output: GraphExecutionState['results'][string]
): output is PromptOutput => output?.type === 'prompt';
export const isGraphOutput = (
output: GraphExecutionState['results'][string]
): output is GraphInvocationOutput => output?.type === 'graph_output';

View File

@ -548,6 +548,69 @@ export type components = {
*/
file: Blob;
};
/**
* BooleanCollectionOutput
* @description Base class for nodes that output a collection of booleans
*/
BooleanCollectionOutput: {
/**
* Type
* @default boolean_collection_output
* @enum {string}
*/
type?: "boolean_collection_output";
/**
* Collection
* @description The output boolean collection
*/
collection?: (boolean)[];
};
/**
* Boolean Primitive
* @description A boolean primitive value
*/
BooleanInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default bool
* @enum {string}
*/
type: "bool";
/**
* A
* @description The boolean value
* @default false
*/
a?: boolean;
};
/**
* BooleanOutput
* @description Base class for nodes that output a single boolean
*/
BooleanOutput: {
/**
* Type
* @default boolean_output
* @enum {string}
*/
type?: "boolean_output";
/**
* A
* @description The output boolean
*/
a: boolean;
};
/**
* Canny Processor
* @description Canny edge detection for ControlNet
@ -712,6 +775,23 @@ export type components = {
*/
collection: (unknown)[];
};
/**
* ColorCollectionOutput
* @description Base class for nodes that output a collection of colors
*/
ColorCollectionOutput: {
/**
* Type
* @default color_collection_output
* @enum {string}
*/
type?: "color_collection_output";
/**
* Collection
* @description The output colors
*/
collection?: (components["schemas"]["ColorField"])[];
};
/**
* Color Correct
* @description Shifts the colors of a target image to match the reference image, optionally
@ -757,7 +837,10 @@ export type components = {
*/
mask_blur_radius?: number;
};
/** ColorField */
/**
* ColorField
* @description A color primitive field
*/
ColorField: {
/**
* R
@ -780,6 +863,57 @@ export type components = {
*/
a: number;
};
/**
* Color Primitive
* @description A color primitive value
*/
ColorInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default color
* @enum {string}
*/
type: "color";
/**
* Color
* @description The color value
* @default {
* "r": 0,
* "g": 0,
* "b": 0,
* "a": 255
* }
*/
color?: components["schemas"]["ColorField"];
};
/**
* ColorOutput
* @description Base class for nodes that output a single color
*/
ColorOutput: {
/**
* Type
* @default color_output
* @enum {string}
*/
type?: "color_output";
/**
* Color
* @description The output color
*/
color: components["schemas"]["ColorField"];
};
/**
* Compel Prompt
* @description Parse prompt using compel package to conditioning.
@ -815,30 +949,78 @@ export type components = {
clip?: components["schemas"]["ClipField"];
};
/**
* CompelOutput
* @description Compel parser output
* ConditioningCollectionOutput
* @description Base class for nodes that output a collection of conditioning tensors
*/
CompelOutput: {
ConditioningCollectionOutput: {
/**
* Type
* @default compel_output
* @default conditioning_collection_output
* @enum {string}
*/
type?: "compel_output";
type?: "conditioning_collection_output";
/**
* Collection
* @description The output conditioning tensors
*/
collection?: (components["schemas"]["ConditioningField"])[];
};
/**
* ConditioningField
* @description A conditioning tensor primitive field
*/
ConditioningField: {
/**
* Conditioning Name
* @description The name of conditioning tensor
*/
conditioning_name: string;
};
/**
* Conditioning
* @description A conditioning tensor primitive value
*/
ConditioningInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default conditioning
* @enum {string}
*/
type: "conditioning";
/**
* Conditioning
* @description Conditioning tensor
*/
conditioning?: components["schemas"]["ConditioningField"];
};
/**
* ConditioningOutput
* @description Base class for nodes that output a single conditioning tensor
*/
ConditioningOutput: {
/**
* Type
* @default conditioning_output
* @enum {string}
*/
type?: "conditioning_output";
/**
* Conditioning
* @description Conditioning tensor
*/
conditioning: components["schemas"]["ConditioningField"];
};
/** ConditioningField */
ConditioningField: {
/**
* Conditioning Name
* @description The name of conditioning data
*/
conditioning_name: string;
};
/**
* Content Shuffle Processor
* @description Applies content shuffle processing to image
@ -1507,7 +1689,7 @@ export type components = {
};
/**
* FloatCollectionOutput
* @description A collection of floats
* @description Base class for nodes that output a collection of floats
*/
FloatCollectionOutput: {
/**
@ -1518,11 +1700,39 @@ export type components = {
type?: "float_collection_output";
/**
* Collection
* @description The float collection
* @default []
* @description The float collection
*/
collection?: (number)[];
};
/**
* Float Primitive
* @description A float primitive value
*/
FloatInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default float
* @enum {string}
*/
type: "float";
/**
* Param
* @description The float value
* @default 0
*/
param?: number;
};
/**
* Float Range
* @description Creates a range
@ -1566,7 +1776,7 @@ export type components = {
};
/**
* FloatOutput
* @description A float output
* @description Base class for nodes that output a single float
*/
FloatOutput: {
/**
@ -1579,7 +1789,7 @@ export type components = {
* A
* @description The output float
*/
a?: number;
a: number;
};
/** Graph */
Graph: {
@ -1593,7 +1803,7 @@ export type components = {
* @description The nodes in this graph
*/
nodes?: {
[key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"]) | undefined;
[key: string]: (components["schemas"]["BooleanInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"]) | undefined;
};
/**
* Edges
@ -1636,7 +1846,7 @@ export type components = {
* @description The results of node executions
*/
results: {
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
[key: string]: (components["schemas"]["BooleanOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
};
/**
* Errors
@ -1870,7 +2080,7 @@ export type components = {
};
/**
* ImageCollectionOutput
* @description A collection of images
* @description Base class for nodes that output a collection of images
*/
ImageCollectionOutput: {
/**
@ -1881,8 +2091,7 @@ export type components = {
type?: "image_collection_output";
/**
* Collection
* @description The output images
* @default []
* @description The output images
*/
collection?: (components["schemas"]["ImageField"])[];
};
@ -2045,7 +2254,7 @@ export type components = {
};
/**
* ImageField
* @description An image field used for passing image objects between invocations
* @description An image primitive field
*/
ImageField: {
/**
@ -2128,6 +2337,34 @@ export type components = {
*/
max?: number;
};
/**
* Image
* @description An image primitive value
*/
ImageInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default image
* @enum {string}
*/
type: "image";
/**
* Image
* @description The image to load
*/
image?: components["schemas"]["ImageField"];
};
/**
* Lerp Image
* @description Linear interpolation of all pixels of an image
@ -2286,7 +2523,7 @@ export type components = {
};
/**
* ImageOutput
* @description Base class for invocations that output an image
* @description Base class for nodes that output a single image
*/
ImageOutput: {
/**
@ -2294,7 +2531,7 @@ export type components = {
* @default image_output
* @enum {string}
*/
type: "image_output";
type?: "image_output";
/**
* Image
* @description The output image
@ -2746,10 +2983,10 @@ export type components = {
seed?: number;
};
/**
* IntCollectionOutput
* @description A collection of integers
* IntegerCollectionOutput
* @description Base class for nodes that output a collection of integers
*/
IntCollectionOutput: {
IntegerCollectionOutput: {
/**
* Type
* @default int_collection_output
@ -2758,16 +2995,44 @@ export type components = {
type?: "int_collection_output";
/**
* Collection
* @description The int collection
* @default []
* @description The int collection
*/
collection?: (number)[];
};
/**
* IntOutput
* @description An integer output
* Integer Primitive
* @description An integer primitive value
*/
IntOutput: {
IntegerInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default int
* @enum {string}
*/
type: "int";
/**
* A
* @description The integer value
* @default 0
*/
a?: number;
};
/**
* IntegerOutput
* @description Base class for nodes that output a single integer
*/
IntegerOutput: {
/**
* Type
* @default int_output
@ -2778,7 +3043,7 @@ export type components = {
* A
* @description The output integer
*/
a?: number;
a: number;
};
/**
* IterateInvocation
@ -2831,9 +3096,26 @@ export type components = {
*/
item?: unknown;
};
/**
* LatentsCollectionOutput
* @description Base class for nodes that output a collection of latents tensors
*/
LatentsCollectionOutput: {
/**
* Type
* @default latents_collection_output
* @enum {string}
*/
type?: "latents_collection_output";
/**
* Latents
* @description Latents tensor
*/
latents?: (components["schemas"]["LatentsField"])[];
};
/**
* LatentsField
* @description A latents field used for passing latents between invocations
* @description A latents tensor primitive field
*/
LatentsField: {
/**
@ -2847,9 +3129,37 @@ export type components = {
*/
seed?: number;
};
/**
* Latents
* @description A latents tensor primitive value
*/
LatentsInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default latents
* @enum {string}
*/
type: "latents";
/**
* Latents
* @description The latents tensor
*/
latents?: components["schemas"]["LatentsField"];
};
/**
* LatentsOutput
* @description Base class for invocations that output latents
* @description Base class for nodes that output a single latents tensor
*/
LatentsOutput: {
/**
@ -3120,34 +3430,6 @@ export type components = {
* @enum {string}
*/
LoRAModelFormat: "lycoris" | "diffusers";
/**
* Load Image
* @description Load an image and provide it as output.
*/
LoadImageInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default load_image
* @enum {string}
*/
type: "load_image";
/**
* Image
* @description The image to load
*/
image?: components["schemas"]["ImageField"];
};
/**
* LogLevel
* @description An enumeration.
@ -3397,33 +3679,6 @@ export type components = {
*/
invert?: boolean;
};
/**
* MaskOutput
* @description Base class for invocations that output a mask
*/
MaskOutput: {
/**
* Type
* @default mask
* @enum {string}
*/
type: "mask";
/**
* Mask
* @description The output mask
*/
mask: components["schemas"]["ImageField"];
/**
* Width
* @description The width of the mask in pixels
*/
width?: number;
/**
* Height
* @description The height of the mask in pixels
*/
height?: number;
};
/**
* Mediapipe Face Processor
* @description Applies mediapipe face processing to image
@ -4342,122 +4597,6 @@ export type components = {
*/
total: number;
};
/**
* Float Parameter
* @description A float parameter
*/
ParamFloatInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default param_float
* @enum {string}
*/
type: "param_float";
/**
* Param
* @description The float value
* @default 0
*/
param?: number;
};
/**
* Integer Parameter
* @description An integer parameter
*/
ParamIntInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default param_int
* @enum {string}
*/
type: "param_int";
/**
* A
* @description The integer value
* @default 0
*/
a?: number;
};
/**
* Prompt Parameter
* @description A prompt input parameter
*/
ParamPromptInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default param_prompt
* @enum {string}
*/
type: "param_prompt";
/**
* Prompt
* @description The prompt value
* @default
*/
prompt?: string;
};
/**
* String Parameter
* @description A string parameter
*/
ParamStringInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default param_string
* @enum {string}
*/
type: "param_string";
/**
* Text
* @description The string value
* @default
*/
text?: string;
};
/**
* PIDI Processor
* @description Applies PIDI processing to image
@ -4510,45 +4649,6 @@ export type components = {
*/
scribble?: boolean;
};
/**
* PromptCollectionOutput
* @description Base class for invocations that output a collection of prompts
*/
PromptCollectionOutput: {
/**
* Type
* @default prompt_collection_output
* @enum {string}
*/
type?: "prompt_collection_output";
/**
* Prompt Collection
* @description The output prompt collection
*/
prompt_collection: (string)[];
/**
* Count
* @description The size of the prompt collection
*/
count: number;
};
/**
* PromptOutput
* @description Base class for invocations that output a prompt
*/
PromptOutput: {
/**
* Type
* @default prompt
* @enum {string}
*/
type?: "prompt";
/**
* Prompt
* @description The output prompt
*/
prompt: string;
};
/**
* Prompts from File
* @description Loads prompts from a text file
@ -5499,7 +5599,7 @@ export type components = {
};
/**
* StringCollectionOutput
* @description A collection of strings
* @description Base class for nodes that output a collection of strings
*/
StringCollectionOutput: {
/**
@ -5510,14 +5610,42 @@ export type components = {
type?: "string_collection_output";
/**
* Collection
* @description The output strings
* @default []
* @description The output strings
*/
collection?: (string)[];
};
/**
* String Primitive
* @description A string primitive value
*/
StringInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default string
* @enum {string}
*/
type: "string";
/**
* Text
* @description The string value
* @default
*/
text?: string;
};
/**
* StringOutput
* @description A string output
* @description Base class for nodes that output a single string
*/
StringOutput: {
/**
@ -5815,7 +5943,7 @@ export type components = {
* If a field should be provided a data type that does not exactly match the python type of the field, use this to provide the type that should be used instead. See the node development docs for detail on adding a new field type, which involves client-side changes.
* @enum {string}
*/
UITypeHint: "integer" | "float" | "boolean" | "string" | "enum" | "array" | "ImageField" | "LatentsField" | "ConditioningField" | "ControlField" | "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VaeModelField" | "LoRAModelField" | "ControlNetModelField" | "UNetField" | "VaeField" | "ClipField" | "ColorField" | "ImageCollection" | "IntegerCollection" | "FloatCollection" | "StringCollection" | "BooleanCollection" | "Collection" | "CollectionItem" | "Seed" | "FilePath";
UITypeHint: "integer" | "float" | "boolean" | "string" | "array" | "ImageField" | "LatentsField" | "ConditioningField" | "ControlField" | "ColorField" | "ImageCollection" | "ConditioningCollection" | "ColorCollection" | "LatentsCollection" | "IntegerCollection" | "FloatCollection" | "StringCollection" | "BooleanCollection" | "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VaeModelField" | "LoRAModelField" | "ControlNetModelField" | "UNetField" | "VaeField" | "ClipField" | "Collection" | "CollectionItem" | "FilePath" | "enum";
/**
* UIComponent
* @description The type of UI component to use for a field, used to override the default components, which are inferred from the field type.
@ -5988,7 +6116,7 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"];
"application/json": components["schemas"]["BooleanInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"];
};
};
responses: {
@ -6025,7 +6153,7 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"];
"application/json": components["schemas"]["BooleanInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"];
};
};
responses: {

View File

@ -155,8 +155,6 @@ export type ZoeDepthImageProcessorInvocation =
// Node Outputs
export type ImageOutput = s['ImageOutput'];
export type MaskOutput = s['MaskOutput'];
export type PromptOutput = s['PromptOutput'];
export type IterateInvocationOutput = s['IterateInvocationOutput'];
export type CollectInvocationOutput = s['CollectInvocationOutput'];
export type LatentsOutput = s['LatentsOutput'];

View File

@ -19,7 +19,7 @@ from invokeai.app.services.graph import (
from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.invocations.image import *
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.params import ParamIntInvocation
from invokeai.app.invocations.primitives import IntegerInvocation
from invokeai.app.services.default_graphs import create_text_to_image
import pytest
@ -499,8 +499,8 @@ def test_graph_subgraph_t2i():
g.add_node(n1)
n2 = ParamIntInvocation(id="2", a=512)
n3 = ParamIntInvocation(id="3", a=256)
n2 = IntegerInvocation(id="2", a=512)
n3 = IntegerInvocation(id="3", a=256)
g.add_node(n2)
g.add_node(n3)