mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(nodes): restore type annotations for InvocationContext
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import (
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
@ -31,10 +32,7 @@ from .baseinvocation import (
|
||||
)
|
||||
from .model import ClipField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
|
||||
# class ConditioningAlgo(str, Enum):
|
||||
@ -65,7 +63,7 @@ class CompelInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context) -> ConditioningOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||
|
||||
@ -148,7 +146,7 @@ class CompelInvocation(BaseInvocation):
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_compel(
|
||||
self,
|
||||
context: "InvocationContext",
|
||||
context: InvocationContext,
|
||||
clip_field: ClipField,
|
||||
prompt: str,
|
||||
get_pooled: bool,
|
||||
@ -288,7 +286,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context) -> ConditioningOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
||||
)
|
||||
@ -373,7 +371,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context) -> ConditioningOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
# TODO: if there will appear lora for refiner - write proper prefix
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||
|
||||
@ -418,7 +416,7 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
||||
|
||||
def invoke(self, context) -> ClipSkipInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
return ClipSkipInvocationOutput(
|
||||
clip=self.clip,
|
||||
|
Reference in New Issue
Block a user