fix(nodes): restore type annotations for InvocationContext

This commit is contained in:
psychedelicious
2024-02-05 17:16:35 +11:00
parent 281c334531
commit 4ce21087d3
25 changed files with 158 additions and 143 deletions

View File

@ -3,7 +3,7 @@
import math
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import TYPE_CHECKING, List, Literal, Optional, Union
from typing import List, Literal, Optional, Union
import einops
import numpy as np
@ -42,6 +42,7 @@ from invokeai.app.invocations.primitives import (
LatentsOutput,
)
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
@ -70,9 +71,6 @@ from .baseinvocation import (
from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField
if TYPE_CHECKING:
from invokeai.app.services.shared.invocation_context import InvocationContext
if choose_torch_device() == torch.device("mps"):
from torch import mps
@ -177,7 +175,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
def get_scheduler(
context: "InvocationContext",
context: InvocationContext,
scheduler_info: ModelInfo,
scheduler_name: str,
seed: int,
@ -300,7 +298,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
def get_conditioning_data(
self,
context: "InvocationContext",
context: InvocationContext,
scheduler,
unet,
seed,
@ -369,7 +367,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_control_data(
self,
context: "InvocationContext",
context: InvocationContext,
control_input: Union[ControlField, List[ControlField]],
latents_shape: List[int],
exit_stack: ExitStack,
@ -442,7 +440,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_ip_adapter_data(
self,
context: "InvocationContext",
context: InvocationContext,
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
conditioning_data: ConditioningData,
exit_stack: ExitStack,
@ -509,7 +507,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
def run_t2i_adapters(
self,
context: "InvocationContext",
context: InvocationContext,
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
latents_shape: list[int],
do_classifier_free_guidance: bool,
@ -618,7 +616,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
return num_inference_steps, timesteps, init_timestep
def prep_inpaint_mask(self, context: "InvocationContext", latents):
def prep_inpaint_mask(self, context: InvocationContext, latents):
if self.denoise_mask is None:
return None, None