2023-08-08 00:38:42 +00:00
|
|
|
from typing import Literal
|
2023-07-11 15:19:36 +00:00
|
|
|
|
2023-08-08 00:38:42 +00:00
|
|
|
from ...backend.model_management import ModelType, SubModelType
|
2023-08-14 03:23:09 +00:00
|
|
|
from .baseinvocation import (
|
|
|
|
BaseInvocation,
|
|
|
|
BaseInvocationOutput,
|
|
|
|
FieldDescriptions,
|
|
|
|
Input,
|
|
|
|
InputField,
|
|
|
|
InvocationContext,
|
|
|
|
OutputField,
|
2023-08-15 11:45:40 +00:00
|
|
|
UIType,
|
2023-08-14 03:23:09 +00:00
|
|
|
tags,
|
|
|
|
title,
|
|
|
|
)
|
|
|
|
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
2023-07-11 15:19:36 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-16 16:17:56 +00:00
|
|
|
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
|
|
|
"""SDXL base model loader output"""
|
|
|
|
|
|
|
|
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
|
|
|
|
2023-08-14 03:23:09 +00:00
|
|
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
|
|
|
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
|
|
|
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
|
|
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
2023-07-16 16:17:56 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-16 16:36:38 +00:00
|
|
|
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
2023-07-16 16:17:56 +00:00
|
|
|
"""SDXL refiner model loader output"""
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-16 16:17:56 +00:00
|
|
|
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-08-14 03:23:09 +00:00
|
|
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
|
|
|
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
|
|
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-08-14 03:23:09 +00:00
|
|
|
|
2023-08-20 10:00:35 +00:00
|
|
|
@title("SDXL Main Model")
|
2023-08-14 03:23:09 +00:00
|
|
|
@tags("model", "sdxl")
|
2023-07-16 16:17:56 +00:00
|
|
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
|
|
|
"""Loads an sdxl base model, outputting its submodels."""
|
|
|
|
|
2023-07-17 09:47:41 +00:00
|
|
|
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
|
2023-07-16 16:17:56 +00:00
|
|
|
|
2023-08-14 03:23:09 +00:00
|
|
|
# Inputs
|
|
|
|
model: MainModelField = InputField(
|
2023-08-15 11:45:40 +00:00
|
|
|
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
2023-08-14 03:23:09 +00:00
|
|
|
)
|
2023-07-16 16:17:56 +00:00
|
|
|
# TODO: precision?
|
|
|
|
|
|
|
|
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
|
|
|
base_model = self.model.base_model
|
|
|
|
model_name = self.model.model_name
|
|
|
|
model_type = ModelType.Main
|
|
|
|
|
|
|
|
# TODO: not found exceptions
|
|
|
|
if not context.services.model_manager.model_exists(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
):
|
|
|
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
|
|
|
|
2023-07-16 16:36:38 +00:00
|
|
|
return SDXLModelLoaderOutput(
|
2023-07-16 16:17:56 +00:00
|
|
|
unet=UNetField(
|
|
|
|
unet=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.UNet,
|
|
|
|
),
|
|
|
|
scheduler=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Scheduler,
|
|
|
|
),
|
|
|
|
loras=[],
|
|
|
|
),
|
|
|
|
clip=ClipField(
|
|
|
|
tokenizer=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Tokenizer,
|
|
|
|
),
|
|
|
|
text_encoder=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.TextEncoder,
|
|
|
|
),
|
|
|
|
loras=[],
|
|
|
|
skipped_layers=0,
|
|
|
|
),
|
|
|
|
clip2=ClipField(
|
|
|
|
tokenizer=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Tokenizer2,
|
|
|
|
),
|
|
|
|
text_encoder=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.TextEncoder2,
|
|
|
|
),
|
|
|
|
loras=[],
|
|
|
|
skipped_layers=0,
|
|
|
|
),
|
|
|
|
vae=VaeField(
|
|
|
|
vae=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Vae,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-08-20 10:00:35 +00:00
|
|
|
@title("SDXL Refiner Model")
|
2023-08-14 03:23:09 +00:00
|
|
|
@tags("model", "sdxl", "refiner")
|
2023-07-16 16:36:38 +00:00
|
|
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
2023-07-16 16:17:56 +00:00
|
|
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-16 16:17:56 +00:00
|
|
|
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
|
|
|
|
2023-08-14 03:23:09 +00:00
|
|
|
# Inputs
|
|
|
|
model: MainModelField = InputField(
|
|
|
|
description=FieldDescriptions.sdxl_refiner_model,
|
|
|
|
input=Input.Direct,
|
2023-08-15 11:45:40 +00:00
|
|
|
ui_type=UIType.SDXLRefinerModel,
|
2023-08-14 03:23:09 +00:00
|
|
|
)
|
2023-07-16 16:38:04 +00:00
|
|
|
# TODO: precision?
|
|
|
|
|
2023-07-16 16:36:38 +00:00
|
|
|
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
|
|
|
base_model = self.model.base_model
|
|
|
|
model_name = self.model.model_name
|
|
|
|
model_type = ModelType.Main
|
2023-07-16 16:17:56 +00:00
|
|
|
|
2023-07-16 16:36:38 +00:00
|
|
|
# TODO: not found exceptions
|
|
|
|
if not context.services.model_manager.model_exists(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
):
|
|
|
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
|
|
|
|
|
|
|
return SDXLRefinerModelLoaderOutput(
|
|
|
|
unet=UNetField(
|
|
|
|
unet=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.UNet,
|
|
|
|
),
|
|
|
|
scheduler=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Scheduler,
|
|
|
|
),
|
|
|
|
loras=[],
|
|
|
|
),
|
|
|
|
clip2=ClipField(
|
|
|
|
tokenizer=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Tokenizer2,
|
|
|
|
),
|
|
|
|
text_encoder=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.TextEncoder2,
|
|
|
|
),
|
|
|
|
loras=[],
|
|
|
|
skipped_layers=0,
|
|
|
|
),
|
|
|
|
vae=VaeField(
|
|
|
|
vae=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Vae,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
)
|