mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
480 lines
14 KiB
Python
480 lines
14 KiB
Python
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
|
from invokeai.app.invocations.fields import (
|
|
ColorField,
|
|
ConditioningField,
|
|
DenoiseMaskField,
|
|
FieldDescriptions,
|
|
ImageField,
|
|
Input,
|
|
InputField,
|
|
LatentsField,
|
|
MaskField,
|
|
OutputField,
|
|
UIComponent,
|
|
)
|
|
from invokeai.app.services.images.images_common import ImageDTO
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
|
|
from .baseinvocation import (
|
|
BaseInvocation,
|
|
BaseInvocationOutput,
|
|
invocation,
|
|
invocation_output,
|
|
)
|
|
|
|
"""
|
|
Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
|
|
- primitive nodes
|
|
- primitive outputs
|
|
- primitive collection outputs
|
|
"""
|
|
|
|
# region Boolean
|
|
|
|
|
|
@invocation_output("boolean_output")
|
|
class BooleanOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single boolean"""
|
|
|
|
value: bool = OutputField(description="The output boolean")
|
|
|
|
|
|
@invocation_output("boolean_collection_output")
|
|
class BooleanCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of booleans"""
|
|
|
|
collection: list[bool] = OutputField(
|
|
description="The output boolean collection",
|
|
)
|
|
|
|
|
|
@invocation(
|
|
"boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives", version="1.0.0"
|
|
)
|
|
class BooleanInvocation(BaseInvocation):
|
|
"""A boolean primitive value"""
|
|
|
|
value: bool = InputField(default=False, description="The boolean value")
|
|
|
|
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
|
return BooleanOutput(value=self.value)
|
|
|
|
|
|
@invocation(
|
|
"boolean_collection",
|
|
title="Boolean Collection Primitive",
|
|
tags=["primitives", "boolean", "collection"],
|
|
category="primitives",
|
|
version="1.0.1",
|
|
)
|
|
class BooleanCollectionInvocation(BaseInvocation):
|
|
"""A collection of boolean primitive values"""
|
|
|
|
collection: list[bool] = InputField(default=[], description="The collection of boolean values")
|
|
|
|
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
|
return BooleanCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|
|
|
|
# region Integer
|
|
|
|
|
|
@invocation_output("integer_output")
|
|
class IntegerOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single integer"""
|
|
|
|
value: int = OutputField(description="The output integer")
|
|
|
|
|
|
@invocation_output("integer_collection_output")
|
|
class IntegerCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of integers"""
|
|
|
|
collection: list[int] = OutputField(
|
|
description="The int collection",
|
|
)
|
|
|
|
|
|
@invocation(
|
|
"integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives", version="1.0.0"
|
|
)
|
|
class IntegerInvocation(BaseInvocation):
|
|
"""An integer primitive value"""
|
|
|
|
value: int = InputField(default=0, description="The integer value")
|
|
|
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
|
return IntegerOutput(value=self.value)
|
|
|
|
|
|
@invocation(
|
|
"integer_collection",
|
|
title="Integer Collection Primitive",
|
|
tags=["primitives", "integer", "collection"],
|
|
category="primitives",
|
|
version="1.0.1",
|
|
)
|
|
class IntegerCollectionInvocation(BaseInvocation):
|
|
"""A collection of integer primitive values"""
|
|
|
|
collection: list[int] = InputField(default=[], description="The collection of integer values")
|
|
|
|
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
|
return IntegerCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|
|
|
|
# region Float
|
|
|
|
|
|
@invocation_output("float_output")
|
|
class FloatOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single float"""
|
|
|
|
value: float = OutputField(description="The output float")
|
|
|
|
|
|
@invocation_output("float_collection_output")
|
|
class FloatCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of floats"""
|
|
|
|
collection: list[float] = OutputField(
|
|
description="The float collection",
|
|
)
|
|
|
|
|
|
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives", version="1.0.0")
|
|
class FloatInvocation(BaseInvocation):
|
|
"""A float primitive value"""
|
|
|
|
value: float = InputField(default=0.0, description="The float value")
|
|
|
|
def invoke(self, context: InvocationContext) -> FloatOutput:
|
|
return FloatOutput(value=self.value)
|
|
|
|
|
|
@invocation(
|
|
"float_collection",
|
|
title="Float Collection Primitive",
|
|
tags=["primitives", "float", "collection"],
|
|
category="primitives",
|
|
version="1.0.1",
|
|
)
|
|
class FloatCollectionInvocation(BaseInvocation):
|
|
"""A collection of float primitive values"""
|
|
|
|
collection: list[float] = InputField(default=[], description="The collection of float values")
|
|
|
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
|
return FloatCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|
|
|
|
# region String
|
|
|
|
|
|
@invocation_output("string_output")
|
|
class StringOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single string"""
|
|
|
|
value: str = OutputField(description="The output string")
|
|
|
|
|
|
@invocation_output("string_collection_output")
|
|
class StringCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of strings"""
|
|
|
|
collection: list[str] = OutputField(
|
|
description="The output strings",
|
|
)
|
|
|
|
|
|
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives", version="1.0.0")
|
|
class StringInvocation(BaseInvocation):
|
|
"""A string primitive value"""
|
|
|
|
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
|
|
|
def invoke(self, context: InvocationContext) -> StringOutput:
|
|
return StringOutput(value=self.value)
|
|
|
|
|
|
@invocation(
|
|
"string_collection",
|
|
title="String Collection Primitive",
|
|
tags=["primitives", "string", "collection"],
|
|
category="primitives",
|
|
version="1.0.1",
|
|
)
|
|
class StringCollectionInvocation(BaseInvocation):
|
|
"""A collection of string primitive values"""
|
|
|
|
collection: list[str] = InputField(default=[], description="The collection of string values")
|
|
|
|
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
|
return StringCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|
|
|
|
# region Image
|
|
|
|
|
|
@invocation_output("mask_output")
|
|
class MaskOutput(BaseInvocationOutput):
|
|
"""A torch mask tensor.
|
|
dtype: torch.bool
|
|
shape: (1, height, width).
|
|
"""
|
|
|
|
mask: MaskField = OutputField(description="The mask.")
|
|
width: int = OutputField(description="The width of the mask in pixels.")
|
|
height: int = OutputField(description="The height of the mask in pixels.")
|
|
|
|
|
|
@invocation_output("image_output")
|
|
class ImageOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single image"""
|
|
|
|
image: ImageField = OutputField(description="The output image")
|
|
width: int = OutputField(description="The width of the image in pixels")
|
|
height: int = OutputField(description="The height of the image in pixels")
|
|
|
|
@classmethod
|
|
def build(cls, image_dto: ImageDTO) -> "ImageOutput":
|
|
return cls(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
@invocation_output("image_collection_output")
|
|
class ImageCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of images"""
|
|
|
|
collection: list[ImageField] = OutputField(
|
|
description="The output images",
|
|
)
|
|
|
|
|
|
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1")
|
|
class ImageInvocation(BaseInvocation):
|
|
"""An image primitive value"""
|
|
|
|
image: ImageField = InputField(description="The image to load")
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.images.get_pil(self.image.image_name)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=self.image.image_name),
|
|
width=image.width,
|
|
height=image.height,
|
|
)
|
|
|
|
|
|
@invocation(
|
|
"image_collection",
|
|
title="Image Collection Primitive",
|
|
tags=["primitives", "image", "collection"],
|
|
category="primitives",
|
|
version="1.0.0",
|
|
)
|
|
class ImageCollectionInvocation(BaseInvocation):
|
|
"""A collection of image primitive values"""
|
|
|
|
collection: list[ImageField] = InputField(description="The collection of image values")
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
|
return ImageCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|
|
|
|
# region DenoiseMask
|
|
|
|
|
|
@invocation_output("denoise_mask_output")
|
|
class DenoiseMaskOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single image"""
|
|
|
|
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
|
|
|
@classmethod
|
|
def build(
|
|
cls, mask_name: str, masked_latents_name: Optional[str] = None, gradient: bool = False
|
|
) -> "DenoiseMaskOutput":
|
|
return cls(
|
|
denoise_mask=DenoiseMaskField(
|
|
mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=gradient
|
|
),
|
|
)
|
|
|
|
|
|
# endregion
|
|
|
|
# region Latents
|
|
|
|
|
|
@invocation_output("latents_output")
|
|
class LatentsOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single latents tensor"""
|
|
|
|
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
|
|
width: int = OutputField(description=FieldDescriptions.width)
|
|
height: int = OutputField(description=FieldDescriptions.height)
|
|
|
|
@classmethod
|
|
def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput":
|
|
return cls(
|
|
latents=LatentsField(latents_name=latents_name, seed=seed),
|
|
width=latents.size()[3] * LATENT_SCALE_FACTOR,
|
|
height=latents.size()[2] * LATENT_SCALE_FACTOR,
|
|
)
|
|
|
|
|
|
@invocation_output("latents_collection_output")
|
|
class LatentsCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of latents tensors"""
|
|
|
|
collection: list[LatentsField] = OutputField(
|
|
description=FieldDescriptions.latents,
|
|
)
|
|
|
|
|
|
@invocation(
|
|
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.1"
|
|
)
|
|
class LatentsInvocation(BaseInvocation):
|
|
"""A latents tensor primitive value"""
|
|
|
|
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
|
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
latents = context.tensors.load(self.latents.latents_name)
|
|
|
|
return LatentsOutput.build(self.latents.latents_name, latents)
|
|
|
|
|
|
@invocation(
|
|
"latents_collection",
|
|
title="Latents Collection Primitive",
|
|
tags=["primitives", "latents", "collection"],
|
|
category="primitives",
|
|
version="1.0.0",
|
|
)
|
|
class LatentsCollectionInvocation(BaseInvocation):
|
|
"""A collection of latents tensor primitive values"""
|
|
|
|
collection: list[LatentsField] = InputField(
|
|
description="The collection of latents tensors",
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
|
return LatentsCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|
|
|
|
# region Color
|
|
|
|
|
|
@invocation_output("color_output")
|
|
class ColorOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single color"""
|
|
|
|
color: ColorField = OutputField(description="The output color")
|
|
|
|
|
|
@invocation_output("color_collection_output")
|
|
class ColorCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of colors"""
|
|
|
|
collection: list[ColorField] = OutputField(
|
|
description="The output colors",
|
|
)
|
|
|
|
|
|
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives", version="1.0.0")
|
|
class ColorInvocation(BaseInvocation):
|
|
"""A color primitive value"""
|
|
|
|
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
|
|
|
|
def invoke(self, context: InvocationContext) -> ColorOutput:
|
|
return ColorOutput(color=self.color)
|
|
|
|
|
|
# endregion
|
|
|
|
# region Conditioning
|
|
|
|
|
|
@invocation_output("conditioning_output")
|
|
class ConditioningOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a single conditioning tensor"""
|
|
|
|
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
|
|
|
@classmethod
|
|
def build(cls, conditioning_name: str) -> "ConditioningOutput":
|
|
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
|
|
|
|
|
@invocation_output("conditioning_collection_output")
|
|
class ConditioningCollectionOutput(BaseInvocationOutput):
|
|
"""Base class for nodes that output a collection of conditioning tensors"""
|
|
|
|
collection: list[ConditioningField] = OutputField(
|
|
description="The output conditioning tensors",
|
|
)
|
|
|
|
|
|
@invocation(
|
|
"conditioning",
|
|
title="Conditioning Primitive",
|
|
tags=["primitives", "conditioning"],
|
|
category="primitives",
|
|
version="1.0.0",
|
|
)
|
|
class ConditioningInvocation(BaseInvocation):
|
|
"""A conditioning tensor primitive value"""
|
|
|
|
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
|
|
|
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
|
return ConditioningOutput(conditioning=self.conditioning)
|
|
|
|
|
|
@invocation(
|
|
"conditioning_collection",
|
|
title="Conditioning Collection Primitive",
|
|
tags=["primitives", "conditioning", "collection"],
|
|
category="primitives",
|
|
version="1.0.1",
|
|
)
|
|
class ConditioningCollectionInvocation(BaseInvocation):
|
|
"""A collection of conditioning tensor primitive values"""
|
|
|
|
collection: list[ConditioningField] = InputField(
|
|
default=[],
|
|
description="The collection of conditioning tensors",
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
|
return ConditioningCollectionOutput(collection=self.collection)
|
|
|
|
|
|
# endregion
|