fix(nodes): restore type annotations for InvocationContext

This commit is contained in:
psychedelicious
2024-02-05 17:16:35 +11:00
parent 281c334531
commit 4ce21087d3
25 changed files with 158 additions and 143 deletions

View File

@ -4,6 +4,7 @@ from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from ...backend.model_management import BaseModelType, ModelType, SubModelType
@ -109,7 +110,7 @@ class MainModelLoaderInvocation(BaseInvocation):
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision?
def invoke(self, context) -> ModelLoaderOutput:
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
@ -221,7 +222,7 @@ class LoraLoaderInvocation(BaseInvocation):
title="CLIP",
)
def invoke(self, context) -> LoraLoaderOutput:
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
@ -310,7 +311,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
title="CLIP 2",
)
def invoke(self, context) -> SDXLLoraLoaderOutput:
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
@ -393,7 +394,7 @@ class VaeLoaderInvocation(BaseInvocation):
title="VAE",
)
def invoke(self, context) -> VAEOutput:
def invoke(self, context: InvocationContext) -> VAEOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
@ -448,7 +449,7 @@ class SeamlessModeInvocation(BaseInvocation):
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
def invoke(self, context) -> SeamlessModeOutput:
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
unet = copy.deepcopy(self.unet)
vae = copy.deepcopy(self.vae)
@ -484,6 +485,6 @@ class FreeUInvocation(BaseInvocation):
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
def invoke(self, context) -> UNetOutput:
def invoke(self, context: InvocationContext) -> UNetOutput:
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
return UNetOutput(unet=self.unet)