fix(nodes): restore type annotations for InvocationContext

This commit is contained in:
psychedelicious
2024-02-05 17:16:35 +11:00
parent 281c334531
commit 4ce21087d3
25 changed files with 158 additions and 143 deletions

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Optional, Union
import torch
from compel import Compel, ReturnedEmbeddingsType
@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import (
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
@ -31,10 +32,7 @@ from .baseinvocation import (
)
from .model import ClipField
if TYPE_CHECKING:
from invokeai.app.services.shared.invocation_context import InvocationContext
# unconditioned: Optional[torch.Tensor]
# unconditioned: Optional[torch.Tensor]
# class ConditioningAlgo(str, Enum):
@ -65,7 +63,7 @@ class CompelInvocation(BaseInvocation):
)
@torch.no_grad()
def invoke(self, context) -> ConditioningOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
@ -148,7 +146,7 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase:
def run_clip_compel(
self,
context: "InvocationContext",
context: InvocationContext,
clip_field: ClipField,
prompt: str,
get_pooled: bool,
@ -288,7 +286,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
@torch.no_grad()
def invoke(self, context) -> ConditioningOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
)
@ -373,7 +371,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad()
def invoke(self, context) -> ConditioningOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
@ -418,7 +416,7 @@ class ClipSkipInvocation(BaseInvocation):
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context) -> ClipSkipInvocationOutput:
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers
return ClipSkipInvocationOutput(
clip=self.clip,