From 75c5ce46bc9867210ae97eda1837b85973b54b9b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 11 Jul 2023 16:33:08 -0400 Subject: [PATCH] merged SDXLModelLoader into ModelLoader invocation --- invokeai/app/invocations/model.py | 88 +------------------ .../nodes/components/InputFieldComponent.tsx | 2 +- 2 files changed, 2 insertions(+), 88 deletions(-) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 76cd5b81e3..4ceb875019 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -46,22 +46,10 @@ class ModelLoaderOutput(BaseInvocationOutput): unet: UNetField = Field(default=None, description="UNet submodel") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") + clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)") vae: VaeField = Field(default=None, description="Vae submodel") # fmt: on -class SDXLModelLoaderOutput(BaseInvocationOutput): - """SDXL model loader output""" - - # fmt: off - type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output" - - unet: UNetField = Field(default=None, description="UNet submodel") - clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") - clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (2d set)") - vae: VaeField = Field(default=None, description="Vae submodel") - # fmt: on - - class MainModelField(BaseModel): """Main model field""" @@ -136,79 +124,6 @@ class MainModelLoaderInvocation(BaseInvocation): """ return ModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, - ), - scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder, - ), - loras=[], - skipped_layers=0, - ), - vae=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Vae, - ), - ), - ) - -class SDXLMainModelLoaderInvocation(BaseInvocation): - """Loads an SDXL main model, outputting its submodels.""" - - type: Literal["sdxl_main_model_loader"] = "sdxl_main_model_loader" - - model: MainModelField = Field(description="The SDXL model to load") - # TODO: precision? - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "SDXL Model Loader", - "tags": ["model", "loader", "sdxl"], - "type_hints": {"model": "model"}, - }, - } - - def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main - - # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - return SDXLModelLoaderOutput( unet=UNetField( unet=ModelInfo( model_name=model_name, @@ -266,7 +181,6 @@ class SDXLMainModelLoaderInvocation(BaseInvocation): ), ) - class LoraLoaderOutput(BaseInvocationOutput): """Model loader output""" diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index ddc6f4c6c3..52e8107f59 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -150,7 +150,7 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { nodeId={nodeId} field={field} template={template} - base_models={['sd-1', 'sd-2']} + base_models={['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']} /> ); }