BREAKING CHANGES: invocations now require model key, not base/type/name

- Implement new model loader and modify invocations and embeddings

- Finish implementation loaders for all models currently supported by
  InvokeAI.

- Move lora, textual_inversion, and model patching support into
  backend/embeddings.

- Restore support for model cache statistics collection (a little ugly,
  needs work).

- Fixed up invocations that load and patch models.

- Move seamless and silencewarnings utils into better location
This commit is contained in:
Lincoln Stein
2024-02-05 22:56:32 -05:00
committed by psychedelicious
parent 5745ce9c7d
commit 78ef946e01
31 changed files with 727 additions and 496 deletions

View File

@ -1,7 +1,7 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType
from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -40,45 +40,31 @@ class SDXLModelLoaderInvocation(BaseInvocation):
# TODO: precision?
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
model_key = self.model.key
# TODO: not found exceptions
if not context.models.exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
if not context.services.model_records.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
return SDXLModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.TextEncoder,
),
loras=[],
@ -86,15 +72,11 @@ class SDXLModelLoaderInvocation(BaseInvocation):
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.TextEncoder2,
),
loras=[],
@ -102,9 +84,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Vae,
),
),
@ -129,45 +109,31 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
# TODO: precision?
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
model_key = self.model.key
# TODO: not found exceptions
if not context.models.exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
if not context.services.model_records.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.TextEncoder2,
),
loras=[],
@ -175,9 +141,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=model_key,
submodel=SubModelType.Vae,
),
),