feat: add missing primitive collections

- add missing primitive collections
- remove `Seed` and `LoRAField` (they don't exist)
This commit is contained in:
psychedelicious
2023-08-15 22:18:37 +10:00
parent fa884134d9
commit 2b7dd3e236
8 changed files with 377 additions and 137 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Tuple
from typing import Literal, Optional, Tuple, Union
from anyio import Condition
from pydantic import BaseModel, Field
import torch
@ -47,8 +48,8 @@ class BooleanCollectionOutput(BaseInvocationOutput):
)
@title("Boolean Primitive")
@tags("boolean")
@title("Boolean")
@tags("primitives", "boolean")
class BooleanInvocation(BaseInvocation):
"""A boolean primitive value"""
@ -61,6 +62,22 @@ class BooleanInvocation(BaseInvocation):
return BooleanOutput(a=self.a)
@title("Boolean Collection")
@tags("primitives", "boolean", "collection")
class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values"""
type: Literal["boolean_collection"] = "boolean_collection"
# Inputs
collection: list[bool] = InputField(
default=False, description="The collection of boolean values", ui_type=UIType.BooleanCollection
)
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
return BooleanCollectionOutput(collection=self.collection)
# endregion
# region Integer
@ -84,8 +101,8 @@ class IntegerCollectionOutput(BaseInvocationOutput):
)
@title("Integer Primitive")
@tags("integer")
@title("Integer")
@tags("primitives", "integer")
class IntegerInvocation(BaseInvocation):
"""An integer primitive value"""
@ -98,6 +115,22 @@ class IntegerInvocation(BaseInvocation):
return IntegerOutput(a=self.a)
@title("Integer Collection")
@tags("primitives", "integer", "collection")
class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values"""
type: Literal["integer_collection"] = "integer_collection"
# Inputs
collection: list[int] = InputField(
default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection
)
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=self.collection)
# endregion
# region Float
@ -121,8 +154,8 @@ class FloatCollectionOutput(BaseInvocationOutput):
)
@title("Float Primitive")
@tags("float")
@title("Float")
@tags("primitives", "float")
class FloatInvocation(BaseInvocation):
"""A float primitive value"""
@ -135,6 +168,22 @@ class FloatInvocation(BaseInvocation):
return FloatOutput(a=self.param)
@title("Float Collection")
@tags("primitives", "float", "collection")
class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values"""
type: Literal["float_collection"] = "float_collection"
# Inputs
collection: list[float] = InputField(
default=0, description="The collection of float values", ui_type=UIType.FloatCollection
)
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
return FloatCollectionOutput(collection=self.collection)
# endregion
# region String
@ -158,8 +207,8 @@ class StringCollectionOutput(BaseInvocationOutput):
)
@title("String Primitive")
@tags("string")
@title("String")
@tags("primitives", "string")
class StringInvocation(BaseInvocation):
"""A string primitive value"""
@ -172,6 +221,22 @@ class StringInvocation(BaseInvocation):
return StringOutput(text=self.text)
@title("String Collection")
@tags("primitives", "string", "collection")
class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values"""
type: Literal["string_collection"] = "string_collection"
# Inputs
collection: list[str] = InputField(
default=0, description="The collection of string values", ui_type=UIType.StringCollection
)
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
return StringCollectionOutput(collection=self.collection)
# endregion
# region Image
@ -204,7 +269,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
@title("Image Primitive")
@tags("image")
@tags("primitives", "image")
class ImageInvocation(BaseInvocation):
"""An image primitive value"""
@ -224,6 +289,22 @@ class ImageInvocation(BaseInvocation):
)
@title("Image Collection")
@tags("primitives", "image", "collection")
class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values"""
type: Literal["image_collection"] = "image_collection"
# Inputs
collection: list[ImageField] = InputField(
default=0, description="The collection of image values", ui_type=UIType.ImageCollection
)
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.collection)
# endregion
# region Latents
@ -253,7 +334,7 @@ class LatentsCollectionOutput(BaseInvocationOutput):
type: Literal["latents_collection_output"] = "latents_collection_output"
latents: list[LatentsField] = OutputField(
collection: list[LatentsField] = OutputField(
default_factory=list,
description=FieldDescriptions.latents,
ui_type=UIType.LatentsCollection,
@ -261,7 +342,7 @@ class LatentsCollectionOutput(BaseInvocationOutput):
@title("Latents Primitive")
@tags("latents")
@tags("primitives", "latents")
class LatentsInvocation(BaseInvocation):
"""A latents tensor primitive value"""
@ -276,6 +357,22 @@ class LatentsInvocation(BaseInvocation):
return build_latents_output(self.latents.latents_name, latents)
@title("Latents Collection")
@tags("primitives", "latents", "collection")
class LatentsCollectionInvocation(BaseInvocation):
"""A collection of latents tensor primitive values"""
type: Literal["latents_collection"] = "latents_collection"
# Inputs
collection: list[LatentsField] = InputField(
default=0, description="The collection of latents tensors", ui_type=UIType.LatentsCollection
)
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
return LatentsCollectionOutput(collection=self.collection)
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name, seed=seed),
@ -320,7 +417,7 @@ class ColorCollectionOutput(BaseInvocationOutput):
@title("Color Primitive")
@tags("color")
@tags("primitives", "color")
class ColorInvocation(BaseInvocation):
"""A color primitive value"""
@ -339,7 +436,7 @@ class ColorInvocation(BaseInvocation):
class ConditioningField(BaseModel):
"""A conditioning tensor primitive field"""
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
@ -366,7 +463,7 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
@title("Conditioning Primitive")
@tags("conditioning")
@tags("primitives", "conditioning")
class ConditioningInvocation(BaseInvocation):
"""A conditioning tensor primitive value"""
@ -378,4 +475,20 @@ class ConditioningInvocation(BaseInvocation):
return ConditioningOutput(conditioning=self.conditioning)
@title("Conditioning Collection")
@tags("primitives", "conditioning", "collection")
class ConditioningCollectionInvocation(BaseInvocation):
"""A collection of conditioning tensor primitive values"""
type: Literal["conditioning_collection"] = "conditioning_collection"
# Inputs
collection: list[ConditioningField] = InputField(
default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection
)
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
return ConditioningCollectionOutput(collection=self.collection)
# endregion