merged SDXLModelLoader into ModelLoader invocation

This commit is contained in:
Lincoln Stein 2023-07-11 16:33:08 -04:00
parent 358ced6bab
commit 75c5ce46bc
2 changed files with 2 additions and 88 deletions

View File

@ -46,22 +46,10 @@ class ModelLoaderOutput(BaseInvocationOutput):
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on # fmt: on
class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL model loader output"""
# fmt: off
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (2d set)")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
class MainModelField(BaseModel): class MainModelField(BaseModel):
"""Main model field""" """Main model field"""
@ -136,79 +124,6 @@ class MainModelLoaderInvocation(BaseInvocation):
""" """
return ModelLoaderOutput( return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
),
),
)
class SDXLMainModelLoaderInvocation(BaseInvocation):
"""Loads an SDXL main model, outputting its submodels."""
type: Literal["sdxl_main_model_loader"] = "sdxl_main_model_loader"
model: MainModelField = Field(description="The SDXL model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Model Loader",
"tags": ["model", "loader", "sdxl"],
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
return SDXLModelLoaderOutput(
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
model_name=model_name, model_name=model_name,
@ -267,7 +182,6 @@ class SDXLMainModelLoaderInvocation(BaseInvocation):
) )
class LoraLoaderOutput(BaseInvocationOutput): class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""

View File

@ -150,7 +150,7 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
nodeId={nodeId} nodeId={nodeId}
field={field} field={field}
template={template} template={template}
base_models={['sd-1', 'sd-2']} base_models={['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']}
/> />
); );
} }