fix(nodes): restore type annotations for InvocationContext

This commit is contained in:
psychedelicious
2024-02-05 17:16:35 +11:00
parent 281c334531
commit 4ce21087d3
25 changed files with 158 additions and 143 deletions

View File

@ -17,6 +17,7 @@ from invokeai.app.invocations.fields import (
UIComponent,
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import (
BaseInvocation,
@ -59,7 +60,7 @@ class BooleanInvocation(BaseInvocation):
value: bool = InputField(default=False, description="The boolean value")
def invoke(self, context) -> BooleanOutput:
def invoke(self, context: InvocationContext) -> BooleanOutput:
return BooleanOutput(value=self.value)
@ -75,7 +76,7 @@ class BooleanCollectionInvocation(BaseInvocation):
collection: list[bool] = InputField(default=[], description="The collection of boolean values")
def invoke(self, context) -> BooleanCollectionOutput:
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
return BooleanCollectionOutput(collection=self.collection)
@ -108,7 +109,7 @@ class IntegerInvocation(BaseInvocation):
value: int = InputField(default=0, description="The integer value")
def invoke(self, context) -> IntegerOutput:
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(value=self.value)
@ -124,7 +125,7 @@ class IntegerCollectionInvocation(BaseInvocation):
collection: list[int] = InputField(default=[], description="The collection of integer values")
def invoke(self, context) -> IntegerCollectionOutput:
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=self.collection)
@ -155,7 +156,7 @@ class FloatInvocation(BaseInvocation):
value: float = InputField(default=0.0, description="The float value")
def invoke(self, context) -> FloatOutput:
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(value=self.value)
@ -171,7 +172,7 @@ class FloatCollectionInvocation(BaseInvocation):
collection: list[float] = InputField(default=[], description="The collection of float values")
def invoke(self, context) -> FloatCollectionOutput:
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
return FloatCollectionOutput(collection=self.collection)
@ -202,7 +203,7 @@ class StringInvocation(BaseInvocation):
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
def invoke(self, context) -> StringOutput:
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(value=self.value)
@ -218,7 +219,7 @@ class StringCollectionInvocation(BaseInvocation):
collection: list[str] = InputField(default=[], description="The collection of string values")
def invoke(self, context) -> StringCollectionOutput:
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
return StringCollectionOutput(collection=self.collection)
@ -261,7 +262,7 @@ class ImageInvocation(
image: ImageField = InputField(description="The image to load")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
return ImageOutput(
@ -283,7 +284,7 @@ class ImageCollectionInvocation(BaseInvocation):
collection: list[ImageField] = InputField(description="The collection of image values")
def invoke(self, context) -> ImageCollectionOutput:
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.collection)
@ -346,7 +347,7 @@ class LatentsInvocation(BaseInvocation):
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
def invoke(self, context) -> LatentsOutput:
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name)
return LatentsOutput.build(self.latents.latents_name, latents)
@ -366,7 +367,7 @@ class LatentsCollectionInvocation(BaseInvocation):
description="The collection of latents tensors",
)
def invoke(self, context) -> LatentsCollectionOutput:
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
return LatentsCollectionOutput(collection=self.collection)
@ -397,7 +398,7 @@ class ColorInvocation(BaseInvocation):
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
def invoke(self, context) -> ColorOutput:
def invoke(self, context: InvocationContext) -> ColorOutput:
return ColorOutput(color=self.color)
@ -438,7 +439,7 @@ class ConditioningInvocation(BaseInvocation):
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
def invoke(self, context) -> ConditioningOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
return ConditioningOutput(conditioning=self.conditioning)
@ -457,7 +458,7 @@ class ConditioningCollectionInvocation(BaseInvocation):
description="The collection of conditioning tensors",
)
def invoke(self, context) -> ConditioningCollectionOutput:
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
return ConditioningCollectionOutput(collection=self.collection)