feat(nodes): move all invocation metadata (type, title, tags, category) to decorator

All invocation metadata (type, title, tags and category) are now defined in decorators.

The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.

Category is a new invocation metadata, but it is not used by the frontend just yet.

- `@invocation()` decorator for invocations

```py
@invocation(
    "sdxl_compel_prompt",
    title="SDXL Prompt",
    tags=["sdxl", "compel", "prompt"],
    category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
    ...
```

- `@invocation_output()` decorator for invocation outputs

```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
    ...
```

- update invocation docs
- add category to decorator
- regen frontend types
This commit is contained in:
psychedelicious
2023-08-30 18:35:12 +10:00
parent ae05d34584
commit 044d4c107a
23 changed files with 1523 additions and 2178 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Tuple
from typing import Optional, Tuple
import torch
from pydantic import BaseModel, Field
@ -15,8 +15,8 @@ from .baseinvocation import (
OutputField,
UIComponent,
UIType,
tags,
title,
invocation,
invocation_output,
)
"""
@ -29,44 +29,39 @@ Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
# region Boolean
@invocation_output("boolean_output")
class BooleanOutput(BaseInvocationOutput):
"""Base class for nodes that output a single boolean"""
type: Literal["boolean_output"] = "boolean_output"
value: bool = OutputField(description="The output boolean")
@invocation_output("boolean_collection_output")
class BooleanCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of booleans"""
type: Literal["boolean_collection_output"] = "boolean_collection_output"
# Outputs
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
@title("Boolean Primitive")
@tags("primitives", "boolean")
@invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives")
class BooleanInvocation(BaseInvocation):
"""A boolean primitive value"""
type: Literal["boolean"] = "boolean"
# Inputs
value: bool = InputField(default=False, description="The boolean value")
def invoke(self, context: InvocationContext) -> BooleanOutput:
return BooleanOutput(value=self.value)
@title("Boolean Primitive Collection")
@tags("primitives", "boolean", "collection")
@invocation(
"boolean_collection",
title="Boolean Collection Primitive",
tags=["primitives", "boolean", "collection"],
category="primitives",
)
class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values"""
type: Literal["boolean_collection"] = "boolean_collection"
# Inputs
collection: list[bool] = InputField(
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
)
@ -80,44 +75,39 @@ class BooleanCollectionInvocation(BaseInvocation):
# region Integer
@invocation_output("integer_output")
class IntegerOutput(BaseInvocationOutput):
"""Base class for nodes that output a single integer"""
type: Literal["integer_output"] = "integer_output"
value: int = OutputField(description="The output integer")
@invocation_output("integer_collection_output")
class IntegerCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of integers"""
type: Literal["integer_collection_output"] = "integer_collection_output"
# Outputs
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
@title("Integer Primitive")
@tags("primitives", "integer")
@invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives")
class IntegerInvocation(BaseInvocation):
"""An integer primitive value"""
type: Literal["integer"] = "integer"
# Inputs
value: int = InputField(default=0, description="The integer value")
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(value=self.value)
@title("Integer Primitive Collection")
@tags("primitives", "integer", "collection")
@invocation(
"integer_collection",
title="Integer Collection Primitive",
tags=["primitives", "integer", "collection"],
category="primitives",
)
class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values"""
type: Literal["integer_collection"] = "integer_collection"
# Inputs
collection: list[int] = InputField(
default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection
)
@ -131,44 +121,39 @@ class IntegerCollectionInvocation(BaseInvocation):
# region Float
@invocation_output("float_output")
class FloatOutput(BaseInvocationOutput):
"""Base class for nodes that output a single float"""
type: Literal["float_output"] = "float_output"
value: float = OutputField(description="The output float")
@invocation_output("float_collection_output")
class FloatCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of floats"""
type: Literal["float_collection_output"] = "float_collection_output"
# Outputs
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
@title("Float Primitive")
@tags("primitives", "float")
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives")
class FloatInvocation(BaseInvocation):
"""A float primitive value"""
type: Literal["float"] = "float"
# Inputs
value: float = InputField(default=0.0, description="The float value")
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(value=self.value)
@title("Float Primitive Collection")
@tags("primitives", "float", "collection")
@invocation(
"float_collection",
title="Float Collection Primitive",
tags=["primitives", "float", "collection"],
category="primitives",
)
class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values"""
type: Literal["float_collection"] = "float_collection"
# Inputs
collection: list[float] = InputField(
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
)
@ -182,44 +167,39 @@ class FloatCollectionInvocation(BaseInvocation):
# region String
@invocation_output("string_output")
class StringOutput(BaseInvocationOutput):
"""Base class for nodes that output a single string"""
type: Literal["string_output"] = "string_output"
value: str = OutputField(description="The output string")
@invocation_output("string_collection_output")
class StringCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of strings"""
type: Literal["string_collection_output"] = "string_collection_output"
# Outputs
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
@title("String Primitive")
@tags("primitives", "string")
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives")
class StringInvocation(BaseInvocation):
"""A string primitive value"""
type: Literal["string"] = "string"
# Inputs
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(value=self.value)
@title("String Primitive Collection")
@tags("primitives", "string", "collection")
@invocation(
"string_collection",
title="String Collection Primitive",
tags=["primitives", "string", "collection"],
category="primitives",
)
class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values"""
type: Literal["string_collection"] = "string_collection"
# Inputs
collection: list[str] = InputField(
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
)
@ -239,33 +219,26 @@ class ImageField(BaseModel):
image_name: str = Field(description="The name of the image")
@invocation_output("image_output")
class ImageOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""
type: Literal["image_output"] = "image_output"
image: ImageField = OutputField(description="The output image")
width: int = OutputField(description="The width of the image in pixels")
height: int = OutputField(description="The height of the image in pixels")
@invocation_output("image_collection_output")
class ImageCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of images"""
type: Literal["image_collection_output"] = "image_collection_output"
# Outputs
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
@title("Image Primitive")
@tags("primitives", "image")
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives")
class ImageInvocation(BaseInvocation):
"""An image primitive value"""
# Metadata
type: Literal["image"] = "image"
# Inputs
image: ImageField = InputField(description="The image to load")
def invoke(self, context: InvocationContext) -> ImageOutput:
@ -278,14 +251,15 @@ class ImageInvocation(BaseInvocation):
)
@title("Image Primitive Collection")
@tags("primitives", "image", "collection")
@invocation(
"image_collection",
title="Image Collection Primitive",
tags=["primitives", "image", "collection"],
category="primitives",
)
class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values"""
type: Literal["image_collection"] = "image_collection"
# Inputs
collection: list[ImageField] = InputField(
default=0, description="The collection of image values", ui_type=UIType.ImageCollection
)
@ -306,10 +280,10 @@ class DenoiseMaskField(BaseModel):
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
@invocation_output("denoise_mask_output")
class DenoiseMaskOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""
type: Literal["denoise_mask_output"] = "denoise_mask_output"
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
@ -325,11 +299,10 @@ class LatentsField(BaseModel):
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
@invocation_output("latents_output")
class LatentsOutput(BaseInvocationOutput):
"""Base class for nodes that output a single latents tensor"""
type: Literal["latents_output"] = "latents_output"
latents: LatentsField = OutputField(
description=FieldDescriptions.latents,
)
@ -337,25 +310,20 @@ class LatentsOutput(BaseInvocationOutput):
height: int = OutputField(description=FieldDescriptions.height)
@invocation_output("latents_collection_output")
class LatentsCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of latents tensors"""
type: Literal["latents_collection_output"] = "latents_collection_output"
collection: list[LatentsField] = OutputField(
description=FieldDescriptions.latents,
ui_type=UIType.LatentsCollection,
)
@title("Latents Primitive")
@tags("primitives", "latents")
@invocation("latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives")
class LatentsInvocation(BaseInvocation):
"""A latents tensor primitive value"""
type: Literal["latents"] = "latents"
# Inputs
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
def invoke(self, context: InvocationContext) -> LatentsOutput:
@ -364,14 +332,15 @@ class LatentsInvocation(BaseInvocation):
return build_latents_output(self.latents.latents_name, latents)
@title("Latents Primitive Collection")
@tags("primitives", "latents", "collection")
@invocation(
"latents_collection",
title="Latents Collection Primitive",
tags=["primitives", "latents", "collection"],
category="primitives",
)
class LatentsCollectionInvocation(BaseInvocation):
"""A collection of latents tensor primitive values"""
type: Literal["latents_collection"] = "latents_collection"
# Inputs
collection: list[LatentsField] = InputField(
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
)
@ -405,30 +374,24 @@ class ColorField(BaseModel):
return (self.r, self.g, self.b, self.a)
@invocation_output("color_output")
class ColorOutput(BaseInvocationOutput):
"""Base class for nodes that output a single color"""
type: Literal["color_output"] = "color_output"
color: ColorField = OutputField(description="The output color")
@invocation_output("color_collection_output")
class ColorCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of colors"""
type: Literal["color_collection_output"] = "color_collection_output"
# Outputs
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
@title("Color Primitive")
@tags("primitives", "color")
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives")
class ColorInvocation(BaseInvocation):
"""A color primitive value"""
type: Literal["color"] = "color"
# Inputs
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
def invoke(self, context: InvocationContext) -> ColorOutput:
@ -446,47 +409,47 @@ class ConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning tensor")
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""
type: Literal["conditioning_output"] = "conditioning_output"
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
@invocation_output("conditioning_collection_output")
class ConditioningCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of conditioning tensors"""
type: Literal["conditioning_collection_output"] = "conditioning_collection_output"
# Outputs
collection: list[ConditioningField] = OutputField(
description="The output conditioning tensors",
ui_type=UIType.ConditioningCollection,
)
@title("Conditioning Primitive")
@tags("primitives", "conditioning")
@invocation(
"conditioning",
title="Conditioning Primitive",
tags=["primitives", "conditioning"],
category="primitives",
)
class ConditioningInvocation(BaseInvocation):
"""A conditioning tensor primitive value"""
type: Literal["conditioning"] = "conditioning"
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
def invoke(self, context: InvocationContext) -> ConditioningOutput:
return ConditioningOutput(conditioning=self.conditioning)
@title("Conditioning Primitive Collection")
@tags("primitives", "conditioning", "collection")
@invocation(
"conditioning_collection",
title="Conditioning Collection Primitive",
tags=["primitives", "conditioning", "collection"],
category="primitives",
)
class ConditioningCollectionInvocation(BaseInvocation):
"""A collection of conditioning tensor primitive values"""
type: Literal["conditioning_collection"] = "conditioning_collection"
# Inputs
collection: list[ConditioningField] = InputField(
default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection
)