mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): update all invocations to use new invocation context
Update all invocations to use the new context. The changes are all fairly simple, but there are a lot of them. Supporting minor changes: - Patch bump for all nodes that use the context - Update invocation processor to provide new context - Minor change to `EventServiceBase` to accept a node's ID instead of the dict version of a node - Minor change to `ModelManagerService` to support the new wrapped context - Fanagling of imports to avoid circular dependencies
This commit is contained in:
@ -10,7 +10,6 @@ from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -102,7 +101,7 @@ class LoRAModelField(BaseModel):
|
||||
title="Main Model",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
@ -110,13 +109,13 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
def invoke(self, context) -> ModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.models.exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
@ -203,7 +202,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0")
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
@ -222,14 +221,14 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
title="CLIP",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
def invoke(self, context) -> LoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.models.exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
@ -285,7 +284,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
title="SDXL LoRA",
|
||||
tags=["lora", "model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
@ -311,14 +310,14 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
title="CLIP 2",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
def invoke(self, context) -> SDXLLoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.models.exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
@ -384,7 +383,7 @@ class VAEModelField(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
@ -394,12 +393,12 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
def invoke(self, context) -> VAEOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
model_name = self.vae_model.model_name
|
||||
model_type = ModelType.Vae
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.models.exists(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
@ -449,7 +448,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
|
||||
def invoke(self, context) -> SeamlessModeOutput:
|
||||
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
||||
unet = copy.deepcopy(self.unet)
|
||||
vae = copy.deepcopy(self.vae)
|
||||
@ -485,6 +484,6 @@ class FreeUInvocation(BaseInvocation):
|
||||
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
|
||||
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> UNetOutput:
|
||||
def invoke(self, context) -> UNetOutput:
|
||||
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
|
||||
return UNetOutput(unet=self.unet)
|
||||
|
Reference in New Issue
Block a user