2023-07-11 15:19:36 +00:00
|
|
|
import torch
|
2023-08-08 00:38:42 +00:00
|
|
|
from typing import Literal
|
|
|
|
from pydantic import Field
|
2023-07-11 15:19:36 +00:00
|
|
|
|
2023-08-08 00:38:42 +00:00
|
|
|
from ...backend.model_management import ModelType, SubModelType
|
2023-07-11 15:19:36 +00:00
|
|
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
|
|
|
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-16 16:17:56 +00:00
|
|
|
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
|
|
|
"""SDXL base model loader output"""
|
|
|
|
|
|
|
|
# fmt: off
|
|
|
|
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
|
|
|
|
|
|
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
|
|
|
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
2023-07-16 16:36:38 +00:00
|
|
|
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
2023-07-16 16:17:56 +00:00
|
|
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
|
|
|
# fmt: on
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-16 16:36:38 +00:00
|
|
|
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
2023-07-16 16:17:56 +00:00
|
|
|
"""SDXL refiner model loader output"""
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-16 16:17:56 +00:00
|
|
|
# fmt: off
|
|
|
|
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
2023-07-16 16:36:38 +00:00
|
|
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
|
|
|
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
|
|
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
|
|
|
# fmt: on
|
2023-07-16 16:17:56 +00:00
|
|
|
# fmt: on
|
2023-07-27 14:54:01 +00:00
|
|
|
|
|
|
|
|
2023-07-16 16:17:56 +00:00
|
|
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
|
|
|
"""Loads an sdxl base model, outputting its submodels."""
|
|
|
|
|
2023-07-17 09:47:41 +00:00
|
|
|
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
|
2023-07-16 16:17:56 +00:00
|
|
|
|
|
|
|
model: MainModelField = Field(description="The model to load")
|
|
|
|
# TODO: precision?
|
|
|
|
|
|
|
|
# Schema customisation
|
|
|
|
class Config(InvocationConfig):
|
|
|
|
schema_extra = {
|
|
|
|
"ui": {
|
|
|
|
"title": "SDXL Model Loader",
|
|
|
|
"tags": ["model", "loader", "sdxl"],
|
|
|
|
"type_hints": {"model": "model"},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
|
|
|
base_model = self.model.base_model
|
|
|
|
model_name = self.model.model_name
|
|
|
|
model_type = ModelType.Main
|
|
|
|
|
|
|
|
# TODO: not found exceptions
|
|
|
|
if not context.services.model_manager.model_exists(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
):
|
|
|
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
|
|
|
|
2023-07-16 16:36:38 +00:00
|
|
|
return SDXLModelLoaderOutput(
|
2023-07-16 16:17:56 +00:00
|
|
|
unet=UNetField(
|
|
|
|
unet=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.UNet,
|
|
|
|
),
|
|
|
|
scheduler=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Scheduler,
|
|
|
|
),
|
|
|
|
loras=[],
|
|
|
|
),
|
|
|
|
clip=ClipField(
|
|
|
|
tokenizer=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Tokenizer,
|
|
|
|
),
|
|
|
|
text_encoder=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.TextEncoder,
|
|
|
|
),
|
|
|
|
loras=[],
|
|
|
|
skipped_layers=0,
|
|
|
|
),
|
|
|
|
clip2=ClipField(
|
|
|
|
tokenizer=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Tokenizer2,
|
|
|
|
),
|
|
|
|
text_encoder=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.TextEncoder2,
|
|
|
|
),
|
|
|
|
loras=[],
|
|
|
|
skipped_layers=0,
|
|
|
|
),
|
|
|
|
vae=VaeField(
|
|
|
|
vae=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Vae,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-16 16:36:38 +00:00
|
|
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
2023-07-16 16:17:56 +00:00
|
|
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-16 16:17:56 +00:00
|
|
|
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
|
|
|
|
2023-07-16 16:38:04 +00:00
|
|
|
model: MainModelField = Field(description="The model to load")
|
|
|
|
# TODO: precision?
|
|
|
|
|
2023-07-16 16:17:56 +00:00
|
|
|
# Schema customisation
|
|
|
|
class Config(InvocationConfig):
|
|
|
|
schema_extra = {
|
|
|
|
"ui": {
|
|
|
|
"title": "SDXL Refiner Model Loader",
|
|
|
|
"tags": ["model", "loader", "sdxl_refiner"],
|
2023-07-25 12:08:25 +00:00
|
|
|
"type_hints": {"model": "refiner_model"},
|
2023-07-16 16:17:56 +00:00
|
|
|
},
|
|
|
|
}
|
|
|
|
|
2023-07-16 16:36:38 +00:00
|
|
|
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
|
|
|
base_model = self.model.base_model
|
|
|
|
model_name = self.model.model_name
|
|
|
|
model_type = ModelType.Main
|
2023-07-16 16:17:56 +00:00
|
|
|
|
2023-07-16 16:36:38 +00:00
|
|
|
# TODO: not found exceptions
|
|
|
|
if not context.services.model_manager.model_exists(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
):
|
|
|
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
|
|
|
|
|
|
|
return SDXLRefinerModelLoaderOutput(
|
|
|
|
unet=UNetField(
|
|
|
|
unet=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.UNet,
|
|
|
|
),
|
|
|
|
scheduler=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Scheduler,
|
|
|
|
),
|
|
|
|
loras=[],
|
|
|
|
),
|
|
|
|
clip2=ClipField(
|
|
|
|
tokenizer=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Tokenizer2,
|
|
|
|
),
|
|
|
|
text_encoder=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.TextEncoder2,
|
|
|
|
),
|
|
|
|
loras=[],
|
|
|
|
skipped_layers=0,
|
|
|
|
),
|
|
|
|
vae=VaeField(
|
|
|
|
vae=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Vae,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
)
|