From 130249a2dd1fe306b07252cd21fd2c05a0ef2b93 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 9 Jul 2023 15:47:06 -0400 Subject: [PATCH] add model loading support for SDXL --- .../backend/model_management/model_probe.py | 14 ++- .../model_management/models/__init__.py | 10 +- .../backend/model_management/models/base.py | 3 + .../models/stable_diffusion.py | 108 +++++++++++++++++- invokeai/configs/INITIAL_MODELS.yaml | 6 + 5 files changed, 136 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 44d2682911..07f01188ac 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -38,6 +38,8 @@ class ModelProbe(object): CLASS2TYPE = { 'StableDiffusionPipeline' : ModelType.Main, + 'StableDiffusionXLPipeline' : ModelType.Main, + 'StableDiffusionXLImg2ImgPipeline' : ModelType.Main, 'AutoencoderKL' : ModelType.Vae, 'ControlNetModel' : ModelType.ControlNet, } @@ -99,9 +101,10 @@ class ModelProbe(object): upcast_attention = (base_type==BaseModelType.StableDiffusion2 \ and prediction_type==SchedulerPredictionType.VPrediction), format = format, - image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \ - and prediction_type==SchedulerPredictionType.VPrediction \ - ) else 512, + image_size = 1024 if (base_type==BaseModelType.StableDiffusionXL) else \ + 768 if (base_type==BaseModelType.StableDiffusion2 \ + and prediction_type==SchedulerPredictionType.VPrediction ) else \ + 512 ) except Exception: raise @@ -248,6 +251,9 @@ class PipelineCheckpointProbe(CheckpointProbeBase): return BaseModelType.StableDiffusion1 if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: return BaseModelType.StableDiffusion2 + # TODO: Verify that this is correct! Need an XL checkpoint file for this. + if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: + return BaseModelType.StableDiffusionXL raise Exception("Cannot determine base type") def get_scheduler_prediction_type(self)->SchedulerPredictionType: @@ -360,6 +366,8 @@ class PipelineFolderProbe(FolderProbeBase): return BaseModelType.StableDiffusion1 elif unet_conf['cross_attention_dim'] == 1024: return BaseModelType.StableDiffusion2 + elif unet_conf['cross_attention_dim'] in {1280,2048}: + return BaseModelType.StableDiffusionXL else: raise ValueError(f'Unknown base model for {self.folder_path}') diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 1b381cd2a8..aa94d640f4 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -3,7 +3,7 @@ from enum import Enum from pydantic import BaseModel from typing import Literal, get_origin from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException -from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model +from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model, StableDiffusionXLModel from .vae import VaeModel from .lora import LoRAModel from .controlnet import ControlNetModel # TODO: @@ -24,6 +24,14 @@ MODEL_CLASSES = { ModelType.ControlNet: ControlNetModel, ModelType.TextualInversion: TextualInversionModel, }, + BaseModelType.StableDiffusionXL: { + ModelType.Main: StableDiffusionXLModel, + ModelType.Vae: VaeModel, + # will not work until support written + ModelType.Lora: LoRAModel, + ModelType.ControlNet: ControlNetModel, + ModelType.TextualInversion: TextualInversionModel, + }, #BaseModelType.Kandinsky2_1: { # ModelType.Main: Kandinsky2_1Model, # ModelType.MoVQ: MoVQModel, diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 57c02bce76..73cbb8eb3e 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -21,6 +21,7 @@ class ModelNotFoundException(Exception): class BaseModelType(str, Enum): StableDiffusion1 = "sd-1" StableDiffusion2 = "sd-2" + StableDiffusionXL = "sdxl" #Kandinsky2_1 = "kandinsky-2.1" class ModelType(str, Enum): @@ -33,7 +34,9 @@ class ModelType(str, Enum): class SubModelType(str, Enum): UNet = "unet" TextEncoder = "text_encoder" + TextEncoder2 = "text_encoder_2" Tokenizer = "tokenizer" + Tokenizer2 = "tokenizer_2" Vae = "vae" Scheduler = "scheduler" SafetyChecker = "safety_checker" diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index c98d5a0ae8..c0b43d6774 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -222,6 +222,105 @@ class StableDiffusion2Model(DiffusersModel): else: return model_path +class StableDiffusionXLModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" + +class StableDiffusionXLModel(DiffusersModel): + + # TODO: check that configs overwriten properly + class DiffusersConfig(ModelConfigBase): + model_format: Literal[StableDiffusionXLModelFormat.Diffusers] + vae: Optional[str] = Field(None) + variant: ModelVariantType + + class CheckpointConfig(ModelConfigBase): + model_format: Literal[StableDiffusionXLModelFormat.Checkpoint] + vae: Optional[str] = Field(None) + config: str + variant: ModelVariantType + + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + assert base_model == BaseModelType.StableDiffusionXL + assert model_type == ModelType.Main + super().__init__( + model_path=model_path, + base_model=BaseModelType.StableDiffusionXL, + model_type=ModelType.Main, + ) + + @classmethod + def probe_config(cls, path: str, **kwargs): + model_format = cls.detect_format(path) + ckpt_config_path = kwargs.get("config", None) + if model_format == StableDiffusionXLModelFormat.Checkpoint: + if ckpt_config_path: + ckpt_config = OmegaConf.load(ckpt_config_path) + ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] + + else: + checkpoint = read_checkpoint_meta(path) + checkpoint = checkpoint.get('state_dict', checkpoint) + in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] + + elif model_format == StableDiffusionXLModelFormat.Diffusers: + unet_config_path = os.path.join(path, "unet", "config.json") + if os.path.exists(unet_config_path): + with open(unet_config_path, "r") as f: + unet_config = json.loads(f.read()) + in_channels = unet_config['in_channels'] + + else: + raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") + + else: + raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}") + + if in_channels == 9: + variant = ModelVariantType.Inpaint + elif in_channels == 5: + variant = ModelVariantType.Depth + elif in_channels == 4: + variant = ModelVariantType.Normal + else: + raise Exception("Unkown stable diffusion 2.* model format") + + if ckpt_config_path is None: + ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusionXL, variant) + + return cls.create_config( + path=path, + model_format=model_format, + + config=ckpt_config_path, + variant=variant, + ) + + @classproperty + def save_to_config(cls) -> bool: + return True + + @classmethod + def detect_format(cls, model_path: str): + if os.path.isdir(model_path): + return StableDiffusionXLModelFormat.Diffusers + else: + return StableDiffusionXLModelFormat.Checkpoint + + @classmethod + def convert_if_required( + cls, + model_path: str, + output_path: str, + config: ModelConfigBase, + base_model: BaseModelType, + ) -> str: + if isinstance(config, cls.CheckpointConfig): + raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported') + else: + return model_path + + def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): ckpt_configs = { BaseModelType.StableDiffusion1: { @@ -232,6 +331,12 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512) ModelVariantType.Inpaint: "v2-inpainting-inference.yaml", ModelVariantType.Depth: "v2-midas-inference.yaml", + }, + # note that these .yaml files don't yet exist! + BaseModelType.StableDiffusionXL: { + ModelVariantType.Normal: "xl-inference-v.yaml", + ModelVariantType.Inpaint: "xl-inpainting-inference.yaml", + ModelVariantType.Depth: "xl-midas-inference.yaml", } } @@ -247,9 +352,10 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): # TODO: rework +# Note that convert_ckpt_to_diffuses does not currently support conversion of SDXL models def _convert_ckpt_and_cache( version: BaseModelType, - model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig], + model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig, StableDiffusionXLModel.CheckpointConfig], output_path: str, ) -> str: """ diff --git a/invokeai/configs/INITIAL_MODELS.yaml b/invokeai/configs/INITIAL_MODELS.yaml index 4ba67bc4bc..f7f6ce7b17 100644 --- a/invokeai/configs/INITIAL_MODELS.yaml +++ b/invokeai/configs/INITIAL_MODELS.yaml @@ -1,4 +1,10 @@ # This file predefines a few models that the user may want to install. +sd-1/main/stable-diffusion-xdl-base: + description: Stable Diffusion XL base model - NOT YET RELEASED!! (70 GB) + repo_id: stabilityai/stable-diffusion-xl-base +sd-1/main/stable-diffusion-xdl-refiner: + description: Stable Diffusion XL refiner model - NOT YET RELEASED!! (60 GB) + repo_id: stabilityai/stable-diffusion-xl-refiner sd-1/main/stable-diffusion-v1-5: description: Stable Diffusion version 1.5 diffusers model (4.27 GB) repo_id: runwayml/stable-diffusion-v1-5