From 0a2964d8c04d2907b574360f9ab5fe6ace85ba86 Mon Sep 17 00:00:00 2001 From: Lincoln Stein <lstein@gmail.com> Date: Sun, 16 Jul 2023 12:17:56 -0400 Subject: [PATCH] add differentiated sdxl and sdxl_refiner model loaders --- invokeai/app/invocations/model.py | 1 - invokeai/app/invocations/sdxl.py | 135 +++++++++++++++++++++++++++++- 2 files changed, 132 insertions(+), 4 deletions(-) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 4ceb875019..5f1a91ae3b 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -46,7 +46,6 @@ class ModelLoaderOutput(BaseInvocationOutput): unet: UNetField = Field(default=None, description="UNet submodel") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") - clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)") vae: VaeField = Field(default=None, description="Vae submodel") # fmt: on diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 2899dd975d..39fa1e8b6e 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,12 +1,11 @@ -import copy import torch import inspect from tqdm import tqdm from typing import List, Literal, Optional, Union -from pydantic import BaseModel, Field, validator +from pydantic import Field, validator -from ...backend.model_management import BaseModelType, ModelType, SubModelType +from ...backend.model_management import ModelType, SubModelType from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) @@ -14,6 +13,136 @@ from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo from .compel import ConditioningField from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output +class SDXLModelLoaderOutput(BaseInvocationOutput): + """SDXL base model loader output""" + + # fmt: off + type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output" + + unet: UNetField = Field(default=None, description="UNet submodel") + clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") + clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)") + vae: VaeField = Field(default=None, description="Vae submodel") + # fmt: on + +class SDXLRefinerModelLoaderOutput(SDXLModelLoaderOutput): + """SDXL refiner model loader output""" + # fmt: off + type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output" + #fmt: on + + +class SDXLModelLoaderInvocation(BaseInvocation): + """Loads an sdxl base model, outputting its submodels.""" + + type: Literal["sdxl_model_loader"] = "sdxl_main_model_loader" + + model: MainModelField = Field(description="The model to load") + # TODO: precision? + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "title": "SDXL Model Loader", + "tags": ["model", "loader", "sdxl"], + "type_hints": {"model": "model"}, + }, + } + + @classmethod + def _output_class(cls): + return SDXLModelLoaderOutput + + def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: + base_model = self.model.base_model + model_name = self.model.model_name + model_type = ModelType.Main + + # TODO: not found exceptions + if not context.services.model_manager.model_exists( + model_name=model_name, + base_model=base_model, + model_type=model_type, + ): + raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + + return self._output_class( + unet=UNetField( + unet=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.UNet, + ), + scheduler=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.Scheduler, + ), + loras=[], + ), + clip=ClipField( + tokenizer=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.Tokenizer, + ), + text_encoder=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.TextEncoder, + ), + loras=[], + skipped_layers=0, + ), + clip2=ClipField( + tokenizer=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.Tokenizer2, + ), + text_encoder=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.TextEncoder2, + ), + loras=[], + skipped_layers=0, + ), + vae=VaeField( + vae=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.Vae, + ), + ), + ) + +class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation): + """Loads an sdxl refiner model, outputting its submodels.""" + type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader" + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "title": "SDXL Refiner Model Loader", + "tags": ["model", "loader", "sdxl_refiner"], + "type_hints": {"model": "model"}, + }, + } + + @classmethod + def _output_class(cls): + return SDXLRefinerModelLoaderOutput + # Text to image class SDXLTextToLatentsInvocation(BaseInvocation): """Generates latents from conditionings."""