feat(nodes): update all invocations to use new invocation context

Update all invocations to use the new context. The changes are all fairly simple, but there are a lot of them.

Supporting minor changes:
- Patch bump for all nodes that use the context
- Update invocation processor to provide new context
- Minor change to `EventServiceBase` to accept a node's ID instead of the dict version of a node
- Minor change to `ModelManagerService` to support the new wrapped context
- Fanagling of imports to avoid circular dependencies
This commit is contained in:
psychedelicious
2024-01-13 23:23:16 +11:00
parent 9bc2d09889
commit 8637c40661
32 changed files with 716 additions and 1191 deletions

View File

@ -10,7 +10,6 @@ from ...backend.model_management import BaseModelType, ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
invocation,
invocation_output,
)
@ -102,7 +101,7 @@ class LoRAModelField(BaseModel):
title="Main Model",
tags=["model"],
category="model",
version="1.0.0",
version="1.0.1",
)
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
@ -110,13 +109,13 @@ class MainModelLoaderInvocation(BaseInvocation):
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
def invoke(self, context) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
if not context.models.exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
@ -203,7 +202,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
@ -222,14 +221,14 @@ class LoraLoaderInvocation(BaseInvocation):
title="CLIP",
)
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
def invoke(self, context) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
if not context.services.model_manager.model_exists(
if not context.models.exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
@ -285,7 +284,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
title="SDXL LoRA",
tags=["lora", "model"],
category="model",
version="1.0.0",
version="1.0.1",
)
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
@ -311,14 +310,14 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
title="CLIP 2",
)
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
def invoke(self, context) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
if not context.services.model_manager.model_exists(
if not context.models.exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
@ -384,7 +383,7 @@ class VAEModelField(BaseModel):
model_config = ConfigDict(protected_namespaces=())
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
@ -394,12 +393,12 @@ class VaeLoaderInvocation(BaseInvocation):
title="VAE",
)
def invoke(self, context: InvocationContext) -> VAEOutput:
def invoke(self, context) -> VAEOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
if not context.services.model_manager.model_exists(
if not context.models.exists(
base_model=base_model,
model_name=model_name,
model_type=model_type,
@ -449,7 +448,7 @@ class SeamlessModeInvocation(BaseInvocation):
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
def invoke(self, context) -> SeamlessModeOutput:
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
unet = copy.deepcopy(self.unet)
vae = copy.deepcopy(self.vae)
@ -485,6 +484,6 @@ class FreeUInvocation(BaseInvocation):
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
def invoke(self, context: InvocationContext) -> UNetOutput:
def invoke(self, context) -> UNetOutput:
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
return UNetOutput(unet=self.unet)