InvokeAI/invokeai/app/invocations/sdxl.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

149 lines
5.0 KiB
Python
Raw Permalink Normal View History

from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
2023-07-27 14:54:01 +00:00
@invocation_output("sdxl_model_loader_output")
class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
2023-07-27 14:54:01 +00:00
@invocation_output("sdxl_refiner_model_loader_output")
2023-07-16 16:36:38 +00:00
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
2023-07-27 14:54:01 +00:00
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
2023-07-27 14:54:01 +00:00
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
model_key = self.model.key
# TODO: not found exceptions
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
2023-07-16 16:36:38 +00:00
return SDXLModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.Vae,
),
),
)
2023-07-27 14:54:01 +00:00
@invocation(
"sdxl_refiner_model_loader",
title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"],
category="model",
version="1.0.1",
)
2023-07-16 16:36:38 +00:00
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
2023-07-27 14:54:01 +00:00
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct,
ui_type=UIType.SDXLRefinerModel,
)
2023-07-16 16:38:04 +00:00
# TODO: precision?
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
model_key = self.model.key
2023-07-16 16:36:38 +00:00
# TODO: not found exceptions
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
2023-07-16 16:36:38 +00:00
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel_type=SubModelType.UNet,
2023-07-16 16:36:38 +00:00
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
2023-07-16 16:36:38 +00:00
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
2023-07-16 16:36:38 +00:00
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
2023-07-16 16:36:38 +00:00
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.Vae,
2023-07-16 16:36:38 +00:00
),
),
)