InvokeAI/invokeai/app/invocations/sdxl.py
Lincoln Stein 79d028ecbd BREAKING CHANGES: invocations now require model key, not base/type/name
- Implement new model loader and modify invocations and embeddings

- Finish implementation loaders for all models currently supported by
  InvokeAI.

- Move lora, textual_inversion, and model patching support into
  backend/embeddings.

- Restore support for model cache statistics collection (a little ugly,
  needs work).

- Fixed up invocations that load and patch models.

- Move seamless and silencewarnings utils into better location
2024-02-08 23:26:41 -05:00

153 lines
4.9 KiB
Python

from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
@invocation_output("sdxl_model_loader_output")
class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("sdxl_refiner_model_loader_output")
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
model_key = self.model.key
# TODO: not found exceptions
if not context.services.model_records.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
return SDXLModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=model_key,
submodel=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel=SubModelType.Vae,
),
),
)
@invocation(
"sdxl_refiner_model_loader",
title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"],
category="model",
version="1.0.0",
)
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct,
ui_type=UIType.SDXLRefinerModel,
)
# TODO: precision?
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
model_key = self.model.key
# TODO: not found exceptions
if not context.services.model_records.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel=SubModelType.Vae,
),
),
)