fix(nodes): restore type annotations for InvocationContext

This commit is contained in:
psychedelicious
2024-02-05 17:16:35 +11:00
parent 281c334531
commit 4ce21087d3
25 changed files with 158 additions and 143 deletions

View File

@ -17,6 +17,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import uuid_string
# in 3.10 this would be "from types import NoneType"
@ -201,7 +202,7 @@ class GraphInvocation(BaseInvocation):
# TODO: figure out how to create a default here
graph: "Graph" = InputField(description="The graph to run", default=None)
def invoke(self, context) -> GraphInvocationOutput:
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
"""Invoke with provided services and return outputs."""
return GraphInvocationOutput()
@ -227,7 +228,7 @@ class IterateInvocation(BaseInvocation):
)
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
def invoke(self, context) -> IterateInvocationOutput:
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
"""Produces the outputs as values"""
return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection))
@ -254,7 +255,7 @@ class CollectInvocation(BaseInvocation):
description="The collection, will be provided on execution", default=[], ui_hidden=True
)
def invoke(self, context) -> CollectInvocationOutput:
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
"""Invoke with provided services and return outputs."""
return CollectInvocationOutput(collection=copy.copy(self.collection))