add differentiated sdxl and sdxl_refiner model loaders

This commit is contained in:
Lincoln Stein
2023-07-16 12:17:56 -04:00
parent fe78a08e37
commit 0a2964d8c0
2 changed files with 132 additions and 4 deletions

View File

@ -46,7 +46,6 @@ class ModelLoaderOutput(BaseInvocationOutput):
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on # fmt: on

View File

@ -1,12 +1,11 @@
import copy
import torch import torch
import inspect import inspect
from tqdm import tqdm from tqdm import tqdm
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field, validator from pydantic import Field, validator
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
@ -14,6 +13,136 @@ from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
from .compel import ConditioningField from .compel import ConditioningField
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
# fmt: off
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
class SDXLRefinerModelLoaderOutput(SDXLModelLoaderOutput):
"""SDXL refiner model loader output"""
# fmt: off
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
#fmt: on
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
type: Literal["sdxl_model_loader"] = "sdxl_main_model_loader"
model: MainModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Model Loader",
"tags": ["model", "loader", "sdxl"],
"type_hints": {"model": "model"},
},
}
@classmethod
def _output_class(cls):
return SDXLModelLoaderOutput
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
return self._output_class(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
),
),
)
class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Refiner Model Loader",
"tags": ["model", "loader", "sdxl_refiner"],
"type_hints": {"model": "model"},
},
}
@classmethod
def _output_class(cls):
return SDXLRefinerModelLoaderOutput
# Text to image # Text to image
class SDXLTextToLatentsInvocation(BaseInvocation): class SDXLTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings.""" """Generates latents from conditionings."""