From 0a2964d8c04d2907b574360f9ab5fe6ace85ba86 Mon Sep 17 00:00:00 2001
From: Lincoln Stein <lstein@gmail.com>
Date: Sun, 16 Jul 2023 12:17:56 -0400
Subject: [PATCH] add differentiated sdxl and sdxl_refiner model loaders

---
 invokeai/app/invocations/model.py |   1 -
 invokeai/app/invocations/sdxl.py  | 135 +++++++++++++++++++++++++++++-
 2 files changed, 132 insertions(+), 4 deletions(-)

diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py
index 4ceb875019..5f1a91ae3b 100644
--- a/invokeai/app/invocations/model.py
+++ b/invokeai/app/invocations/model.py
@@ -46,7 +46,6 @@ class ModelLoaderOutput(BaseInvocationOutput):
 
     unet: UNetField = Field(default=None, description="UNet submodel")
     clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
-    clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)")
     vae: VaeField = Field(default=None, description="Vae submodel")
     # fmt: on
 
diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py
index 2899dd975d..39fa1e8b6e 100644
--- a/invokeai/app/invocations/sdxl.py
+++ b/invokeai/app/invocations/sdxl.py
@@ -1,12 +1,11 @@
-import copy
 import torch
 import inspect
 from tqdm import tqdm
 from typing import List, Literal, Optional, Union
 
-from pydantic import BaseModel, Field, validator
+from pydantic import Field, validator
 
-from ...backend.model_management import BaseModelType, ModelType, SubModelType
+from ...backend.model_management import ModelType, SubModelType
 from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
                              InvocationConfig, InvocationContext)
 
@@ -14,6 +13,136 @@ from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
 from .compel import ConditioningField
 from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
 
+class SDXLModelLoaderOutput(BaseInvocationOutput):
+    """SDXL base model loader output"""
+
+    # fmt: off
+    type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
+
+    unet: UNetField = Field(default=None, description="UNet submodel")
+    clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
+    clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)")
+    vae: VaeField = Field(default=None, description="Vae submodel")
+    # fmt: on
+
+class SDXLRefinerModelLoaderOutput(SDXLModelLoaderOutput):
+    """SDXL refiner model loader output"""
+    # fmt: off
+    type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
+    #fmt: on
+    
+
+class SDXLModelLoaderInvocation(BaseInvocation):
+    """Loads an sdxl base model, outputting its submodels."""
+
+    type: Literal["sdxl_model_loader"] = "sdxl_main_model_loader"
+
+    model: MainModelField = Field(description="The model to load")
+    # TODO: precision?
+
+    # Schema customisation
+    class Config(InvocationConfig):
+        schema_extra = {
+            "ui": {
+                "title": "SDXL Model Loader",
+                "tags": ["model", "loader", "sdxl"],
+                "type_hints": {"model": "model"},
+            },
+        }
+
+    @classmethod
+    def _output_class(cls):
+        return SDXLModelLoaderOutput
+
+    def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
+        base_model = self.model.base_model
+        model_name = self.model.model_name
+        model_type = ModelType.Main
+
+        # TODO: not found exceptions
+        if not context.services.model_manager.model_exists(
+            model_name=model_name,
+            base_model=base_model,
+            model_type=model_type,
+        ):
+            raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
+
+        return self._output_class(
+            unet=UNetField(
+                unet=ModelInfo(
+                    model_name=model_name,
+                    base_model=base_model,
+                    model_type=model_type,
+                    submodel=SubModelType.UNet,
+                ),
+                scheduler=ModelInfo(
+                    model_name=model_name,
+                    base_model=base_model,
+                    model_type=model_type,
+                    submodel=SubModelType.Scheduler,
+                ),
+                loras=[],
+            ),
+            clip=ClipField(
+                tokenizer=ModelInfo(
+                    model_name=model_name,
+                    base_model=base_model,
+                    model_type=model_type,
+                    submodel=SubModelType.Tokenizer,
+                ),
+                text_encoder=ModelInfo(
+                    model_name=model_name,
+                    base_model=base_model,
+                    model_type=model_type,
+                    submodel=SubModelType.TextEncoder,
+                ),
+                loras=[],
+                skipped_layers=0,
+            ),
+            clip2=ClipField(
+                tokenizer=ModelInfo(
+                    model_name=model_name,
+                    base_model=base_model,
+                    model_type=model_type,
+                    submodel=SubModelType.Tokenizer2,
+                ),
+                text_encoder=ModelInfo(
+                    model_name=model_name,
+                    base_model=base_model,
+                    model_type=model_type,
+                    submodel=SubModelType.TextEncoder2,
+                ),
+                loras=[],
+                skipped_layers=0,
+            ),
+            vae=VaeField(
+                vae=ModelInfo(
+                    model_name=model_name,
+                    base_model=base_model,
+                    model_type=model_type,
+                    submodel=SubModelType.Vae,
+                ),
+            ),
+        )
+
+class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation):
+    """Loads an sdxl refiner model, outputting its submodels."""
+    type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
+
+    # Schema customisation
+    class Config(InvocationConfig):
+        schema_extra = {
+            "ui": {
+                "title": "SDXL Refiner Model Loader",
+                "tags": ["model", "loader", "sdxl_refiner"],
+                "type_hints": {"model": "model"},
+            },
+        }
+
+    @classmethod
+    def _output_class(cls):
+        return SDXLRefinerModelLoaderOutput
+
 # Text to image
 class SDXLTextToLatentsInvocation(BaseInvocation):
     """Generates latents from conditionings."""