mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: add missing primitive collections
- add missing primitive collections - remove `Seed` and `LoRAField` (they don't exist)
This commit is contained in:
@ -70,19 +70,3 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
rng = np.random.default_rng(self.seed)
|
||||
return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
||||
|
||||
|
||||
@title("Image Collection")
|
||||
@tags("image", "collection")
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""Load a collection of images and provide it as output."""
|
||||
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Inputs
|
||||
images: list[ImageField] = InputField(
|
||||
default=[], description="The image collection to load", ui_type=UIType.ImageCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.images)
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional, Tuple
|
||||
from typing import Literal, Optional, Tuple, Union
|
||||
from anyio import Condition
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
@ -47,8 +48,8 @@ class BooleanCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@title("Boolean Primitive")
|
||||
@tags("boolean")
|
||||
@title("Boolean")
|
||||
@tags("primitives", "boolean")
|
||||
class BooleanInvocation(BaseInvocation):
|
||||
"""A boolean primitive value"""
|
||||
|
||||
@ -61,6 +62,22 @@ class BooleanInvocation(BaseInvocation):
|
||||
return BooleanOutput(a=self.a)
|
||||
|
||||
|
||||
@title("Boolean Collection")
|
||||
@tags("primitives", "boolean", "collection")
|
||||
class BooleanCollectionInvocation(BaseInvocation):
|
||||
"""A collection of boolean primitive values"""
|
||||
|
||||
type: Literal["boolean_collection"] = "boolean_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[bool] = InputField(
|
||||
default=False, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||
return BooleanCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Integer
|
||||
@ -84,8 +101,8 @@ class IntegerCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@title("Integer Primitive")
|
||||
@tags("integer")
|
||||
@title("Integer")
|
||||
@tags("primitives", "integer")
|
||||
class IntegerInvocation(BaseInvocation):
|
||||
"""An integer primitive value"""
|
||||
|
||||
@ -98,6 +115,22 @@ class IntegerInvocation(BaseInvocation):
|
||||
return IntegerOutput(a=self.a)
|
||||
|
||||
|
||||
@title("Integer Collection")
|
||||
@tags("primitives", "integer", "collection")
|
||||
class IntegerCollectionInvocation(BaseInvocation):
|
||||
"""A collection of integer primitive values"""
|
||||
|
||||
type: Literal["integer_collection"] = "integer_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[int] = InputField(
|
||||
default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Float
|
||||
@ -121,8 +154,8 @@ class FloatCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@title("Float Primitive")
|
||||
@tags("float")
|
||||
@title("Float")
|
||||
@tags("primitives", "float")
|
||||
class FloatInvocation(BaseInvocation):
|
||||
"""A float primitive value"""
|
||||
|
||||
@ -135,6 +168,22 @@ class FloatInvocation(BaseInvocation):
|
||||
return FloatOutput(a=self.param)
|
||||
|
||||
|
||||
@title("Float Collection")
|
||||
@tags("primitives", "float", "collection")
|
||||
class FloatCollectionInvocation(BaseInvocation):
|
||||
"""A collection of float primitive values"""
|
||||
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[float] = InputField(
|
||||
default=0, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
return FloatCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region String
|
||||
@ -158,8 +207,8 @@ class StringCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@title("String Primitive")
|
||||
@tags("string")
|
||||
@title("String")
|
||||
@tags("primitives", "string")
|
||||
class StringInvocation(BaseInvocation):
|
||||
"""A string primitive value"""
|
||||
|
||||
@ -172,6 +221,22 @@ class StringInvocation(BaseInvocation):
|
||||
return StringOutput(text=self.text)
|
||||
|
||||
|
||||
@title("String Collection")
|
||||
@tags("primitives", "string", "collection")
|
||||
class StringCollectionInvocation(BaseInvocation):
|
||||
"""A collection of string primitive values"""
|
||||
|
||||
type: Literal["string_collection"] = "string_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[str] = InputField(
|
||||
default=0, description="The collection of string values", ui_type=UIType.StringCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
return StringCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Image
|
||||
@ -204,7 +269,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@title("Image Primitive")
|
||||
@tags("image")
|
||||
@tags("primitives", "image")
|
||||
class ImageInvocation(BaseInvocation):
|
||||
"""An image primitive value"""
|
||||
|
||||
@ -224,6 +289,22 @@ class ImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("Image Collection")
|
||||
@tags("primitives", "image", "collection")
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""A collection of image primitive values"""
|
||||
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[ImageField] = InputField(
|
||||
default=0, description="The collection of image values", ui_type=UIType.ImageCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Latents
|
||||
@ -253,7 +334,7 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
type: Literal["latents_collection_output"] = "latents_collection_output"
|
||||
|
||||
latents: list[LatentsField] = OutputField(
|
||||
collection: list[LatentsField] = OutputField(
|
||||
default_factory=list,
|
||||
description=FieldDescriptions.latents,
|
||||
ui_type=UIType.LatentsCollection,
|
||||
@ -261,7 +342,7 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@title("Latents Primitive")
|
||||
@tags("latents")
|
||||
@tags("primitives", "latents")
|
||||
class LatentsInvocation(BaseInvocation):
|
||||
"""A latents tensor primitive value"""
|
||||
|
||||
@ -276,6 +357,22 @@ class LatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(self.latents.latents_name, latents)
|
||||
|
||||
|
||||
@title("Latents Collection")
|
||||
@tags("primitives", "latents", "collection")
|
||||
class LatentsCollectionInvocation(BaseInvocation):
|
||||
"""A collection of latents tensor primitive values"""
|
||||
|
||||
type: Literal["latents_collection"] = "latents_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[LatentsField] = InputField(
|
||||
default=0, description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||
return LatentsCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||
@ -320,7 +417,7 @@ class ColorCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@title("Color Primitive")
|
||||
@tags("color")
|
||||
@tags("primitives", "color")
|
||||
class ColorInvocation(BaseInvocation):
|
||||
"""A color primitive value"""
|
||||
|
||||
@ -339,7 +436,7 @@ class ColorInvocation(BaseInvocation):
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive field"""
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
@ -366,7 +463,7 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@title("Conditioning Primitive")
|
||||
@tags("conditioning")
|
||||
@tags("primitives", "conditioning")
|
||||
class ConditioningInvocation(BaseInvocation):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
@ -378,4 +475,20 @@ class ConditioningInvocation(BaseInvocation):
|
||||
return ConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
|
||||
@title("Conditioning Collection")
|
||||
@tags("primitives", "conditioning", "collection")
|
||||
class ConditioningCollectionInvocation(BaseInvocation):
|
||||
"""A collection of conditioning tensor primitive values"""
|
||||
|
||||
type: Literal["conditioning_collection"] = "conditioning_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[ConditioningField] = InputField(
|
||||
default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
||||
return ConditioningCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
Reference in New Issue
Block a user